def testUpdateCheckpointState(self):
   save_dir = self._get_test_dir("update_checkpoint_state")
   os.chdir(save_dir)
   # Make a temporary train directory.
   train_dir = "train"
   os.mkdir(train_dir)
   abs_path = os.path.join(save_dir, "model-0")
   rel_path = os.path.join("train", "model-2")
   checkpoint_management.update_checkpoint_state(
       train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
   ckpt = checkpoint_management.get_checkpoint_state(train_dir)
   self.assertEqual(ckpt.model_checkpoint_path, rel_path)
   self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
   self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
   self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
 def testFSPath(self):
     save_dir = self._get_test_dir("fspath")
     os.chdir(save_dir)
     # Make a temporary train directory.
     train_dir = "train"
     os.mkdir(train_dir)
     abs_path = os.path.join(save_dir, "model-0")
     rel_path = os.path.join("train", "model-2")
     checkpoint_management.update_checkpoint_state(
         train_dir,
         rel_path,
         all_model_checkpoint_paths=[abs_path, rel_path])
     ckpt = checkpoint_management.get_checkpoint_state(
         pathlib.Path(train_dir))
     self.assertEqual(ckpt.model_checkpoint_path, rel_path)
 def testUpdateCheckpointState(self):
   save_dir = self._get_test_dir("update_checkpoint_state")
   os.chdir(save_dir)
   # Make a temporary train directory.
   train_dir = "train"
   os.mkdir(train_dir)
   abs_path = os.path.join(save_dir, "model-0")
   rel_path = os.path.join("train", "model-2")
   checkpoint_management.update_checkpoint_state(
       train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
   ckpt = checkpoint_management.get_checkpoint_state(train_dir)
   self.assertEqual(ckpt.model_checkpoint_path, rel_path)
   self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
   self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
   self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)