I have a model which consists of two heads, a concatenation, and a bunch of layers; all dense. However, for the concatenation I use a
Lambda layer to not only concatenate both heads, but to also change the concatenation order for each entry in the batch using
For this purpose I make use of an index
Input(batch_size, 512), and the
Lambda layer I’m using is this:
Lambda(lambda x: gather(Concatenate()([x, x]), x))([h1, h2, idx])
h1 is the output of the first head,
h2 is the output of the second head, and
idx is the index tensor.
If I remove
gather and leave only
Concatenate, the model learns and the loss decreases. However, this way, it doesn’t, and it gets stuck.
Concatenate()([h1, h2]) # this works well
Just in case, idx has a shape of (None, 512), and h1 and h2 (None, 256). Batch size is 2048.
What am I doing wrong? Any help would be much appreciated.
I found the error. I had to specify the batch_dims=1 parameter in gather and remove the indexing in Concatenate. So now I have this and works well:
Lambda(lambda x: gather(Concatenate()([x, x]), x, batch_dims=1))([h1, h2, idx])
Answered By – underwater