calling Keras Model.evaluate() on every batch element separately

Issue

I would like to call tf.keras.Model.evaluate() (or a similar method) on a batch of my test_data, and I would like to get back the losses/metrics separately for every batch element. So if the batches are 64 element long I would like back a list of 64 losses/metrics.

I need this in order to find outliers in the test dataset.

I tried calling test_on_batch(), or evaluate() on single batches, but this method aggregates the batch result (I assume via mean), and batching every element singularly, although possible, takes 10-20x the time on my GPU.

Also I tried to call predict() and calculate manually the losses/metrics, but this approach also suffers from a steep drop of performance (from the subsequent required manual step of calculating every loss/metric from the test dataset and the predictions)

Is there a way to do this without compromising performance?

Solution

Using the TensorFlow metric/loss function with model.predict() is fast and doesn’t involve loops

considering this dummy classification task:

X = np.random.uniform(0,1, (64,28,28,1))
y = np.random.randint(0,2, 64)

model = Sequential([Flatten(), Dense(2, activation='softmax')])
model.compile('adam', 
              loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
model.fit(X,y, epochs=3)

you can evaluate the score for every batch element in this way:

scce = tf.keras.losses.sparse_categorical_crossentropy(y, model.predict(X))
# scce.shape ==> (64,)

scca = tf.keras.metrics.sparse_categorical_accuracy(y, model.predict(X))
# scca.shape ==> (64,)

these scores are the same scores aggregated by model.evaluate()

scce_eval, scca_eval = model.evaluate(X,y, verbose=0)

scce_eval is equal to tf.reduce_mean(scce)

scca_eval is equal to tf.reduce_mean(scca)

Answered By – Marco Cerliani

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published