예제 #1
0
def wait_for_next_checkpoint(log_dir,
                             last_checkpoint=None,
                             seconds_to_sleep=1,
                             timeout=20):
  """Blocking wait until next checkpoint is written to logdir.

  Can timeout at regular intervals to log a timeout warning (a good indicator
  the thread is still alive).

  Args:
    log_dir: The directory in which checkpoints are saved.
    last_checkpoint: The last checkpoint path used or None if we're expecting a
      checkpoint for the first time.
    seconds_to_sleep: The number of seconds to sleep for before looking for a
      new checkpoint.
    timeout: The maximum amount of time to wait before printing timeout warning
      and checking for a new checkpoint. If left as None, then the thread will
      wait indefinitely.

  Returns:
    next_checkpoint filename.
  """

  while True:
    logging.info('Waiting for next policy checkpoint...')
    next_checkpoint = contrib_training.wait_for_new_checkpoint(
        log_dir,
        last_checkpoint,
        seconds_to_sleep=seconds_to_sleep,
        timeout=timeout)
    if next_checkpoint is None:
      logging.warn('Timeout waiting for checkpoint, trying again...')
    elif next_checkpoint != last_checkpoint:
      # Found a new checkpoint.
      logging.warn('Found a new checkpoint ("%s").', next_checkpoint)
      break
    else:
      logging.warn('No new checkpoint found, trying again...')

  return next_checkpoint
예제 #2
0
def evaluate(master,
             model_fn,
             data_fn,
             additional_trial_info,
             model_dir,
             preprocess_examples,
             hparams,
             name,
             num_steps=None):
  """Evaluation loop."""
  estimator = create_estimator(
      model_fn=model_fn, model_dir=model_dir, master=master, hparams=hparams)

  transcription_data_base = functools.partial(
      data_fn,
      preprocess_examples=preprocess_examples,
      is_training=False)

  if num_steps is None:
    transcription_data = functools.partial(
        transcription_data_base,
        shuffle_examples=False, skip_n_initial_records=0)
  else:
    # If num_steps is specified, we will evaluate only a subset of the data.
    #
    # The following is a hack that works around the problems of not being able
    # to determine the number of records in a given TFRecord shard without
    # reading the whole thing and not being able to persist a tf.data.Dataset
    # session across multiple estimator evaluate calls.
    #
    # This code tries to select a different subset for every evaluation by doing
    # the following:
    # - Setting shuffle_examples=True. This shuffles not only individual
    #   examples, but also shuffles the order in which shards are read.
    # - Skipping N examples before starting evaluation, where N is selected
    #   randomly for each evaluation run. This provides a different starting
    #   offset.

    # In order to skip a random number of records, we need to provide an upper
    # bound that will still let us run num_steps evaluation steps before running
    # out of data. The following code does a one-time check on startup to see
    # if there are up to num_steps * 5 records available, which would allow
    # a maximum skip range of [0, num_steps*4].
    records_to_check = num_steps * 5
    tf.logging.info('Checking for at least %d records...', records_to_check)
    records_available = 0
    with tf.Graph().as_default():
      record_check_params = copy.deepcopy(hparams)
      record_check_params.batch_size = 1
      iterator = transcription_data_base(
          params=record_check_params,
          shuffle_examples=False,
          skip_n_initial_records=0,
          ).make_initializable_iterator()
      next_record = iterator.get_next()
      with tf.Session() as sess:
        sess.run(iterator.initializer)
        try:
          for i in range(records_to_check):
            del i
            sess.run(next_record)
            records_available += 1
            if records_available % 10 == 0:
              tf.logging.info('Found %d records...', records_available)
        except tf.errors.OutOfRangeError:
          pass
    # Determine max number of records we could skip and still have num_steps
    # records remaining.
    max_records_to_skip = max(0, records_available - num_steps)
    tf.logging.info('Found at least %d records. '
                    'Will skip a maximum of %d records during eval runs '
                    'in order to support %d evaluation steps.',
                    records_available, max_records_to_skip, num_steps)

    # Since we're doing a limited number of steps, we should shuffle the
    # examples we're evaluating so each evaluation is over a different portion
    # of the dataset.
    def transcription_data(params, *args, **kwargs):
      assert not args
      skip_n_initial_records = random.randint(0, max_records_to_skip)
      tf.logging.info('Skipping %d initial record(s)', skip_n_initial_records)
      return transcription_data_base(
          params=params,
          shuffle_examples=True,
          skip_n_initial_records=skip_n_initial_records,
          **kwargs)

  _trial_summary(
      hparams=hparams,
      model_dir=model_dir,
      output_dir=estimator.eval_dir(name),
      additional_trial_info=additional_trial_info)

  checkpoint_path = None
  while True:
    checkpoint_path = contrib_training.wait_for_new_checkpoint(
        model_dir, last_checkpoint=checkpoint_path)
    estimator.evaluate(input_fn=transcription_data, steps=num_steps,
                       checkpoint_path=checkpoint_path, name=name)
예제 #3
0
      # only. Since learning goes very fast, we save often.
      if i % params['eval_steps'] == 0 or i == params['steps']:
        saver.save(sess, args.model_dir + '/model.ckpt', global_step=i)
else:
  print('Evaluating on %s' % params['partition'])
  # For each checkpoint the entire dataset is evaluated.
  steps_per_eval = params['%s_size' % params['partition']]
  checkpoint = None
  # Basic session since we will only manually save summaries.
  with tf.Session() as sess:
    coord = tf.train.Coordinator()
    # Queue runners will take care of reading data in seperate threads.
    threads = tf.train.start_queue_runners(coord=coord)
    while True:
      checkpoint = wait_for_new_checkpoint(args.model_dir,
                                           checkpoint,
                                           seconds_to_sleep=1,
                                           timeout=1200)
      if checkpoint is None:
        print('No checkpoint found for 20 min, exiting evaluation.')
        break
      # Init for variables that are not part of checkpoint,
      # in this case the ones used for metrics.
      sess.run(init)
      # Restore a checkpoint saved by the training run.
      saver.restore(sess, checkpoint)
      # Update the metrics for every element in the dataset.
      batch_steps = int(np.ceil(steps_per_eval/float(params['read_batch'])))
      for i in range(batch_steps):
        sess.run([eval_update])
      # Get the resulting metrics.
      cur_step, cur_reward, cur_summary = sess.run([global_step, mean_reward, merged_summary])