def testLastSavedStep(self): model = _DummyModel() model(tf.random.uniform([4, 10])) model_dir = os.path.join(self.get_temp_dir(), "model") checkpoint = checkpoint_util.Checkpoint(model, model_dir=model_dir) self.assertIsNone(checkpoint.last_saved_step) checkpoint.save(10) self.assertEqual(checkpoint.last_saved_step, 10) checkpoint.save(20) self.assertEqual(checkpoint.last_saved_step, 20) # Property should not be bound to an instance. checkpoint = checkpoint_util.Checkpoint(model, model_dir=model_dir) self.assertEqual(checkpoint.last_saved_step, 20)
def _init_model(self, config): model = misc.clone_layer(self._model) model.initialize(config["data"], params=config["params"]) if "optimizer" in config["params"]: optimizer = model.get_optimizer() else: optimizer = None checkpoint = checkpoint_util.Checkpoint( model, optimizer=optimizer, model_dir=config.get("model_dir"), keep_checkpoint_max=config["train"].get("keep_checkpoint_max", 8)) return checkpoint