Saving Attributes and Methods with Tensorflow for custom Model

Issue

I created a basic model with a custom method – new_method – and a custom attribute – testing – that I want to save. Is it possible to do so using model.save()? Below is an example of what I wish to accomplish.

@tf.keras.utils.register_keras_serializable()
class GreatClass(tf.keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.testing = 3424
        self.dense = tf.keras.layers.Dense(100)
        
    def get_config(self):
        config = super().get_config()
        config['testing'] = self.testing
        config['dense'] = self.dense
        return config
    
    def new_method(self):
        print('hello world')
    
    def call(self, inputs):
        return self.dense(inputs)

Below I create and save an instance of the above class.

model = GreatClass()
model.compile()

array = np.array([100,10])
model.predict(array)

model.save('testing')

I can save the model, but the loaded model does not have access to the new_method method or testing attribute.

loaded_model = tf.keras.models.load_model("testing")
reconstructed_model.new_method()

AttributeError: ‘Custom>GreatClass’ object has no attribute ‘new_method’

reconstructed_model.get_vars()

‘Custom>GreatClass’ object has no attribute ‘testing’

Is it possible to save custom methods and attributes using model.save()?

Solution

Save your attributes that you want to serialize as tf.Variables and use tf.function with an input_signature to save and load methods. See here for more details. Here is an example:

import tensorflow as tf

class GreatClass(tf.keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.testing = tf.Variable(3424, trainable=False)
        self.dense = tf.keras.layers.Dense(100)

    @tf.function(input_signature=[])
    def new_method(self):
        tf.print('hello world')
    
    def call(self, inputs):
        return self.dense(inputs)

model = GreatClass()
model.compile()
model(tf.random.normal((1, 10)))
model.save('testing')


loaded_model = tf.keras.models.load_model("testing")
loaded_model.new_method()
loaded_model.testing
hello world
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=3424>

Answered By – AloneTogether

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published