Why TensorFlow throws this exception when loading a model that was normalized like this?

Issue

All latest versions from the very moment of this post.

tensorflow-gpu: 2.6.0
Python: 3.9.7
CUDA: 11.4.2
cuDNN: 8.2.4

As in the code below, when loading a model that was normalized by not passing arguments to Normalization() it throws an exception when that model is loaded by load_model(), however before loading the model I can use it without any apparent issues which makes you think it’s all good since Normalization() did NOT complain and took care of the input shape. When loading a model that was normalized by Normalization(input_dim=5) it does NOT thrown any exception since a known shape is specified. That is weird I mean it should warn you that when normalizing it without passing arguments to Normalization() you should expect an exception when loading it.

I’m not sure if it’s a bug so I’m posting it here before reporting a bug in the github section, maybe I’m missing to setup something.

Here’s my code:

import numpy as np
import tensorflow as tf


def main():
    train_data = np.array([[1, 2, 3, 4, 5]])
    train_label = np.array([123])
    
    # Uncomment this to load the model and comment the next model and normalizer related lines.
    #model = tf.keras.models.load_model('AI/test.h5')

    normalizer = tf.keras.layers.experimental.preprocessing.Normalization()
    normalizer.adapt(train_data)

    model = tf.keras.Sequential([normalizer, tf.keras.layers.Dense(units=1)])

    model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.1), loss='mean_absolute_error')
    model.fit(train_data, train_label, epochs=3000)

    model.save('AI/test.h5')

    unseen_data = np.array([[1, 2, 3, 4, 6]])

    prediction = model.predict(unseen_data)
    print(prediction)


if __name__ == "__main__":
    main()

It throws the following exception:

Traceback (most recent call last):
  File "E:\Backup\Desktop\tensorflow_test.py", line 30, in <module>
    main()
  File "E:\Backup\Desktop\tensorflow_test.py", line 11, in main
    model = tf.keras.models.load_model('AI/test.h5')
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\saving\save.py", line 200, in load_model
    return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\saving\hdf5_format.py", line 180, in load_model_from_hdf5
    model = model_config_lib.model_from_config(model_config,
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\saving\model_config.py", line 52, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\layers\serialization.py", line 208, in deserialize
    return generic_utils.deserialize_keras_object(
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\utils\generic_utils.py", line 674, in deserialize_keras_object
    deserialized_obj = cls.from_config(
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\engine\sequential.py", line 434, in from_config
    model.add(layer)
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\tensorflow\python\training\tracking\base.py", line 530, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\engine\sequential.py", line 217, in add
    output_tensor = layer(self.outputs[0])
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\engine\base_layer.py", line 976, in __call__
    return self._functional_construction_call(inputs, args, kwargs,
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\engine\base_layer.py", line 1114, in _functional_construction_call
    outputs = self._keras_tensor_symbolic_call(
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\engine\base_layer.py", line 848, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\engine\base_layer.py", line 886, in _infer_output_signature
    self._maybe_build(inputs)
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\engine\base_layer.py", line 2659, in _maybe_build
    self.build(input_shapes)  # pylint:disable=not-callable
  File "C:\Users\censored\AppData\Local\Programs\Python\Python39\lib\site-packages\keras\layers\preprocessing\normalization.py", line 145, in build
    raise ValueError(
ValueError: All `axis` values to be kept must have known shape. Got axis: (-1,), input shape: [None, None], with unknown axis at index: 1

Process finished with exit code 1

Solution

It looks like a bug.
Follow this link

if 'input_dim' in kwargs and 'input_shape' not in kwargs:
  # Backwards compatibility: alias 'input_dim' to 'input_shape'.
  kwargs['input_shape'] = (kwargs['input_dim'],)
if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
  # In this case we will later create an input layer
  # to insert before the current layer
  if 'batch_input_shape' in kwargs:
    batch_input_shape = tuple(kwargs['batch_input_shape'])
  elif 'input_shape' in kwargs:
    if 'batch_size' in kwargs:
      batch_size = kwargs['batch_size']
    else:
      batch_size = None
    batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
  self._batch_input_shape = batch_input_shape

The error occurs because the normalization could not get any shape information which would lead to self._input_batch_shape =(None, None).

But when loading model(deserialization), It would call build function which should have known shape in all axes.

# Sorted to avoid transposing axes.
self._keep_axis = sorted([d if d >= 0 else d + ndim for d in self.axis])
# All axes to be kept should have known shape.
for d in self._keep_axis:
  if input_shape[d] is None:
    raise ValueError(
        'All `axis` values to be kept must have known shape. Got axis: {}, '
        'input shape: {}, with unknown axis at index: {}'.format(
            self.axis, input_shape, d))

Answered By – FancyXun

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published