Generalised inner product in TensorFlow

Issue

I would like to calculate a generalised inner product in TensorFlow, similarly to this discussion for numpy.

In particular, I would like a function inner_product(f,a,b) that takes a function f (of two 1D tensors, which returns a scalar tensor) and applies f to slices of a and b such that the i,jth element of the output is given by f(a[i,:], b[:,j]).

This is just tf.matmul if f(x, y) = tf.reduce_sum(x * y). However I’m struggling to come up with an efficient solution for other f functions. Something that gets the correct answer (assuming a function f with arguments f(x, y)) is

def inner_product(f, a, b):
    def f_row_function(row, a):
        return tf.map_fn(partial(f, y=row), a)

    return tf.transpose(
        tf.map_fn(partial(f_row_function, a=a), tf.transpose(b))
    )

but this is very slow (it’s doing effectively two loops over f).

As an example, with

a = tf.cast(
    tf.constant([[1, 2, 3, 5, 2], [3, 4, 6, 3, 2], [1, 5, 6, 8, 1]]), tf.float32
)
b = tf.cast(
    tf.constant(
        [[4, 2, 5, 3], [4, 9, 3, 4], [1, 7, 8, 3], [4, 3, 5, 7], [1, 6, 7, 9]]
    ),
    tf.float32,
)
def f(x, y):
    r = tf.norm(x - y)
    return tf.exp(-((0.1 * r) ** 2))

inner_product(f, a, b) should give

<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[0.82695913, 0.423162  , 0.5117086 , 0.5433509 ],
       [0.75578374, 0.6505091 , 0.6838614 , 0.4771139 ],
       [0.6004956 , 0.506617  , 0.50157607, 0.45384485]], dtype=float32)>

Solution

Here my attempt to trade some memory for time. The idea is to generate the necessary pairs of a and b beforehand and then apply the reduction f on this (now larger) tensor once. Let me know if that’s faster for you.

import tensorflow as tf

a = tf.cast(
    tf.constant([[1, 2, 3, 5, 2], [3, 4, 6, 3, 2], [1, 5, 6, 8, 1]]), tf.float32
)
b = tf.cast(
    tf.constant(
        [[4, 2, 5, 3], [4, 9, 3, 4], [1, 7, 8, 3], [4, 3, 5, 7], [1, 6, 7, 9]]
    ),
    tf.float32,
)

@tf.function
def modified_f(x):
    r = tf.norm(x[..., 0] - x[..., 1], axis=-1) 
    return tf.exp(-((0.1 * r) ** 2))

# Create indices for the axes we want to iterate (i.e. i and j)
a_is = tf.range(a.shape[0])
b_js = tf.range(b.shape[-1])
print(a_is.shape) # (3,)
print(b_js.shape) # (4,)

A_IS, B_JS = tf.meshgrid(a_is, b_js) # get all combinations of indices. The first two axes now correspond to the [i,j]
all_a = tf.gather(a, A_IS) # Now we extract the corresponding values from "a"
all_b = tf.gather(tf.transpose(b), B_JS) # and now from "b"
x = tf.stack([all_a, all_b], axis=-1) # stack both into a single array, you probably could skip this...
print(all_a.shape) # (4, 3, 5)
print(all_b.shape) # (4, 3, 5)
print(x.shape) # (4, 3, 5, 2)

@tf.function
def modified_f(x):
    r = tf.norm(x[..., 0] - x[..., 1], axis=-1) 
    return tf.exp(-((0.1 * r) ** 2))

out = tf.transpose(modified_f(x))
print(out.shape) # (3, 4)
print(out)
# [[0.82695913 0.423162   0.5117086  0.5433508 ]
# [0.75578374 0.65050906 0.68386143 0.47711396]
# [0.6004956  0.50661695 0.50157607 0.45384485]], shape=(3, 4), dtype=float32)

Answered By – André

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published