Exemplo n.º 1
0
 def make_test_iterator():
     return corpus.get_batches(
         data['test'],
         vocab,
         eval_batch_size(data['test']) if not dyneval else 1,
         max_time_steps,
         episodic=episodic,
         deterministic=True,
         max_epochs=1,
         max_batches=max_test_eval_batches,
         conditioning_separator=conditioning_separator)
Exemplo n.º 2
0
 def make_train_iterator():
     return corpus.get_batches(
         data['training'],
         vocab,
         batch_size,
         max_time_steps,
         episodic=episodic,
         deterministic=True,
         max_epochs=1,
         max_batches=max_training_eval_batches,
         conditioning_separator=conditioning_separator)
Exemplo n.º 3
0
 def make_quick_eval_iterator():
     quick_eval_max_batches = max(
         1, eval_softmax_temperature_estimation_num_tokens //
         batch_size // max_time_steps)
     return corpus.get_batches(
         data['valid'],
         vocab,
         batch_size,
         max_time_steps,
         episodic=episodic,
         deterministic=True,
         max_epochs=1,
         max_batches=quick_eval_max_batches,
         conditioning_separator=conditioning_separator)
Exemplo n.º 4
0
def _train_loop(monitor, lr_scheduler, averaged, dyneval, model,
                data, vocab, config, summary_writer, save_last_checkpoint_fn):
  source_iterator = corpus.get_batches(
      data['training'], vocab,
      config.batch_size,
      config.max_time_steps,
      num_samples=config.num_training_samples,
      episodic=FLAGS.episodic,
      deterministic=False,
      conditioning_separator=config.conditioning_separator)
  last_state = None
  steps_per_sec = 0.0

  def munge_max_batches_flag_value(max_batches):
    if max_batches == -1:
      return None
    else:
      return max_batches

  def evaluate0():
    # KLUDGE: This depends on monitor calling this function before using the
    # worst target.
    monitor.set_es_worst_target(es_worst_target())
    global_step = model.global_step()
    logging.info('turn: %s (eval), step: %d (opt) (%.2f/s)',
                 monitor.turn(), global_step, steps_per_sec)
    if config.accum_batch_size == -1:
      eval_batch_size = config.batch_size
    else:
      eval_batch_size = config.accum_batch_size
    training_xe, valid_xe, test_xe = evaluation.evaluate_all(
        model, data, vocab, eval_batch_size, config.max_time_steps,
        FLAGS.min_non_episodic_eval_examples_per_stripe,
        munge_max_batches_flag_value(FLAGS.max_training_eval_batches),
        munge_max_batches_flag_value(FLAGS.max_eval_eval_batches),
        munge_max_batches_flag_value(FLAGS.max_test_eval_batches),
        FLAGS.episodic,
        config.eval_softmax_temperature,
        config.eval_softmax_temperature_estimation_num_tokens,
        config.eval_method,
        config.num_eval_samples,
        config.eval_power_mean_power,
        config.eval_dropout_multiplier,
        config.validation_prediction_file,
        dyneval,
        conditioning_separator=config.conditioning_separator)
    return valid_xe, {'training_xe': training_xe,
                      'test_xe': test_xe,
                      'global_step': global_step}

  def evaluate():
    if monitor.averaging_triggered():
      with averaged:
        logging.info('Evaluating with averaged parameters.')
        return evaluate0()
    else:
      return evaluate0()

  def add_summary(summary_str):
    if summary_writer is not None:
      summary_writer.add_summary(summary_str, model.global_step())

  def add_summaries_for_metrics():
    metrics = monitor.metrics()
    summary = tf.Summary()
    for key in metrics:
      summary.value.add(tag=key, simple_value=metrics[key])
    add_summary(summary)

  # Compute the early stopping worst target. It may change when the learning
  # rate is dropped.
  def es_worst_target():
    if FLAGS.early_stopping_worst_xe_target is None:
      return -1.0
    else:
      targets_for_lr_drops = [
          float(string) for string
          in FLAGS.early_stopping_worst_xe_target.split(',')
          if string
      ]
      num_drops = lr_scheduler.num_drops()
      if targets_for_lr_drops:
        return targets_for_lr_drops[min(num_drops, len(targets_for_lr_drops)-1)]
      else:
        return None

  def log_summaries(summary):
    utils.log_scalar_summaries(summary)
    add_summary(summary)

  while monitor.next_turn(evaluate):

    logging.info('metrics: %r', monitor.metrics())
    logging.info(
        'early stopping: turns: %s, worst xe target: %s, best expected xe: %s',
        monitor.effective_es_turns(), monitor.es_worst_target(),
        monitor.best_expected_xe())
    add_summaries_for_metrics()

    # If enough turns passed without improvement, turn on averaging.
    best_turn = monitor.best_xe_turn() or 0
    num_tuns_since_best = monitor.turn() - best_turn
    if (averaged and
        ((monitor.turn() > 0 and
          num_tuns_since_best >= FLAGS.trigger_averaging_turns) or
         (FLAGS.trigger_averaging_at_the_latest >= 0 and
          monitor.turn() >= FLAGS.trigger_averaging_at_the_latest))):
      monitor.set_averaging_triggered(True)

    start_time = time.time()
    sum_cost = 0.0
    sum_tokens = 0
    for _ in range(FLAGS.steps_per_turn):
      cost, summary, last_state, num_tokens = train_1(
          model, source_iterator, last_state,
          learning_rate=lr_scheduler.learning_rate(),
          accum_batch_size=model.config.accum_batch_size)
      if monitor.averaging_triggered():
        averaged.take_sample()
      sum_cost += cost
      sum_tokens += num_tokens
      # Log summaries at the very beginning of training to make it easier to
      # debug initialization problems.
      if (model.global_step() == 1 or
          (model.global_step()+1) %
          FLAGS.print_training_stats_every_num_steps == 1):
        log_summaries(summary)
        logging.info('avg training cost at step %d: %.5f',
                     model.global_step(), sum_cost / sum_tokens)
        sum_cost = 0.0
        sum_tokens = 0
    steps_per_sec = FLAGS.steps_per_turn / (time.time()-start_time)

    # TODO(melisgl): Is this the right frequency for saving?
    save_last_checkpoint_fn()

  metrics = monitor.metrics()
  logging.info('Finished at turn %d for reason: %s',
               monitor.turn(), monitor.finished_reason())
  logging.info('Best XE was %5.5f at turn %d',
               metrics['best_xe'], metrics['best_xe_turn'])
  return metrics