Input tensor is passed to custom layer with wrong shape

Issue

I am trying to create a custom keras layer. This layer will be added to a trained model before deployment in order to have a model that contains all pre-processing steps. In training the preprocessing will be done outside of the model for multiple reasons. So this layer will never be used in training.

The input for the layer is a 1D tensor of floats with 64 components. This vector contains redundant values based on a mapping. The mapping is a name list of 64 strings. Elements with the same name are redundant. The layer is supposed to pick the median of each set of redundant values based on a mapping provided at the initialisation of the layer.

def pick_median(X, name, mapping):
  mask = (mapping == name) 
  median = tfp.stats.percentile(X[mask], 50.0, interpolation='midpoint')

  return median


class PickMedianLayer(tf.keras.layers.Layer):
  def __init__(self, mapping:List[str], **kwargs):
    super(PickMedianLayer, self).__init__(**kwargs)
    self.mapping= tf.constant(mapping, dtype='string')
    self.map_value_set =  tf.constant(list(set(mapping)), dtype='string')
    self.trainable = False

  def build(self, input_shape):
    pass

  def call(self, X):
    picked_vec = tf.map_fn(fn=lambda name: pick_median(X, name, self.mapping), elems=self.map_value_set, fn_output_signature='float32')
    return picked_vec 

I use tf.map_fn to create a tensor component for each unique name in the mapping.
As a first simple test of the layer I created a keras model only containing the custom layer and tried to predict a simple vector:

rel_model = Sequential()
rel_model.add(ToFuncLayer(mapping=mapping))

rel_model.compile()
#rel_model.build(input_shape=tf.TensorShape([64]))

vec = 1000. * tf.ones((64,), dtype='float32')
print(vec.shape)

vec1 = rel_model.predict(vec)
print(vec1)
print(vec1.shape)

I get the error "ValueError: Shapes (32,) and (64,) are incompatible" when executing the predict line. I also get the message:

Call arguments received by layer "pick_median_layer" (type PickMedianLayer):
      • X=tf.Tensor(shape=(32,), dtype=float32)

After creating the input vector vec, it has the shape TensorShape([64]). So it seems like the input vector is not fed into the layer correctly. By debugging I was able to verify that the other tensors inside the layer have the expected shapes and types. Also the mask is created as expected.

What am I doing wrong? Can anybody help me? Thank you!

Solution

I was able to fix the problem. It was a combination of three problems:

  1. The predict function requires the input to be a batch. So I had to create a 2d array with one row and 64 columns.

  2. You need to use tf.boolean_mask to apply a mask on tensors.

  3. Layer variables should be stored as a tf.Variable, not tf.Tensor.

The working solution looks like this:

def pick_median(X, name, mapping):
  mask = (mapping == name) 
  X_selected = tf.boolean_mask(
    X, mask, axis=1, name='boolean_mask'  
  )
  median = tfp.stats.percentile(X_selected , 50.0, interpolation='midpoint')

  return median

class PickMedianLayer(tf.keras.layers.Layer):
  def __init__(self, mapping:List[str], **kwargs):
    super(PickMedianLayer, self).__init__(**kwargs)
    self.mapping= tf.Variable(initial_value=np.array(mapping, dtype=np.str_), dtype='string', trainable=False)
    self.map_value_set= tf.Variable(initial_value=np.array(list(set(mapping))), dtype=np.str_), dtype='string', trainable=False)
    self.trainable = False

  def build(self, input_shape):
    pass

  def call(self, X):
    picked_vec = tf.map_fn(fn=lambda name: pick_median(X, name, self.mapping), elems=self.map_value_set, fn_output_signature='float32')
    return picked_vec 


rel_model = Sequential()
rel_model.add(PickMedianLayer(mapping=mapping))

rel_model.compile()

vec = np.ones((1, 64), dtype='float32')

vec1 = rel_model.predict(vec)

Answered By – Tilagiho

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published