memory increase when training in tensorflow2.2.0 keras graph with different shape

Issue

Below is the simple example:

import os
import psutil
import numpy as np
process = psutil.Process(os.getpid())
class TestKeras3:
    def __init__(self):
        pass

    def build_graph(self):
        inputs = tf.keras.Input(shape=(None, None, 3), batch_size=1)
        x = tf.keras.layers.Conv2D(100, (2, 2), padding='SAME', name='x')(inputs)
        y = tf.reshape(x, (-1,))
        z = tf.multiply(y, y)
        model = tf.keras.models.Model(inputs=inputs, outputs=z)
        return model

    def train(self):
        model = self.build_graph()
        model.summary()
        size = np.arange(1000)
        for i in range(1000):
            inputs = tf.random.normal([1, size[999-i], size[999-i], 3])
            with tf.GradientTape() as tage:
                output = model(inputs)
                print(i, tf.shape(output), process.memory_info().rss)

and the output is:

id                output_shape                  memory cost
979 tf.Tensor([40000], shape=(1,), dtype=int32) 2481123328
980 tf.Tensor([36100], shape=(1,), dtype=int32) 2481582080
981 tf.Tensor([32400], shape=(1,), dtype=int32) 2482122752
982 tf.Tensor([28900], shape=(1,), dtype=int32) 2482393088
983 tf.Tensor([25600], shape=(1,), dtype=int32) 2482933760
984 tf.Tensor([22500], shape=(1,), dtype=int32) 2483453952
985 tf.Tensor([19600], shape=(1,), dtype=int32) 2483793920
986 tf.Tensor([16900], shape=(1,), dtype=int32) 2484330496
987 tf.Tensor([14400], shape=(1,), dtype=int32) 2484871168
988 tf.Tensor([12100], shape=(1,), dtype=int32) 2485137408
989 tf.Tensor([10000], shape=(1,), dtype=int32) 2485665792
990 tf.Tensor([8100], shape=(1,), dtype=int32) 2486206464
991 tf.Tensor([6400], shape=(1,), dtype=int32) 2486579200
992 tf.Tensor([4900], shape=(1,), dtype=int32) 2487119872
993 tf.Tensor([3600], shape=(1,), dtype=int32) 2487390208
994 tf.Tensor([2500], shape=(1,), dtype=int32) 2487930880
995 tf.Tensor([1600], shape=(1,), dtype=int32) 2488463360
996 tf.Tensor([900], shape=(1,), dtype=int32) 2488811520
997 tf.Tensor([400], shape=(1,), dtype=int32) 2489335808
998 tf.Tensor([100], shape=(1,), dtype=int32) 2489868288
999 tf.Tensor([0], shape=(1,), dtype=int32) 2490241024

I found that every time I changed the size of the input, the consumption of memory also increased.

I have a question that the size (2,2,3,100) of the conv2D parameter in the model is fixed. Is it true that the model will cache some Tensor during the forward calculation process, which will cause the memory to increase all the time? If so, how can these resources be released during training? If not, what else is the reason?

Solution

So after trying many method, i solved this problem.
It seems that using tf common operation in a keras graph will cause a memory leak, which can be solved by packaging the tf common op into the tf.keras.layers.Layer subclass.

class ReshapeMulti(tf.keras.layers.Layer):
    def __init__(self):
        super(ReshapeMulti, self).__init__()

    def call(self, inputs):
        y = tf.reshape(inputs, (-1, ))
        z = tf.multiply(y, y)
        return z

class TestKeras3:
    def __init__(self):
        pass

    def build_graph(self):
        inputs = tf.keras.Input(shape=(None, None, 3), batch_size=1)
        x = tf.keras.layers.Conv2D(100, (2, 2), padding='SAME', name='x')(inputs)
        # y = tf.reshape(x, (-1,))
        # z = tf.multiply(y, y)
        z = ReshapeMulti()(x)
        model = tf.keras.models.Model(inputs=inputs, outputs=z)
        return model

    def train(self):
        model = self.build_graph()
        model.summary()
        size = np.arange(1000)
        for i in range(1000):
            inputs = tf.random.normal([1, size[999-i], size[999-i], 3])
            with tf.GradientTape() as tage:
                output = model(inputs)
                print(i, tf.shape(output), process.memory_info().rss)

Answered By – Yang

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published