Esempio n. 1
0
  def testReturnsEmptyIfNoCheckpointsFound(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'no_checkpoints_found')

    num_found = 0
    for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
      num_found += 1
    self.assertEqual(num_found, 0)
Esempio n. 2
0
  def testReturnsSingleCheckpointIfOneShardedCheckpoint(self):
    if tf.executing_eagerly():
      return
    checkpoint_dir = os.path.join(self.get_temp_dir(),
                                  'one_checkpoint_found_sharded')
    if not tf.io.gfile.exists(checkpoint_dir):
      tf.io.gfile.makedirs(checkpoint_dir)

    global_step = tf.compat.v1.train.get_or_create_global_step()

    # This will result in 3 different checkpoint shard files.
    with tf.device('/cpu:0'):
      tf.Variable(10, name='v0')
    with tf.device('/cpu:1'):
      tf.Variable(20, name='v1')

    saver = tf.compat.v1.train.Saver(sharded=True)

    with tf.compat.v1.Session(
        target='',
        config=tf.compat.v1.ConfigProto(device_count={'CPU': 2})) as session:

      session.run(tf.compat.v1.global_variables_initializer())
      save_path = os.path.join(checkpoint_dir, 'model.ckpt')
      saver.save(session, save_path, global_step=global_step)

    num_found = 0
    for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
      num_found += 1
    self.assertEqual(num_found, 1)
Esempio n. 3
0
  def testTimeoutFn(self):
    timeout_fn_calls = [0]
    def timeout_fn():
      timeout_fn_calls[0] += 1
      return timeout_fn_calls[0] > 3

    results = list(
        evaluation.checkpoints_iterator(
            '/non-existent-dir', timeout=0.1, timeout_fn=timeout_fn))
    self.assertEqual([], results)
    self.assertEqual(4, timeout_fn_calls[0])
Esempio n. 4
0
def run_continuous_eval(hparams):
    """What to run in continuous eval mode."""
    tf.logging.info('Continuous evaluation.')
    estimator = make_estimator(hparams)
    timeout = hparams.debug_params.continuous_eval_timeout_secs
    for ckpt_str in evaluation.checkpoints_iterator(hparams.model_dir,
                                                    timeout=timeout):
        tf.logging.info('Evaluating checkpoint: %s' % ckpt_str)
        estimator.evaluate(train_eval_input_fn,
                           steps=hparams.num_eval_steps,
                           name='eval_continuous')
        tf.logging.info('Finished evaluating checkpoint: %s' % ckpt_str)
Esempio n. 5
0
  def testReturnsSingleCheckpointIfOneCheckpointFound(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'one_checkpoint_found')
    if not tf.io.gfile.exists(checkpoint_dir):
      tf.io.gfile.makedirs(checkpoint_dir)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    saver = tf.compat.v1.train.Saver()  # Saves the global step.

    with self.cached_session() as session:
      session.run(tf.compat.v1.global_variables_initializer())
      save_path = os.path.join(checkpoint_dir, 'model.ckpt')
      saver.save(session, save_path, global_step=global_step)

    num_found = 0
    for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
      num_found += 1
    self.assertEqual(num_found, 1)