Efficiently implementing chamfer's distance as a loss for tensorflow


I’d like to implement Chamfer’s distance as a loss for my Tensorflow model, but it’s very computationally impractical. Is there a more efficient approach to the minimal running example below? (The input and output are of size (1, 216, 216, 3).

import tensorflow as tf

class EulerResnetBlock(tf.keras.Model):
    def __init__(self):
        super(EulerResnetBlock, self).__init__()

        self.conv2a = tf.keras.layers.Conv2D(50, 1, padding='same')
        self.conv2b = tf.keras.layers.Conv2D(3, 1, padding='same')
    def call(self, input_tensor, training=False):
        return tf.nn.relu(x + self.conv2b(tf.nn.relu(self.conv2a(input_tensor))))

# custom class for computing Chamfer's distance
class ChamfersDistance(tf.keras.losses.Loss):
    def call(self, y_true, y_pred):
        # y_true and y_pred size: (1, 216, 216, 3)
        cd = 0
        for i in range(216):
            for j in range(216):
                cd +=  tf.math.add(tf.math.sqrt(tf.math.reduce_min(tf.math.reduce_sum(tf.math.square(y_pred[0,i,j,:]-y_true), axis=3))),
                                      tf.math.sqrt(tf.math.reduce_min(tf.math.reduce_sum(tf.math.square(y_true[0,i,j,:]-y_pred), axis=3))))
        return cd

And the net:

eulernet = tf.keras.models.Sequential()
initializer = tf.keras.initializers.HeNormal()

eulernet.add(tf.keras.layers.Reshape((46656, 3)))

opt = tf.keras.optimizers.SGD(learning_rate=10e-2, momentum=0.5)
loss_func = ChamfersDistance()
eulernet.compile(optimizer=opt, loss=loss_func)

I think my implementation is OK since it’s in TensorFlow so the automatic gradients work out, just I’m not sure why it’s taking so slow in the first place.


To begin with, it’s slow by definition. Given 2 sets, you have a loop (2 since input is 2 dims), and then some O(n) operations, so your loss is at least O(n^2)…

Given this, we can see that with two 216×216 images, you will have a matrix pairwise distance that is (216×216)^2, that considering that you probably will use float32 as type, it will consume:

216 * 216 * 216 * 216 * 32 = 69 Gb

this, for each pair of images, thus if you have batches of 10 images (which is pretty small), considering that you will use around 690 gigabytes of memory

This, without considering:

  • the time it takes to fill that matrix
  • that Tensor in TF are limited to 2GB of size.

Now, either the paper that you are considering used images with size in the order of 50×50, and still the distance matrix is pretty big and slow to compute, or they used an iterative algorithm you are doing.

About the implementation, I tried a parallelized version of your code:

# custom class for computing Chamfer's distance
class ChamfersDistanceVect(tf.keras.losses.Loss):
    def call(self, y_true, y_pred):
        # y_true and y_pred size: (1, 216, 216, 3)
        dim1 = y_true.shape[1]
        dim2 = y_true.shape[2]
        mat1 = np.reshape(np.repeat(np.expand_dims(np.linspace(0,dim1-1,dim1, dtype="float32"), -1), dim2, axis=1), -1)
        mat2 = np.reshape(np.repeat(np.transpose(np.expand_dims(np.linspace(0,dim2-1,dim2, dtype="float32"), -1)), dim2, axis=0), -1)
        indexes = np.transpose(np.stack((mat1, mat2)))
        def dist (index):
            i,j = tf.cast(index, tf.int32).numpy()
            return tf.math.add(tf.math.sqrt(tf.math.reduce_min(tf.math.reduce_sum(tf.math.square(y_pred[0,i,j,:]-y_true), axis=3))),
                        tf.math.sqrt(tf.math.reduce_min(tf.math.reduce_sum(tf.math.square(y_true[0,i,j,:]-y_pred), axis=3))))
        res = tf.reduce_sum(
            tf.map_fn(dist, indexes)
        return res

It is faster during the first run, since it’s Graph is easier to build, but it’s slower after that, in particular, those are the numbers (using only ONE image with size 50×50):

  • sequential version: first run (so with Graph creation) 21 seconds, then 12
  • vectorized version: first run 12 seconds, then 9 seconds

Given everything said, either you allow TF to implement the Graph during the first run, and thus the first run will take much longer than the following ones, or you have to reduce your input size significantly

Answered By – Alberto Sinigaglia

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published