def testReturnsSingleCheckpointIfOneShardedCheckpoint(self): checkpoint_dir = tempfile.mkdtemp('one_checkpoint_found_sharded') if not gfile.Exists(checkpoint_dir): gfile.MakeDirs(checkpoint_dir) global_step = variables.get_or_create_global_step() # This will result in 3 different checkpoint shard files. with ops.device('/cpu:0'): variables_lib.Variable(10, name='v0') with ops.device('/cpu:1'): variables_lib.Variable(20, name='v1') saver = saver_lib.Saver(sharded=True) with session_lib.Session( target='', config=config_pb2.ConfigProto(device_count={'CPU': 2})) as session: session.run(variables_lib.global_variables_initializer()) save_path = os.path.join(checkpoint_dir, 'model.ckpt') saver.save(session, save_path, global_step=global_step) num_found = 0 for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0): num_found += 1 self.assertEqual(num_found, 1)
def testReturnsEmptyIfNoCheckpointsFound(self): checkpoint_dir = tempfile.mkdtemp('no_checkpoints_found') num_found = 0 for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0): num_found += 1 self.assertEqual(num_found, 0)
def testTimeoutFn(self): timeout_fn_calls = [0] def timeout_fn(): timeout_fn_calls[0] += 1 return timeout_fn_calls[0] > 3 results = list( evaluation.checkpoints_iterator( '/non-existent-dir', timeout=0.1, timeout_fn=timeout_fn)) self.assertEqual([], results) self.assertEqual(4, timeout_fn_calls[0])
def testReturnsSingleCheckpointIfOneCheckpointFound(self): checkpoint_dir = tempfile.mkdtemp('one_checkpoint_found') if not gfile.Exists(checkpoint_dir): gfile.MakeDirs(checkpoint_dir) global_step = variables.get_or_create_global_step() saver = saver_lib.Saver() # Saves the global step. with self.cached_session() as session: session.run(variables_lib.global_variables_initializer()) save_path = os.path.join(checkpoint_dir, 'model.ckpt') saver.save(session, save_path, global_step=global_step) num_found = 0 for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0): num_found += 1 self.assertEqual(num_found, 1)
def testMonitorCheckpointsLoopTimeout(self): ret = list( evaluation_lib.checkpoints_iterator('/non-existent-dir', timeout=0)) self.assertEqual(ret, [])