Flatten alongside with batch axis in TensorFlow / Keras


In a Sequential model, I’m trying to go from a layer output shape of (None, 300) to something like (1,1,None*300) to apply an AveragePooling layer. In fact I would like to flatten everything (even the batch axis), while both Flatten and Reshape layers always skip the batch axis. Any idea?


You can use a Lambda layer and the K.reshape from backend like this:

from keras import backend as K

out = Lambda(lambda x: K.reshape(x, (1, 1, -1)))(inp)

Answered By – today

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published