decoder output shape does not match encoder input shape in CNN Autoencoder

Issue

I have a CNN Autoencoder structure like this. But the decoder ouput shape does not match the encoder shape. I have tried to modifed Pooling or Conv layer but its hard to find a good one. How to modify the network structure to make it works in my case? I am also thinking about resize the input shape, but it may affect the input quality, is this a good way?

input_img = Input(shape=(200, 800, 2))
## Encoder
x = Conv2D(16, (3, 3), activation='tanh', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='tanh', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='tanh', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='tanh', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(4, (3, 3), activation='tanh', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(4, (3, 3), activation='tanh', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Reshape([4*13*4])(x) 
encoded = Dense(2,activation='tanh')(x)
## Two variables
val1= Lambda(lambda x: x[:,0:1])(encoded)
val2= Lambda(lambda x: x[:,1:2])(encoded)
## Decoder 1
x1 = Dense(4*13*4,activation='tanh')(val1)
x1 = Reshape([4,13,4])(x1)
x1 = UpSampling2D((2,2))(x1)
x1 = Conv2D(4,(3,3),activation='tanh',padding='same')(x1)
x1 = UpSampling2D((2,2))(x1)
x1 = Conv2D(8,(3,3),activation='tanh',padding='same')(x1)
x1 = UpSampling2D((2,2))(x1)
x1 = Conv2D(8,(3,3),activation='tanh',padding='same')(x1)
x1 = UpSampling2D((2,2))(x1)
x1 = Conv2D(8,(3,3),activation='tanh',padding='same')(x1)
x1 = UpSampling2D((2,2))(x1)
x1 = Conv2D(16,(3,3),activation='tanh',padding='same')(x1)
x1 = UpSampling2D((2,2))(x1)
x1d = Conv2D(2,(3,3),activation='linear',padding='same')(x1)
## Decoder 2
x2 = Dense(4*13*4,activation='tanh')(val2)
x2 = Reshape([4,13,4])(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(4,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(8,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(8,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(8,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(16,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2d = Conv2D(2,(3,3),activation='linear',padding='same')(x2)

decoded = Add()([x1d,x2d])

The output shape:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_6 (InputLayer)            [(None, 200, 800, 2) 0                                            
__________________________________________________________________________________________________
conv2d_78 (Conv2D)              (None, 200, 800, 16) 304         input_6[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_30 (MaxPooling2D) (None, 100, 400, 16) 0           conv2d_78[0][0]                  
__________________________________________________________________________________________________
conv2d_79 (Conv2D)              (None, 100, 400, 8)  1160        max_pooling2d_30[0][0]           
__________________________________________________________________________________________________
max_pooling2d_31 (MaxPooling2D) (None, 50, 200, 8)   0           conv2d_79[0][0]                  
__________________________________________________________________________________________________
conv2d_80 (Conv2D)              (None, 50, 200, 8)   584         max_pooling2d_31[0][0]           
__________________________________________________________________________________________________
max_pooling2d_32 (MaxPooling2D) (None, 25, 100, 8)   0           conv2d_80[0][0]                  
__________________________________________________________________________________________________
conv2d_81 (Conv2D)              (None, 25, 100, 8)   584         max_pooling2d_32[0][0]           
__________________________________________________________________________________________________
max_pooling2d_33 (MaxPooling2D) (None, 13, 50, 8)    0           conv2d_81[0][0]                  
__________________________________________________________________________________________________
conv2d_82 (Conv2D)              (None, 13, 50, 4)    292         max_pooling2d_33[0][0]           
__________________________________________________________________________________________________
max_pooling2d_34 (MaxPooling2D) (None, 7, 25, 4)     0           conv2d_82[0][0]                  
__________________________________________________________________________________________________
conv2d_83 (Conv2D)              (None, 7, 25, 4)     148         max_pooling2d_34[0][0]           
__________________________________________________________________________________________________
max_pooling2d_35 (MaxPooling2D) (None, 4, 13, 4)     0           conv2d_83[0][0]                  
__________________________________________________________________________________________________
reshape_13 (Reshape)            (None, 208)          0           max_pooling2d_35[0][0]           
__________________________________________________________________________________________________
dense_12 (Dense)                (None, 2)            418         reshape_13[0][0]                 
__________________________________________________________________________________________________
lambda_8 (Lambda)               (None, 1)            0           dense_12[0][0]                   
__________________________________________________________________________________________________
lambda_9 (Lambda)               (None, 1)            0           dense_12[0][0]                   
__________________________________________________________________________________________________
dense_13 (Dense)                (None, 208)          416         lambda_8[0][0]                   
__________________________________________________________________________________________________
dense_14 (Dense)                (None, 208)          416         lambda_9[0][0]                   
__________________________________________________________________________________________________
reshape_14 (Reshape)            (None, 4, 13, 4)     0           dense_13[0][0]                   
__________________________________________________________________________________________________
reshape_15 (Reshape)            (None, 4, 13, 4)     0           dense_14[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_48 (UpSampling2D) (None, 8, 13, 4)     0           reshape_14[0][0]                 
__________________________________________________________________________________________________
up_sampling2d_54 (UpSampling2D) (None, 8, 13, 4)     0           reshape_15[0][0]                 
__________________________________________________________________________________________________
conv2d_84 (Conv2D)              (None, 8, 13, 4)     148         up_sampling2d_48[0][0]           
__________________________________________________________________________________________________
conv2d_90 (Conv2D)              (None, 8, 13, 4)     148         up_sampling2d_54[0][0]           
__________________________________________________________________________________________________
up_sampling2d_49 (UpSampling2D) (None, 16, 26, 4)    0           conv2d_84[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_55 (UpSampling2D) (None, 16, 26, 4)    0           conv2d_90[0][0]                  
__________________________________________________________________________________________________
conv2d_85 (Conv2D)              (None, 16, 26, 8)    296         up_sampling2d_49[0][0]           
__________________________________________________________________________________________________
conv2d_91 (Conv2D)              (None, 16, 26, 8)    296         up_sampling2d_55[0][0]           
__________________________________________________________________________________________________
up_sampling2d_50 (UpSampling2D) (None, 32, 52, 8)    0           conv2d_85[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_56 (UpSampling2D) (None, 32, 52, 8)    0           conv2d_91[0][0]                  
__________________________________________________________________________________________________
conv2d_86 (Conv2D)              (None, 32, 52, 8)    584         up_sampling2d_50[0][0]           
__________________________________________________________________________________________________
conv2d_92 (Conv2D)              (None, 32, 52, 8)    584         up_sampling2d_56[0][0]           
__________________________________________________________________________________________________
up_sampling2d_51 (UpSampling2D) (None, 64, 104, 8)   0           conv2d_86[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_57 (UpSampling2D) (None, 64, 104, 8)   0           conv2d_92[0][0]                  
__________________________________________________________________________________________________
conv2d_87 (Conv2D)              (None, 64, 104, 8)   584         up_sampling2d_51[0][0]           
__________________________________________________________________________________________________
conv2d_93 (Conv2D)              (None, 64, 104, 8)   584         up_sampling2d_57[0][0]           
__________________________________________________________________________________________________
up_sampling2d_52 (UpSampling2D) (None, 128, 208, 8)  0           conv2d_87[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_58 (UpSampling2D) (None, 128, 208, 8)  0           conv2d_93[0][0]                  
__________________________________________________________________________________________________
conv2d_88 (Conv2D)              (None, 128, 208, 16) 1168        up_sampling2d_52[0][0]           
__________________________________________________________________________________________________
conv2d_94 (Conv2D)              (None, 128, 208, 16) 1168        up_sampling2d_58[0][0]           
__________________________________________________________________________________________________
up_sampling2d_53 (UpSampling2D) (None, 256, 416, 16) 0           conv2d_88[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_59 (UpSampling2D) (None, 256, 416, 16) 0           conv2d_94[0][0]                  
__________________________________________________________________________________________________
conv2d_89 (Conv2D)              (None, 256, 416, 2)  290         up_sampling2d_53[0][0]           
__________________________________________________________________________________________________
conv2d_95 (Conv2D)              (None, 256, 416, 2)  290         up_sampling2d_59[0][0]           
__________________________________________________________________________________________________
add_4 (Add)                     (None, 256, 416, 2)  0           conv2d_89[0][0]                  
                                                                 conv2d_95[0][0]

Solution

You can modify the network structure of the decoder as follows to match the input shape of the encoder and output shape of the decoder.

x2 = Dense(4*13*4,activation='tanh')(val2)
x2 = Reshape([4,13,4])(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(4,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(8,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(8,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(8,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(16,(3,3),activation='tanh',padding='same')(x2)
x2 = UpSampling2D((2,2))(x2)
x2 = Conv2D(4,(7,7),activation='tanh',padding='valid')(x2)
x2 = Conv2D(2,(7,7),activation='linear',padding='valid')(x2)
# Cropping2D crops along spatial dimensions, i.e. height and width.
x2d = Cropping2D(cropping=((22,22),(10,10)))(x2)

Answered By – Tfer3

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published