def make_test_iterator(): return corpus.get_batches( data['test'], vocab, eval_batch_size(data['test']) if not dyneval else 1, max_time_steps, episodic=episodic, deterministic=True, max_epochs=1, max_batches=max_test_eval_batches, conditioning_separator=conditioning_separator)
def make_train_iterator(): return corpus.get_batches( data['training'], vocab, batch_size, max_time_steps, episodic=episodic, deterministic=True, max_epochs=1, max_batches=max_training_eval_batches, conditioning_separator=conditioning_separator)
def make_quick_eval_iterator(): quick_eval_max_batches = max( 1, eval_softmax_temperature_estimation_num_tokens // batch_size // max_time_steps) return corpus.get_batches( data['valid'], vocab, batch_size, max_time_steps, episodic=episodic, deterministic=True, max_epochs=1, max_batches=quick_eval_max_batches, conditioning_separator=conditioning_separator)
def _train_loop(monitor, lr_scheduler, averaged, dyneval, model, data, vocab, config, summary_writer, save_last_checkpoint_fn): source_iterator = corpus.get_batches( data['training'], vocab, config.batch_size, config.max_time_steps, num_samples=config.num_training_samples, episodic=FLAGS.episodic, deterministic=False, conditioning_separator=config.conditioning_separator) last_state = None steps_per_sec = 0.0 def munge_max_batches_flag_value(max_batches): if max_batches == -1: return None else: return max_batches def evaluate0(): # KLUDGE: This depends on monitor calling this function before using the # worst target. monitor.set_es_worst_target(es_worst_target()) global_step = model.global_step() logging.info('turn: %s (eval), step: %d (opt) (%.2f/s)', monitor.turn(), global_step, steps_per_sec) if config.accum_batch_size == -1: eval_batch_size = config.batch_size else: eval_batch_size = config.accum_batch_size training_xe, valid_xe, test_xe = evaluation.evaluate_all( model, data, vocab, eval_batch_size, config.max_time_steps, FLAGS.min_non_episodic_eval_examples_per_stripe, munge_max_batches_flag_value(FLAGS.max_training_eval_batches), munge_max_batches_flag_value(FLAGS.max_eval_eval_batches), munge_max_batches_flag_value(FLAGS.max_test_eval_batches), FLAGS.episodic, config.eval_softmax_temperature, config.eval_softmax_temperature_estimation_num_tokens, config.eval_method, config.num_eval_samples, config.eval_power_mean_power, config.eval_dropout_multiplier, config.validation_prediction_file, dyneval, conditioning_separator=config.conditioning_separator) return valid_xe, {'training_xe': training_xe, 'test_xe': test_xe, 'global_step': global_step} def evaluate(): if monitor.averaging_triggered(): with averaged: logging.info('Evaluating with averaged parameters.') return evaluate0() else: return evaluate0() def add_summary(summary_str): if summary_writer is not None: summary_writer.add_summary(summary_str, model.global_step()) def add_summaries_for_metrics(): metrics = monitor.metrics() summary = tf.Summary() for key in metrics: summary.value.add(tag=key, simple_value=metrics[key]) add_summary(summary) # Compute the early stopping worst target. It may change when the learning # rate is dropped. def es_worst_target(): if FLAGS.early_stopping_worst_xe_target is None: return -1.0 else: targets_for_lr_drops = [ float(string) for string in FLAGS.early_stopping_worst_xe_target.split(',') if string ] num_drops = lr_scheduler.num_drops() if targets_for_lr_drops: return targets_for_lr_drops[min(num_drops, len(targets_for_lr_drops)-1)] else: return None def log_summaries(summary): utils.log_scalar_summaries(summary) add_summary(summary) while monitor.next_turn(evaluate): logging.info('metrics: %r', monitor.metrics()) logging.info( 'early stopping: turns: %s, worst xe target: %s, best expected xe: %s', monitor.effective_es_turns(), monitor.es_worst_target(), monitor.best_expected_xe()) add_summaries_for_metrics() # If enough turns passed without improvement, turn on averaging. best_turn = monitor.best_xe_turn() or 0 num_tuns_since_best = monitor.turn() - best_turn if (averaged and ((monitor.turn() > 0 and num_tuns_since_best >= FLAGS.trigger_averaging_turns) or (FLAGS.trigger_averaging_at_the_latest >= 0 and monitor.turn() >= FLAGS.trigger_averaging_at_the_latest))): monitor.set_averaging_triggered(True) start_time = time.time() sum_cost = 0.0 sum_tokens = 0 for _ in range(FLAGS.steps_per_turn): cost, summary, last_state, num_tokens = train_1( model, source_iterator, last_state, learning_rate=lr_scheduler.learning_rate(), accum_batch_size=model.config.accum_batch_size) if monitor.averaging_triggered(): averaged.take_sample() sum_cost += cost sum_tokens += num_tokens # Log summaries at the very beginning of training to make it easier to # debug initialization problems. if (model.global_step() == 1 or (model.global_step()+1) % FLAGS.print_training_stats_every_num_steps == 1): log_summaries(summary) logging.info('avg training cost at step %d: %.5f', model.global_step(), sum_cost / sum_tokens) sum_cost = 0.0 sum_tokens = 0 steps_per_sec = FLAGS.steps_per_turn / (time.time()-start_time) # TODO(melisgl): Is this the right frequency for saving? save_last_checkpoint_fn() metrics = monitor.metrics() logging.info('Finished at turn %d for reason: %s', monitor.turn(), monitor.finished_reason()) logging.info('Best XE was %5.5f at turn %d', metrics['best_xe'], metrics['best_xe_turn']) return metrics