Warm tip: This article is reproduced from serverfault.com, please click

TensorFlow layer that converts a 2D matrix to a vector of certain length

发布于 2020-11-28 14:44:00

I am trying to build a neural network that takes in data in form of a matrix and outputs a vector but I don't know what layers to use to perform that. My input has shape (10,4) and my desired output has shape (3,). My current model is the following :

model = tf.keras.Sequential([
    tf.keras.layers.Dense(256,activation="relu"),
    tf.keras.layers.Dense(256,activation="relu"),
    tf.keras.layers.Dense(1),
])

this at least results in a vector instead of a matrix but it has (10,) instead of (3,). I could probably find a way to reduce that to (3,) but I doubt I am doing the correct thing with this approach.

Questioner
Lukas Gradl
Viewed
0
Akshay Sehgal 2020-11-28 22:54:12

Assuming that your (10,4) is a matrix which doesn't represent a 10 length sequence (where you will need an LSTM) OR an image (where you will need a 2D CNN), you can simply flatten() the input matrix and pass it through to the next few dense layers as below.

from tensorflow.keras import layers, Model

inp = layers.Input((10,4)) #none,10,4
x = layers.Flatten()(inp)  #none,40
x = layers.Dense(256)(x)   #none,256
out = layers.Dense(3)(x)   #none,3

model = Model(inp, out)
model.summary()
Layer (type)                 Output Shape              Param #   
=================================================================
input_43 (InputLayer)        [(None, 10, 4)]           0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 40)                0         
_________________________________________________________________
dense_82 (Dense)             (None, 256)               10496     
_________________________________________________________________
dense_83 (Dense)             (None, 3)                 771       
=================================================================
Total params: 11,267
Trainable params: 11,267
Non-trainable params: 0