Binary classificationModel can't train and auc/accuracy 0.5

Issue

I am trying to train a model that predicts two classes 1 and 0.
I am using a balanced dataset for train and the input of the classifier is a pretrained embeddings with shape (466,1024).

The problem is that my model can only predict class 0.

I have an auc 0.5 and not changing accuracy.

I tried to change the optimizer/ Learning rate/ the loss/ Number of units in layers and activation function but I am still having the same issue.

I am using this classifier:

model = tf.keras.models.Sequential([
    Dense(n_input, input_shape = (n_input,), activation = elu),   # Input layer
    
    Dense(n_hidden1, activation = elu), # hidden layer 1
    Dropout(dropout_prob),     
    
    Dense(n_hidden2, activation = elu), # hidden layer 2
    Dropout(dropout_prob), 
    
    Dense(n_hidden3, activation = elu), # hidden layer 3
    Dropout(dropout_prob), 
    
    Dense(n_output, activation = "sigmoid")  # Output layer
])

model summary:

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_14 (Dense)             (None, 1024)              1049600   
_________________________________________________________________
dense_15 (Dense)             (None, 1024)              1049600   
_________________________________________________________________
dropout_9 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_16 (Dense)             (None, 800)               820000    
_________________________________________________________________
dropout_10 (Dropout)         (None, 800)               0         
_________________________________________________________________
dense_17 (Dense)             (None, 400)               320400    
_________________________________________________________________
dropout_11 (Dropout)         (None, 400)               0         
_________________________________________________________________
dense_18 (Dense)             (None, 1)                 401       
=================================================================
Total params: 3,240,001
Trainable params: 3,240,001
Non-trainable params: 0

And for the training and validation

bce_loss = tf.keras.losses.BinaryCrossentropy()
accuracy = tf.keras.metrics.BinaryAccuracy()
optimizer = tf.optimizers.Adam(learning_rate = 0.001)

model.fit(embeddings_train, np.array(y_train),
          validation_data=(embeddings_test, np.array(y_test)), epochs=300)

and this is the history of the training

