How to create a custom model class

Issue

sorry if this is a simple question or if I’m doing something wrong, but I am fairly new to both keras/tensorflow and python. I am trying to test some different models for image classification based on transfer learning. For that, I wanted to create a function to build a model where I only specified some parameters, and it automatically generates the desired model.
I have written the following code:

class modelMaker(tf.keras.Model):

  def __init__(self, img_height, img_width, trained='None'):
    super(modelMaker, self).__init__()
    self.x = tf.keras.Input(shape=(img_height, img_width, 3),name="input_layer")
    if (trained == 'None'):
      pass
    elif (trained == 'ResNet50'):
      self.x = tf.keras.applications.resnet50.preprocess_input(self.x)
      IMG_SHAPE = (img_height,img_width) + (3,)
      base_model = tf.keras.applications.ResNet50(input_shape=IMG_SHAPE,
                                                  include_top=False,
                                                  weights='imagenet')
      base_model.trainable = False
      for layer in base_model.layers:
        if isinstance(layer, keras.layers.BatchNormalization):
          layer.trainable = True
        else:
          layer.trainable = False
      self.x = base_model(self.x)

  def call(self, inputs):
    return self.x(inputs)

For now I only implemented the ResNet50 and an empty option, but I am planning to add more. The reason I tried to add layers using self.x = LAYER(self.x) was because the model can have a different number of layers based on future parameters.

However, when I try to get the summary of the model, using model.summary(), I get the following error:

ValueError: This model has not yet been built. Build the model first by calling build() or calling fit() with some data, or specify an input_shape argument in the first layer(s) for automatic build.

Is it possible to build models like this?
Thanks for the help

Solution

model.summary() needs some information about the input shape and the structure of your model (layers), in order to print them for you. So, somewhere you should give this information to the model object.

If you use a Sequential model or Functional API, just specifying input_shape parameter for running model.summary() is enough. If you don’t specify input_shape then you may call your model or use model.build to give this information.

But when you are using subclassing (Like what you did), objects of this class have no information about shapes and layers, unless you call the call() function (since you are defining your layers structure in the call function and passing input to it).

There are 3 ways to call call() function:

  1. model.fit(): calls it while training
    • may not fit your need, since you have to first train your model.
  2. model.build(): calls it internally
    • Just pass the shape of input like model.build((1,128,128,3))
  3. model(): call it directly
    • you need to pass at least one sample (tensor) like model(tf.random.uniform((1,128,128,3))

Modified code should be like this:

class modelMaker(tf.keras.Model):

    def __init__(self, img_height, img_width, num_classes=1, trained='dense'):
        super(modelMaker, self).__init__()
        self.trained = trained
        self.IMG_SHAPE = (img_height,img_width) + (3,)
        # define common layers
        self.flat = tf.keras.layers.Flatten(name="flatten")
        self.classify = tf.keras.layers.Dense(num_classes, name="classify")
        # define layers for when "trained" != "resnet"
        if self.trained == "dense":
            self.dense = tf.keras.layers.Dense(128, name="dense128") 
        
        # layers for when "trained" == "resnet"
        else:
            self.pre_resnet = tf.keras.applications.resnet50.preprocess_input
            self.base_model = tf.keras.applications.ResNet50(input_shape=self.IMG_SHAPE, include_top=False, weights='imagenet')
            self.base_model.trainable = False
            for layer in self.base_model.layers:
                if isinstance(layer, tf.keras.layers.BatchNormalization):
                    layer.trainable = True
                else:
                    layer.trainable = False
    
    def call(self, inputs):
        # define your model without resnet 
        if self.trained == "dense":
            x = self.flat(inputs)
            x = self.dense(x)
            x = self.classify(x)
            return x
        # define your model with resnet
        else:
            x = self.pre_resnet(inputs)
            x = self.base_model(x)
            x = self.flat(x)
            x = self.classify(x)
            return x
        
    # add this function to get correct output for model summary
    def summary(self):
        x = tf.keras.Input(shape=self.IMG_SHAPE, name="input_layer")
        model = tf.keras.Model(inputs=[x], outputs=self.call(x))
        return model.summary()
    
model = modelMaker(128, 128, trained="resnet") # create object
model.build((10,128,128,3))                    # build model
model.summary()                                # print summary

Output is:

Model: "model_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_layer (InputLayer)           [(None, 128, 128, 3)]     0         
_________________________________________________________________
tf.__operators__.getitem_6 ( (None, 128, 128, 3)       0         
_________________________________________________________________
tf.nn.bias_add_6 (TFOpLambda (None, 128, 128, 3)       0         
_________________________________________________________________
resnet50 (Functional)        (None, 4, 4, 2048)        23587712  
_________________________________________________________________
flatten (Flatten)            (None, 32768)             0         
_________________________________________________________________
classify (Dense)             (None, 1)                 32769     
=================================================================
Total params: 23,620,481
Trainable params: 32,769
Non-trainable params: 23,587,712
_________________________________________________________________

Answered By – Kaveh

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published