`torch.gather` without unbroadcasting

Issue

I have some batched input x of shape [batch, time, feature], and some batched indices i of shape [batch, new_time] which I want to gather into the time dim of x. As output of this operation I want a tensor y of shape [batch, new_time, feature] with values like this:

y[b, t', f] = x[b, i[b, t'], f]

In Tensorflow, I can accomplish this by using the batch_dims: int argument of tf.gather: y = tf.gather(x, i, axis=1, batch_dims=1).

In PyTorch, I can think of some functions which do similar things:

  1. torch.gather of course, but this does not have an argument similar to Tensorflow’s batch_dims. The output of torch.gather will always have the same shape as the indices. So I would need to unbroadcast the feature dim into i before passing it to torch.gather.

  2. torch.index_select, but here, the indices must be one-dimensional. So to make it work I would need to unbroadcast x to add a "batch * new_time" dim, and then after torch.index_select reshape the output.

  3. torch.nn.functional.embedding. Here, the embedding matrices would correspond to x. But this embedding function does not support the weights to be batched, so I run into the same issue as for torch.index_select (looking at the code, tf.embedding uses torch.index_select under the hood).

Is it possible to accomplish such gather operation without relying on unbroadcasting which is inefficient for large dims?

Solution

This is actually the most frequent case: when input and index tensors don’t perfectly match the number of dimensions. You can still utilize torch.gather though since you can rewrite your expression:

y[b, t, f] = x[b, i[b, t], f]

as:

y[b, t, f] = x[b, i[b, t, f], f]

which ensures all three tensors have an equal number of dimensions. This reveals a third dimension on i, which we can easily create for free by unsqueezing a dimension and expanding it to the shape of x. You can do so with i[:,None].expand_as(x).

Here is a minimal example:

>>> b = 2; t = 3; f = 1
>>> x = torch.rand(b, t, f)
>>> i = torch.randint(0, t, (b, f))

>>> x.gather(1, i[:,None].expand_as(x))

Answered By – Ivan

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published