def testReturnsEmptyIfNoCheckpointsFound(self): checkpoint_dir = os.path.join(self.get_temp_dir(), 'no_checkpoints_found') num_found = 0 for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0): num_found += 1 self.assertEqual(num_found, 0)
def testReturnsSingleCheckpointIfOneShardedCheckpoint(self): if tf.executing_eagerly(): return checkpoint_dir = os.path.join(self.get_temp_dir(), 'one_checkpoint_found_sharded') if not tf.io.gfile.exists(checkpoint_dir): tf.io.gfile.makedirs(checkpoint_dir) global_step = tf.compat.v1.train.get_or_create_global_step() # This will result in 3 different checkpoint shard files. with tf.device('/cpu:0'): tf.Variable(10, name='v0') with tf.device('/cpu:1'): tf.Variable(20, name='v1') saver = tf.compat.v1.train.Saver(sharded=True) with tf.compat.v1.Session( target='', config=tf.compat.v1.ConfigProto(device_count={'CPU': 2})) as session: session.run(tf.compat.v1.global_variables_initializer()) save_path = os.path.join(checkpoint_dir, 'model.ckpt') saver.save(session, save_path, global_step=global_step) num_found = 0 for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0): num_found += 1 self.assertEqual(num_found, 1)
def testTimeoutFn(self): timeout_fn_calls = [0] def timeout_fn(): timeout_fn_calls[0] += 1 return timeout_fn_calls[0] > 3 results = list( evaluation.checkpoints_iterator( '/non-existent-dir', timeout=0.1, timeout_fn=timeout_fn)) self.assertEqual([], results) self.assertEqual(4, timeout_fn_calls[0])
def run_continuous_eval(hparams): """What to run in continuous eval mode.""" tf.logging.info('Continuous evaluation.') estimator = make_estimator(hparams) timeout = hparams.debug_params.continuous_eval_timeout_secs for ckpt_str in evaluation.checkpoints_iterator(hparams.model_dir, timeout=timeout): tf.logging.info('Evaluating checkpoint: %s' % ckpt_str) estimator.evaluate(train_eval_input_fn, steps=hparams.num_eval_steps, name='eval_continuous') tf.logging.info('Finished evaluating checkpoint: %s' % ckpt_str)
def testReturnsSingleCheckpointIfOneCheckpointFound(self): checkpoint_dir = os.path.join(self.get_temp_dir(), 'one_checkpoint_found') if not tf.io.gfile.exists(checkpoint_dir): tf.io.gfile.makedirs(checkpoint_dir) global_step = tf.compat.v1.train.get_or_create_global_step() saver = tf.compat.v1.train.Saver() # Saves the global step. with self.cached_session() as session: session.run(tf.compat.v1.global_variables_initializer()) save_path = os.path.join(checkpoint_dir, 'model.ckpt') saver.save(session, save_path, global_step=global_step) num_found = 0 for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0): num_found += 1 self.assertEqual(num_found, 1)