Multi-class weighted loss for semantic image segmentation in keras/tensorflow

Issue

Given batched RGB images as input, shape=(batch_size, width, height, 3)

And a multiclass target represented as one-hot, shape=(batch_size, width, height, n_classes)

And a model (Unet, DeepLab) with softmax activation in last layer.

I’m looking for weighted categorical-cross-entropy loss funciton in kera/tensorflow.

The class_weight argument in fit_generator doesn’t seems to work, and I didn’t find the answer here or in https://github.com/keras-team/keras/issues/2115.

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        # y_true, y_pred shape is (batch_size, width, height, n_classes)
        loos = ?...
        return loss

    return wcce

Solution

I will answer my question:

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        Kweights = K.constant(weights)
        if not K.is_tensor(y_pred): y_pred = K.constant(y_pred)
        y_true = K.cast(y_true, y_pred.dtype)
        return K.categorical_crossentropy(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1)
    return wcce

Usage:

loss = weighted_categorical_crossentropy(weights)
optimizer = keras.optimizers.Adam(lr=0.01)
model.compile(optimizer=optimizer, loss=loss)

Answered By – Mendi Barel

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published