def testNeedToCheckpoint(self): checkpointer = CheckpointSaver("", 0, 5, False) self.assertFalse(checkpointer.is_enabled()) checkpointer._steps = 3 self.assertTrue(checkpointer.is_enabled()) self.assertFalse(checkpointer.need_to_checkpoint(1)) self.assertFalse(checkpointer.need_to_checkpoint(2)) self.assertTrue(checkpointer.need_to_checkpoint(3)) self.assertFalse(checkpointer.need_to_checkpoint(4)) self.assertFalse(checkpointer.need_to_checkpoint(5)) self.assertTrue(checkpointer.need_to_checkpoint(6))
def testSaveLoadCheckpoint(self): init_var = m["custom_model"]().trainable_variables with tempfile.TemporaryDirectory() as tempdir: ckpt_dir = os.path.join(tempdir, "testSaveLoadCheckpoint") os.makedirs(ckpt_dir) checkpoint_saver = CheckpointSaver(ckpt_dir, 3, 5, False) self.assertTrue(checkpoint_saver.is_enabled()) params = Parameters() for var in init_var: params.non_embedding_params[var.name] = var model_pb = params.to_model_pb() checkpoint_saver.save(0, model_pb, False) ckpt_version_dir = os.path.join(ckpt_dir, "version-0") restore_params = CheckpointSaver.restore_params_from_checkpoint( ckpt_version_dir, 0, 1) self.assertEqual(restore_params.version, params.version) for var_name in params.non_embedding_params: self.assertTrue( np.array_equal( params.non_embedding_params[var_name].numpy(), restore_params.non_embedding_params[var_name].numpy(), ))