Epoch 1/300
15/15 [==============================] - 0s 18ms/step - loss: 0.8819 - binary_accuracy: 0.4657 - val_loss: 0.5634 - val_binary_accuracy: 0.7699
Epoch 2/300
15/15 [==============================] - 0s 12ms/step - loss: 0.7407 - binary_accuracy: 0.4657 - val_loss: 0.6985 - val_binary_accuracy: 0.2301
Epoch 3/300
15/15 [==============================] - 0s 12ms/step - loss: 0.7156 - binary_accuracy: 0.4700 - val_loss: 0.6065 - val_binary_accuracy: 0.7699
Epoch 4/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6995 - binary_accuracy: 0.5043 - val_loss: 0.8415 - val_binary_accuracy: 0.2301
Epoch 5/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6960 - binary_accuracy: 0.5043 - val_loss: 0.8285 - val_binary_accuracy: 0.2301
Epoch 6/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6600 - binary_accuracy: 0.5687 - val_loss: 0.7431 - val_binary_accuracy: 0.2301
Epoch 7/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6825 - binary_accuracy: 0.5300 - val_loss: 0.5497 - val_binary_accuracy: 0.8110
Epoch 8/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6625 - binary_accuracy: 0.5901 - val_loss: 0.5609 - val_binary_accuracy: 0.8137
Epoch 9/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5959 - binary_accuracy: 0.6459 - val_loss: 0.5328 - val_binary_accuracy: 0.8082
Epoch 10/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6123 - binary_accuracy: 0.6416 - val_loss: 0.4928 - val_binary_accuracy: 0.8110
Epoch 11/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6222 - binary_accuracy: 0.5966 - val_loss: 0.5483 - val_binary_accuracy: 0.8055
Epoch 12/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6850 - binary_accuracy: 0.6137 - val_loss: 0.6373 - val_binary_accuracy: 0.8027
Epoch 13/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6078 - binary_accuracy: 0.6717 - val_loss: 0.4625 - val_binary_accuracy: 0.8110
Epoch 14/300
15/15 [==============================] - 0s 11ms/step - loss: 0.6453 - binary_accuracy: 0.5837 - val_loss: 0.6565 - val_binary_accuracy: 0.6795
Epoch 15/300
15/15 [==============================] - 0s 11ms/step - loss: 0.5891 - binary_accuracy: 0.6524 - val_loss: 0.5181 - val_binary_accuracy: 0.8137
Epoch 16/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5548 - binary_accuracy: 0.6867 - val_loss: 0.4467 - val_binary_accuracy: 0.8164
Epoch 17/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5540 - binary_accuracy: 0.6888 - val_loss: 0.7391 - val_binary_accuracy: 0.3014
Epoch 18/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5437 - binary_accuracy: 0.7382 - val_loss: 0.5215 - val_binary_accuracy: 0.8164
Epoch 19/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5772 - binary_accuracy: 0.7082 - val_loss: 0.5882 - val_binary_accuracy: 0.7863
Epoch 20/300
15/15 [==============================] - 0s 11ms/step - loss: 0.5670 - binary_accuracy: 0.7232 - val_loss: 0.4696 - val_binary_accuracy: 0.8137
Epoch 21/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5664 - binary_accuracy: 0.6953 - val_loss: 0.4637 - val_binary_accuracy: 0.8110
Epoch 22/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5300 - binary_accuracy: 0.7103 - val_loss: 0.9387 - val_binary_accuracy: 0.2712
Epoch 23/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5733 - binary_accuracy: 0.6631 - val_loss: 0.7253 - val_binary_accuracy: 0.3178
Epoch 24/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5222 - binary_accuracy: 0.7339 - val_loss: 0.4540 - val_binary_accuracy: 0.8110
Epoch 25/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4770 - binary_accuracy: 0.7854 - val_loss: 0.4343 - val_binary_accuracy: 0.8137
Epoch 26/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6285 - binary_accuracy: 0.6502 - val_loss: 0.6575 - val_binary_accuracy: 0.6849
Epoch 27/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5458 - binary_accuracy: 0.7425 - val_loss: 0.5516 - val_binary_accuracy: 0.7781
Epoch 28/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4916 - binary_accuracy: 0.7339 - val_loss: 0.4976 - val_binary_accuracy: 0.8274
Epoch 29/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4897 - binary_accuracy: 0.7575 - val_loss: 0.5883 - val_binary_accuracy: 0.7671
Epoch 30/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4722 - binary_accuracy: 0.7961 - val_loss: 0.6962 - val_binary_accuracy: 0.6575
Epoch 31/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5001 - binary_accuracy: 0.7361 - val_loss: 0.4427 - val_binary_accuracy: 0.8164
Epoch 32/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5273 - binary_accuracy: 0.7210 - val_loss: 0.5754 - val_binary_accuracy: 0.7836
Epoch 33/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5391 - binary_accuracy: 0.7124 - val_loss: 0.5743 - val_binary_accuracy: 0.8000
Epoch 34/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5869 - binary_accuracy: 0.7253 - val_loss: 0.4628 - val_binary_accuracy: 0.8137
Epoch 35/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5848 - binary_accuracy: 0.6459 - val_loss: 0.5605 - val_binary_accuracy: 0.8164
Epoch 36/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5117 - binary_accuracy: 0.7489 - val_loss: 0.4373 - val_binary_accuracy: 0.8164
Epoch 37/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4901 - binary_accuracy: 0.8112 - val_loss: 0.4458 - val_binary_accuracy: 0.8247
Epoch 38/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4542 - binary_accuracy: 0.8155 - val_loss: 0.6725 - val_binary_accuracy: 0.7068
Epoch 39/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4584 - binary_accuracy: 0.8026 - val_loss: 0.5492 - val_binary_accuracy: 0.7808
Epoch 40/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4387 - binary_accuracy: 0.8240 - val_loss: 0.4297 - val_binary_accuracy: 0.8192
Epoch 41/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4361 - binary_accuracy: 0.7961 - val_loss: 0.7132 - val_binary_accuracy: 0.6822
Epoch 42/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4312 - binary_accuracy: 0.8283 - val_loss: 0.4252 - val_binary_accuracy: 0.8247
Epoch 43/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5493 - binary_accuracy: 0.7103 - val_loss: 0.5322 - val_binary_accuracy: 0.8192
Epoch 44/300
15/15 [==============================] - 0s 12ms/step - loss: 0.6070 - binary_accuracy: 0.6695 - val_loss: 0.4435 - val_binary_accuracy: 0.8219
Epoch 45/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4556 - binary_accuracy: 0.8047 - val_loss: 0.5101 - val_binary_accuracy: 0.8219
Epoch 46/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5358 - binary_accuracy: 0.7554 - val_loss: 0.4601 - val_binary_accuracy: 0.8329
Epoch 47/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4053 - binary_accuracy: 0.8391 - val_loss: 0.8246 - val_binary_accuracy: 0.6356
Epoch 48/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4564 - binary_accuracy: 0.7897 - val_loss: 0.4349 - val_binary_accuracy: 0.8192
Epoch 49/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4999 - binary_accuracy: 0.7532 - val_loss: 0.9320 - val_binary_accuracy: 0.3507
Epoch 50/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5715 - binary_accuracy: 0.6910 - val_loss: 0.5226 - val_binary_accuracy: 0.8164
Epoch 51/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5390 - binary_accuracy: 0.7103 - val_loss: 0.5946 - val_binary_accuracy: 0.7589
Epoch 52/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4583 - binary_accuracy: 0.7811 - val_loss: 0.4642 - val_binary_accuracy: 0.8301
Epoch 53/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4385 - binary_accuracy: 0.8262 - val_loss: 0.4273 - val_binary_accuracy: 0.8274
Epoch 54/300
15/15 [==============================] - 0s 11ms/step - loss: 0.4468 - binary_accuracy: 0.8090 - val_loss: 0.4227 - val_binary_accuracy: 0.8247
Epoch 55/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4794 - binary_accuracy: 0.7768 - val_loss: 0.5298 - val_binary_accuracy: 0.8274
Epoch 56/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4287 - binary_accuracy: 0.8369 - val_loss: 0.4394 - val_binary_accuracy: 0.8274
Epoch 57/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4631 - binary_accuracy: 0.7682 - val_loss: 0.6234 - val_binary_accuracy: 0.7288
Epoch 58/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4495 - binary_accuracy: 0.8112 - val_loss: 0.4525 - val_binary_accuracy: 0.8384
Epoch 59/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4026 - binary_accuracy: 0.8391 - val_loss: 0.6229 - val_binary_accuracy: 0.7452
Epoch 60/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4011 - binary_accuracy: 0.8369 - val_loss: 0.4382 - val_binary_accuracy: 0.8274
Epoch 61/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4482 - binary_accuracy: 0.7811 - val_loss: 0.4247 - val_binary_accuracy: 0.8301
Epoch 62/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4543 - binary_accuracy: 0.7876 - val_loss: 0.4216 - val_binary_accuracy: 0.8301
Epoch 63/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4149 - binary_accuracy: 0.8240 - val_loss: 0.4449 - val_binary_accuracy: 0.8356
Epoch 64/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4314 - binary_accuracy: 0.8004 - val_loss: 0.6862 - val_binary_accuracy: 0.6822
Epoch 65/300
15/15 [==============================] - 0s 11ms/step - loss: 0.4343 - binary_accuracy: 0.8197 - val_loss: 0.4688 - val_binary_accuracy: 0.8329
Epoch 66/300
15/15 [==============================] - 0s 12ms/step - loss: 0.3882 - binary_accuracy: 0.8326 - val_loss: 0.9579 - val_binary_accuracy: 0.3616
Epoch 67/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4689 - binary_accuracy: 0.7876 - val_loss: 0.4199 - val_binary_accuracy: 0.8301
Epoch 68/300
15/15 [==============================] - 0s 13ms/step - loss: 0.4261 - binary_accuracy: 0.8197 - val_loss: 0.4541 - val_binary_accuracy: 0.8356
Epoch 69/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5395 - binary_accuracy: 0.7318 - val_loss: 0.4623 - val_binary_accuracy: 0.8219
Epoch 70/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4715 - binary_accuracy: 0.8219 - val_loss: 0.5407 - val_binary_accuracy: 0.7836
Epoch 71/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4223 - binary_accuracy: 0.8133 - val_loss: 0.4479 - val_binary_accuracy: 0.8247
Epoch 72/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4519 - binary_accuracy: 0.8004 - val_loss: 0.5923 - val_binary_accuracy: 0.7589
Epoch 73/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4383 - binary_accuracy: 0.7983 - val_loss: 0.5267 - val_binary_accuracy: 0.7890
Epoch 74/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4701 - binary_accuracy: 0.7854 - val_loss: 0.5831 - val_binary_accuracy: 0.7644
Epoch 75/300
15/15 [==============================] - 0s 12ms/step - loss: 0.5001 - binary_accuracy: 0.7511 - val_loss: 0.6085 - val_binary_accuracy: 0.7397
Epoch 76/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4607 - binary_accuracy: 0.8090 - val_loss: 0.6974 - val_binary_accuracy: 0.7014
Epoch 77/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4493 - binary_accuracy: 0.8283 - val_loss: 0.4760 - val_binary_accuracy: 0.8356
Epoch 78/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4071 - binary_accuracy: 0.8348 - val_loss: 0.6835 - val_binary_accuracy: 0.7233
Epoch 79/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4785 - binary_accuracy: 0.7940 - val_loss: 0.4332 - val_binary_accuracy: 0.8274
Epoch 80/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4130 - binary_accuracy: 0.8369 - val_loss: 0.5052 - val_binary_accuracy: 0.8110
Epoch 81/300
15/15 [==============================] - 0s 11ms/step - loss: 0.4021 - binary_accuracy: 0.8348 - val_loss: 0.4832 - val_binary_accuracy: 0.8301
Epoch 82/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4400 - binary_accuracy: 0.8112 - val_loss: 0.5291 - val_binary_accuracy: 0.7945
Epoch 83/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4275 - binary_accuracy: 0.8326 - val_loss: 0.5291 - val_binary_accuracy: 0.7890
Epoch 84/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4009 - binary_accuracy: 0.8412 - val_loss: 0.4268 - val_binary_accuracy: 0.8301
Epoch 85/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4073 - binary_accuracy: 0.8240 - val_loss: 0.4347 - val_binary_accuracy: 0.8356
Epoch 86/300
15/15 [==============================] - 0s 11ms/step - loss: 0.4035 - binary_accuracy: 0.8219 - val_loss: 0.4368 - val_binary_accuracy: 0.8274
Epoch 87/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4531 - binary_accuracy: 0.8112 - val_loss: 0.6171 - val_binary_accuracy: 0.7370
Epoch 88/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4331 - binary_accuracy: 0.8176 - val_loss: 0.4348 - val_binary_accuracy: 0.8356
Epoch 89/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4153 - binary_accuracy: 0.8219 - val_loss: 0.4363 - val_binary_accuracy: 0.8356
Epoch 90/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4107 - binary_accuracy: 0.8305 - val_loss: 0.4521 - val_binary_accuracy: 0.8329
Epoch 91/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4703 - binary_accuracy: 0.7768 - val_loss: 0.5170 - val_binary_accuracy: 0.8082
Epoch 92/300
15/15 [==============================] - 0s 11ms/step - loss: 0.4098 - binary_accuracy: 0.8455 - val_loss: 0.4286 - val_binary_accuracy: 0.8301
Epoch 93/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4631 - binary_accuracy: 0.7876 - val_loss: 0.4579 - val_binary_accuracy: 0.8384
Epoch 94/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4668 - binary_accuracy: 0.7704 - val_loss: 0.4289 - val_binary_accuracy: 0.8274
Epoch 95/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4109 - binary_accuracy: 0.8498 - val_loss: 0.4636 - val_binary_accuracy: 0.8384
Epoch 96/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4211 - binary_accuracy: 0.8219 - val_loss: 0.7324 - val_binary_accuracy: 0.6603
Epoch 97/300
15/15 [==============================] - 0s 11ms/step - loss: 0.4477 - binary_accuracy: 0.7790 - val_loss: 0.4237 - val_binary_accuracy: 0.8247
Epoch 98/300
15/15 [==============================] - 0s 11ms/step - loss: 0.4167 - binary_accuracy: 0.8305 - val_loss: 0.4499 - val_binary_accuracy: 0.8329
Epoch 99/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4005 - binary_accuracy: 0.8348 - val_loss: 0.6721 - val_binary_accuracy: 0.7068
Epoch 100/300
15/15 [==============================] - 0s 12ms/step - loss: 0.4581 - binary_accuracy: 0.7918 - val_loss: 0.6318 - val_binary_accuracy: 0.7178
<tensorflow.python.keras.callbacks.History at 0x7fc960249190>

I am having predictions only with class 0.

prediction = model.predict(embeddings_val)
prediction = np.argmax(prediction, axis=1)
testdata = np.argmax(y_val, axis=1)
from sklearn.metrics import confusion_matrix
CM = confusion_matrix(prediction, testdata)

Test Accuracy: 1.0
AUC of the model and roc curve:

AUC of the model and roc curve

Solution

np.argmax returns index of maximum element in array. If you have 1 element (Your model output is 1 for each sample), then it gives you 0.

You have 1 value, which if it is less than 0.5 you may interpret it to blong class 0 and if higher than 0.5, it belongs to class 1. You can do it by np.where instead of argmax:

#prediction = np.argmax(prediction, axis=1) #comment this line
prediction =  np.where(prediction<0.5,0,1)  #add this line

#testdata = np.argmax(y_val, axis=1)        #comment this line
#testdata = np.where(y_val<0.5,0,1)         #Probably y_val is currently categorized in 1 and 0, and adding this line instead of line above is useless

Answered By – Kaveh

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published