## Issue

I would like to calculate a generalised inner product in TensorFlow, similarly to this discussion for numpy.

In particular, I would like a function `inner_product(f,a,b)`

that takes a function `f`

(of two 1D tensors, which returns a scalar tensor) and applies `f`

to slices of `a`

and `b`

such that the *i*,*j*th element of the output is given by `f(a[i,:], b[:,j])`

.

This is just `tf.matmul`

if `f(x, y) = tf.reduce_sum(x * y)`

. However I’m struggling to come up with an efficient solution for other `f`

functions. Something that gets the correct answer (assuming a function `f`

with arguments `f(x, y)`

) is

```
def inner_product(f, a, b):
def f_row_function(row, a):
return tf.map_fn(partial(f, y=row), a)
return tf.transpose(
tf.map_fn(partial(f_row_function, a=a), tf.transpose(b))
)
```

but this is very slow (it’s doing effectively two loops over `f`

).

As an example, with

```
a = tf.cast(
tf.constant([[1, 2, 3, 5, 2], [3, 4, 6, 3, 2], [1, 5, 6, 8, 1]]), tf.float32
)
b = tf.cast(
tf.constant(
[[4, 2, 5, 3], [4, 9, 3, 4], [1, 7, 8, 3], [4, 3, 5, 7], [1, 6, 7, 9]]
),
tf.float32,
)
def f(x, y):
r = tf.norm(x - y)
return tf.exp(-((0.1 * r) ** 2))
```

`inner_product(f, a, b)`

should give

```
<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[0.82695913, 0.423162 , 0.5117086 , 0.5433509 ],
[0.75578374, 0.6505091 , 0.6838614 , 0.4771139 ],
[0.6004956 , 0.506617 , 0.50157607, 0.45384485]], dtype=float32)>
```

## Solution

Here my attempt to trade some memory for time. The idea is to generate the necessary pairs of `a`

and `b`

beforehand and then apply the reduction `f`

on this (now larger) tensor **once**. Let me know if that’s faster for you.

```
import tensorflow as tf
a = tf.cast(
tf.constant([[1, 2, 3, 5, 2], [3, 4, 6, 3, 2], [1, 5, 6, 8, 1]]), tf.float32
)
b = tf.cast(
tf.constant(
[[4, 2, 5, 3], [4, 9, 3, 4], [1, 7, 8, 3], [4, 3, 5, 7], [1, 6, 7, 9]]
),
tf.float32,
)
@tf.function
def modified_f(x):
r = tf.norm(x[..., 0] - x[..., 1], axis=-1)
return tf.exp(-((0.1 * r) ** 2))
# Create indices for the axes we want to iterate (i.e. i and j)
a_is = tf.range(a.shape[0])
b_js = tf.range(b.shape[-1])
print(a_is.shape) # (3,)
print(b_js.shape) # (4,)
A_IS, B_JS = tf.meshgrid(a_is, b_js) # get all combinations of indices. The first two axes now correspond to the [i,j]
all_a = tf.gather(a, A_IS) # Now we extract the corresponding values from "a"
all_b = tf.gather(tf.transpose(b), B_JS) # and now from "b"
x = tf.stack([all_a, all_b], axis=-1) # stack both into a single array, you probably could skip this...
print(all_a.shape) # (4, 3, 5)
print(all_b.shape) # (4, 3, 5)
print(x.shape) # (4, 3, 5, 2)
@tf.function
def modified_f(x):
r = tf.norm(x[..., 0] - x[..., 1], axis=-1)
return tf.exp(-((0.1 * r) ** 2))
out = tf.transpose(modified_f(x))
print(out.shape) # (3, 4)
print(out)
# [[0.82695913 0.423162 0.5117086 0.5433508 ]
# [0.75578374 0.65050906 0.68386143 0.47711396]
# [0.6004956 0.50661695 0.50157607 0.45384485]], shape=(3, 4), dtype=float32)
```

Answered By – André

**This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0 **