# Generalised inner product in TensorFlow

## Issue

I would like to calculate a generalised inner product in TensorFlow, similarly to this discussion for numpy.

In particular, I would like a function `inner_product(f,a,b)` that takes a function `f` (of two 1D tensors, which returns a scalar tensor) and applies `f` to slices of `a` and `b` such that the i,jth element of the output is given by `f(a[i,:], b[:,j])`.

This is just `tf.matmul` if `f(x, y) = tf.reduce_sum(x * y)`. However I’m struggling to come up with an efficient solution for other `f` functions. Something that gets the correct answer (assuming a function `f` with arguments `f(x, y)`) is

``````def inner_product(f, a, b):
def f_row_function(row, a):
return tf.map_fn(partial(f, y=row), a)

return tf.transpose(
tf.map_fn(partial(f_row_function, a=a), tf.transpose(b))
)
``````

but this is very slow (it’s doing effectively two loops over `f`).

As an example, with

``````a = tf.cast(
tf.constant([[1, 2, 3, 5, 2], [3, 4, 6, 3, 2], [1, 5, 6, 8, 1]]), tf.float32
)
b = tf.cast(
tf.constant(
[[4, 2, 5, 3], [4, 9, 3, 4], [1, 7, 8, 3], [4, 3, 5, 7], [1, 6, 7, 9]]
),
tf.float32,
)
def f(x, y):
r = tf.norm(x - y)
return tf.exp(-((0.1 * r) ** 2))
``````

`inner_product(f, a, b)` should give

``````<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[0.82695913, 0.423162  , 0.5117086 , 0.5433509 ],
[0.75578374, 0.6505091 , 0.6838614 , 0.4771139 ],
[0.6004956 , 0.506617  , 0.50157607, 0.45384485]], dtype=float32)>
``````

## Solution

Here my attempt to trade some memory for time. The idea is to generate the necessary pairs of `a` and `b` beforehand and then apply the reduction `f` on this (now larger) tensor once. Let me know if that’s faster for you.

``````import tensorflow as tf

a = tf.cast(
tf.constant([[1, 2, 3, 5, 2], [3, 4, 6, 3, 2], [1, 5, 6, 8, 1]]), tf.float32
)
b = tf.cast(
tf.constant(
[[4, 2, 5, 3], [4, 9, 3, 4], [1, 7, 8, 3], [4, 3, 5, 7], [1, 6, 7, 9]]
),
tf.float32,
)

@tf.function
def modified_f(x):
r = tf.norm(x[..., 0] - x[..., 1], axis=-1)
return tf.exp(-((0.1 * r) ** 2))

# Create indices for the axes we want to iterate (i.e. i and j)
a_is = tf.range(a.shape[0])
b_js = tf.range(b.shape[-1])
print(a_is.shape) # (3,)
print(b_js.shape) # (4,)

A_IS, B_JS = tf.meshgrid(a_is, b_js) # get all combinations of indices. The first two axes now correspond to the [i,j]
all_a = tf.gather(a, A_IS) # Now we extract the corresponding values from "a"
all_b = tf.gather(tf.transpose(b), B_JS) # and now from "b"
x = tf.stack([all_a, all_b], axis=-1) # stack both into a single array, you probably could skip this...
print(all_a.shape) # (4, 3, 5)
print(all_b.shape) # (4, 3, 5)
print(x.shape) # (4, 3, 5, 2)

@tf.function
def modified_f(x):
r = tf.norm(x[..., 0] - x[..., 1], axis=-1)
return tf.exp(-((0.1 * r) ** 2))

out = tf.transpose(modified_f(x))
print(out.shape) # (3, 4)
print(out)
# [[0.82695913 0.423162   0.5117086  0.5433508 ]
# [0.75578374 0.65050906 0.68386143 0.47711396]
# [0.6004956  0.50661695 0.50157607 0.45384485]], shape=(3, 4), dtype=float32)
``````