# Tensorflow custom reduction function with axis support

## Issue

I would like to get the value with the maximum absolute value in a tensor, with respect to an axis. Note that I don’t want the maximum absolute value, I want the value that has the maximum absolute value (so I need to keep the sign).

Ideally, I would like something similar to reduce_max or reduce_min:

tensor = tf.constant(
[
[[ 1,  5, -3],
[ 2, -3,  1],
[ 3, -6,  2]],

[[-2,  3, -5],
[-1,  4,  2],
[ 4, -1,  0]]
]
)
# tensor.shape = (2, 3, 3)

tensor.reduce_maxamplitude(tensor, axis=0)
# Tensor(
#  [[-2,  5, -5],
#   [ 2,  4,  2],
#   [ 4, -6,  2]]
# )
# shape: (3, 3)

tensor.reduce_maxamplitude(tensor, axis=1)
# Tensor(
#  [[3, -6, -3],
#   [4,  4, -5]]
# )
# shape: (2, 3)

tensor.reduce_maxamplitude(tensor, axis=2)
# Tensor(
#  [[5, -3, -6],
#   [-5,  4, 4]]
# )
# shape: (2, 3)

but I did not find anything useful in tensorflow documentation.

With a flat tensor, I know that I could use tf.foldl or tf.foldr:

flat = tf.reshape(tensor, -1)
tf.foldr(lambda a, x: x if tf.abs(x) > tf.abs(a) else a, flat)
# -6

However, I don’t know how to handle an axis parameter in the case of multidimensional tensors.

## Solution

It really depends on how many dimensions your tensor has, but for a 2D tensor you could just do:

import tensorflow as tf

tensor = tf.constant(
[[1,  5, -3],
[2, -3,  1],
[3, -6,  2]])

tf.gather(tensor, tf.argmax(tf.abs(tensor), axis=1), axis=1, batch_dims=1)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 5, -3, -6], dtype=int32)>

3D example:

tensor = tf.constant(
[
[[ 1,  5, -3],
[ 2, -3,  1],
[ 3, -6,  2]],

[[-2,  3, -5],
[-1,  4,  2],
[ 4, -1,  0]]
]
)

# axis = 0
argmax = tf.argmax(tf.abs(tensor), axis=0)
i, j = tf.meshgrid(
tf.range(tensor.shape[1], dtype=tf.int64),
tf.range(tensor.shape[2], dtype=tf.int64),
indexing='ij')
tf.gather_nd(tensor, tf.stack([argmax, i, j], axis=-1))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[-2,  5, -5],
[ 2,  4,  2],
[ 4, -6,  2]], dtype=int32)>
# axis = 1
argmax = tf.argmax(tf.abs(tensor), axis=1)
i, j = tf.meshgrid(
tf.range(tensor.shape[0], dtype=tf.int64),
tf.range(tensor.shape[2], dtype=tf.int64),
indexing='ij')
tf.gather_nd(tensor, tf.stack([i, argmax, j], axis=-1))
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 3, -6, -3],
[ 4,  4, -5]], dtype=int32)>
# axis = 2
i, j = tf.meshgrid(
tf.range(tensor.shape[0], dtype=tf.int64),
tf.range(tensor.shape[1], dtype=tf.int64),
indexing='ij')
tf.gather_nd(tensor, tf.stack([i, j, argmax], axis=-1))
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 5, -3, -6],
[-5,  4,  4]], dtype=int32)>

For a 4D tensor just extend the meshgrid:

# axis=-1
i, j, k = tf.meshgrid(
tf.range(tensor.shape[0], dtype=tf.int64),
tf.range(tensor.shape[1], dtype=tf.int64),
tf.range(tensor.shape[2], dtype=tf.int64),
indexing='ij')

Quick function bundling everything by @leleogere

def reduce_maxamplitude(tensor, axis):
argmax = tf.argmax(tf.abs(tensor), axis=axis)
mesh = tf.meshgrid(
*[tf.range(tensor.shape[i], dtype=tf.int64) for i in range(tensor.shape.rank) if i != axis],
indexing='ij'
)
return tf.gather_nd(tensor, tf.stack([*mesh[:axis], argmax, *mesh[axis:]], axis=-1))