Cannot load checkpoints

Issue

I taught a model (tensorflow tutorial) in Jupyter then saved it, then succesfully loaded it back (kernel was restarted). Here’s the code:

# Directory where the checkpoints will be saved
checkpoint_dir = '/home/charlie-chin/william_model/training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

model.save('/home/charlie-chin/william_model')

model = keras.models.load_model('/home/charlie-chin/william_model', custom_objects={'loss':loss})

checkpoint_num = 10
model.load_weights(tf.train.Checkpoint("/home/charlie-chin/william_model/training_checkpoints/ckpt_" + str(checkpoint_num)))

All went good except the last 2 lines which gave me this error:

ValueError: `Checkpoint` was expecting root to be a trackable object (an object derived from `Trackable`), got /home/charlie-chin/william_model/training_checkpoints/ckpt_1. If you believe this object should be trackable (i.e. it is part of the TensorFlow Python API and manages state), please open an issue.

I checked the path – it is correct. Here’s full output of the error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [39], in <cell line: 4>()
      1 checkpoint_num = 10
      2 # model.load_weights(tf.train.load_checkpoint("./william_model/training_checkpoints/ckpt_"))
      3 # model.load_weights(tf.train.Checkpoint("/home/charlie-chin/william_model/training_checkpoints/ckpt_" + str(checkpoint_num)+".data-00000-of-00001"))
----> 4 model.load_weights(tf.train.Checkpoint("/home/charlie-chin/william_model/training_checkpoints/ckpt_" + str(checkpoint_num)))

File ~/.local/lib/python3.8/site-packages/tensorflow/python/training/tracking/util.py:2107, in Checkpoint.__init__(self, root, **kwargs)
   2105 if root:
   2106   trackable_root = root() if isinstance(root, weakref.ref) else root
-> 2107   _assert_trackable(trackable_root, "root")
   2108   attached_dependencies = []
   2110   # All keyword arguments (including root itself) are set as children
   2111   # of root.

File ~/.local/lib/python3.8/site-packages/tensorflow/python/training/tracking/util.py:1546, in _assert_trackable(obj, name)
   1543 def _assert_trackable(obj, name):
   1544   if not isinstance(
   1545       obj, (base.Trackable, def_function.Function)):
-> 1546     raise ValueError(
   1547         f"`Checkpoint` was expecting {name} to be a trackable object (an "
   1548         f"object derived from `Trackable`), got {obj}. If you believe this "
   1549         "object should be trackable (i.e. it is part of the "
   1550         "TensorFlow Python API and manages state), please open an issue.")

ValueError: `Checkpoint` was expecting root to be a trackable object (an object derived from `Trackable`), got /home/charlie-chin/william_model/training_checkpoints/ckpt_10. If you believe this object should be trackable (i.e. it is part of the TensorFlow Python API and manages state), please open an issue.

Solution

You should be able to load the checkpoints according to the TensorFlow documentation like this:

checkpoint_num = 10
model.load_weights("/home/charlie-chin/william_model/training_checkpoints/ckpt_" + str(checkpoint_num))

Answered By – claudia

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published