Using tensors as indexes in a network


My network has multiple inputs where one of those inputs is an index that is used in the network to index into other tensors.

I am having issues using the tensor as an index.

class MemoryLayer(tf.keras.layers.Layer):
  def __init__(self, memory_size, k, **kwargs):
    self.memory_size = memory_size
    self.k = k

  def build(self,input_shape):
    # Set up the memory_var

  # Shape of input is [(1,3,6), (1,3)]
  def call(self, input):
    for i in range(3):
      statement = input[0][0,i]
      cluster = input[1][0,i]
      old_sub_mem = self.memory_var[cluster, :-1] #Error here
      # Here should be a bunch of stuff I removed because its not relevant

    return tf.expand_dims(self.memory_var, axis=0)

I get a TypeError saying that <tf.Tensor 'memory_layer_19/strided_slice_2:0' shape=() dtype=float32> isn’t a valid index. I tried calling .numpy() on input[1] but this doesn’t work as the tensor has no shape. From the data I input cluster should be a single number.


By default, input to layers are tf.float32. However, to index a tensor, you need integers. You can either cast the inputs of your layer to integers, or you can specify that the input of that layer should be of the integer type.

With casting

cluster = tf.cast(input[1], dtype=tf.int32)[0,i]
old_sub_mem = self.memory_var[cluster, :-1]

Specifying the type of the input

I use the functional API in that example:

inp_statement = tf.keras.Input(shape=(3,6))
inp_cluster = tf.keras.Input(shape=(3,), dtype=tf.int32)
memory = MemoryLayer(memory_size=10)([inp1,inp2])

Note: I don’t exactly understand what you are trying to achieve, but this for loop could probably be optimized out with a call to tf.gather or tf.gather_nd.

Answered By – Lescurel

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published