Tensorflow custom reduction function with axis support

Issue

I would like to get the value with the maximum absolute value in a tensor, with respect to an axis. Note that I don’t want the maximum absolute value, I want the value that has the maximum absolute value (so I need to keep the sign).

Ideally, I would like something similar to reduce_max or reduce_min:

tensor = tf.constant(
  [
    [[ 1,  5, -3],
     [ 2, -3,  1],
     [ 3, -6,  2]],

    [[-2,  3, -5],
     [-1,  4,  2],
     [ 4, -1,  0]]
   ]
)
# tensor.shape = (2, 3, 3)

tensor.reduce_maxamplitude(tensor, axis=0)
# Tensor(
#  [[-2,  5, -5],
#   [ 2,  4,  2],
#   [ 4, -6,  2]]
# )
# shape: (3, 3)

tensor.reduce_maxamplitude(tensor, axis=1)
# Tensor(
#  [[3, -6, -3],
#   [4,  4, -5]]
# )
# shape: (2, 3)

tensor.reduce_maxamplitude(tensor, axis=2)
# Tensor(
#  [[5, -3, -6],
#   [-5,  4, 4]]
# )
# shape: (2, 3)

but I did not find anything useful in tensorflow documentation.

With a flat tensor, I know that I could use tf.foldl or tf.foldr:

flat = tf.reshape(tensor, -1)
tf.foldr(lambda a, x: x if tf.abs(x) > tf.abs(a) else a, flat)
# -6

However, I don’t know how to handle an axis parameter in the case of multidimensional tensors.

Solution

It really depends on how many dimensions your tensor has, but for a 2D tensor you could just do:

import tensorflow as tf

tensor = tf.constant(
  [[1,  5, -3],
   [2, -3,  1],
   [3, -6,  2]])

tf.gather(tensor, tf.argmax(tf.abs(tensor), axis=1), axis=1, batch_dims=1)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 5, -3, -6], dtype=int32)>

3D example:

tensor = tf.constant(
  [
    [[ 1,  5, -3],
     [ 2, -3,  1],
     [ 3, -6,  2]],

    [[-2,  3, -5],
     [-1,  4,  2],
     [ 4, -1,  0]]
   ]
)

# axis = 0
argmax = tf.argmax(tf.abs(tensor), axis=0)
i, j = tf.meshgrid(
    tf.range(tensor.shape[1], dtype=tf.int64), 
    tf.range(tensor.shape[2], dtype=tf.int64),
                              indexing='ij')
tf.gather_nd(tensor, tf.stack([argmax, i, j], axis=-1))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[-2,  5, -5],
       [ 2,  4,  2],
       [ 4, -6,  2]], dtype=int32)>
# axis = 1
argmax = tf.argmax(tf.abs(tensor), axis=1)
i, j = tf.meshgrid(
    tf.range(tensor.shape[0], dtype=tf.int64), 
    tf.range(tensor.shape[2], dtype=tf.int64),
                              indexing='ij')
tf.gather_nd(tensor, tf.stack([i, argmax, j], axis=-1))
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 3, -6, -3],
       [ 4,  4, -5]], dtype=int32)>
# axis = 2
i, j = tf.meshgrid(
    tf.range(tensor.shape[0], dtype=tf.int64), 
    tf.range(tensor.shape[1], dtype=tf.int64),
                              indexing='ij')
tf.gather_nd(tensor, tf.stack([i, j, argmax], axis=-1))
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 5, -3, -6],
       [-5,  4,  4]], dtype=int32)>

For a 4D tensor just extend the meshgrid:

# axis=-1
i, j, k = tf.meshgrid(
    tf.range(tensor.shape[0], dtype=tf.int64), 
    tf.range(tensor.shape[1], dtype=tf.int64),
    tf.range(tensor.shape[2], dtype=tf.int64),
                              indexing='ij')

Quick function bundling everything by @leleogere

def reduce_maxamplitude(tensor, axis):
    argmax = tf.argmax(tf.abs(tensor), axis=axis)
    mesh = tf.meshgrid(
        *[tf.range(tensor.shape[i], dtype=tf.int64) for i in range(tensor.shape.rank) if i != axis],
        indexing='ij'
    )
    return tf.gather_nd(tensor, tf.stack([*mesh[:axis], argmax, *mesh[axis:]], axis=-1))

Answered By – AloneTogether

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published