Efficiently use Dense layers in parallel

Issue

I need to implement a layer in Tensorflow for a dataset of size N where each sample has a set of M independent features (each feature is represented by a tensor of dimension L). I want to train M dense layers in parallel, then concatenate the outputted tensors.

I could implement a layer using for loop as below:

class MyParallelDenseLayer(tf.keras.layers.Layer):
    
    def __init__(self, dense_kwargs, **kwargs):
        super().__init__(**kwargs)
        self.dense_kwargs = dense_kwargs
    
    def build(self, input_shape):
        self.N, self.M, self.L = input_shape
        self.list_dense_layers = [tf.keras.layers.Dense(**self.dense_kwargs) for a_m in range(self.M)]
        super().build(input_shape)
        
    def call(self, inputs):
        parallel_output = [self.list_dense_layers[i](inputs[:, i]) for i in range(self.M)]
        return tf.keras.layers.Concatenate()(parallel_output)

But the for loop in the ‘call’ function makes my layer extremely slow.
Is there a faster way to do this layer?

Solution

This should be doable using einsum. Expand this layer to your liking with activation functions and whatnot.

class ParallelDense(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        super().build(input_shape)
        self.kernel = self.add_weight(shape=[input_shape[1], input_shape[2], self.units])

    def call(self, inputs):
        return tf.einsum("bml, mlk -> bmk", inputs, self.kernel)

Test it:

b = 16  # batch size
m = 200
l = 4  # no. of input features per m
k = 10  # no. of output features per m

layer = ParallelDense(k)
inp = tf.random.normal([b, m, l])

print(layer(inp).shape)

(16, 200, 10)

Answered By – xdurch0

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published