コード例 #1
0
  def testRemoveCheckpoint(self):
    for sharded in (False, True):
      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
        with self.test_session(graph=ops_lib.Graph()) as sess:
          unused_v = variables.Variable(1.0, name="v")
          variables.global_variables_initializer().run()
          saver = saver_module.Saver(sharded=sharded, write_version=version)

          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
          ckpt_prefix = saver.save(sess, path)
          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
          checkpoint_management.remove_checkpoint(ckpt_prefix, version)
          self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
コード例 #2
0
  def testRemoveCheckpoint(self):
    for sharded in (False, True):
      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
        with self.session(graph=ops_lib.Graph()) as sess:
          unused_v = variables.Variable(1.0, name="v")
          variables.global_variables_initializer().run()
          saver = saver_module.Saver(sharded=sharded, write_version=version)

          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
          ckpt_prefix = saver.save(sess, path)
          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
          checkpoint_management.remove_checkpoint(ckpt_prefix, version)
          self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))