def test(model: Model, evaluators, datasets: Dict[str, Dataset], loader, checkpoint, ema=True, aysnc_encoding=None, sample=None) -> Dict[str, Evaluation]: print("Setting up model") model.set_inputs(list(datasets.values()), loader) if aysnc_encoding: evaluator_runner = AysncEvaluatorRunner(evaluators, model, aysnc_encoding) inputs = evaluator_runner.dequeue_op else: evaluator_runner = EvaluatorRunner(evaluators, model) inputs = model.get_placeholders() input_dict = {p: x for p, x in zip(model.get_placeholders(), inputs)} #pdb.set_trace() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) with sess.as_default(): pred = model.get_predictions_for(input_dict) #vz #pdb.set_trace() evaluator_runner.set_input(pred) print("Restoring variables") saver = tf.train.Saver() saver.restore(sess, checkpoint) if ema: # FIXME This is a bit stupid, since we are loading variables twice, but I found it # a bit fiddly to load the variables directly.... ema = tf.train.ExponentialMovingAverage(0) reader = tf.train.NewCheckpointReader(checkpoint) expected_ema_names = { ema.average_name(x): x for x in tf.trainable_variables() if reader.has_tensor(ema.average_name(x)) } if len(expected_ema_names) > 0: print("Restoring EMA variables") saver = tf.train.Saver(expected_ema_names) saver.restore(sess, checkpoint) tf.get_default_graph().finalize() print("Begin evaluation") dataset_outputs = {} for name, dataset in datasets.items(): dataset_outputs[name] = evaluator_runner.run_evaluators( sess, dataset, name, sample, {}) return dataset_outputs
def _train(model: Model, data: TrainingData, checkpoint: Union[str, None], parameter_checkpoint: Union[str, None], save_start: bool, train_params: TrainParams, evaluators: List[Evaluator], out: ModelDir, notes=None, dry_run=False, start_eval=False): if train_params.async_encoding: _train_async(model, data, checkpoint, parameter_checkpoint, save_start, train_params, evaluators, out, notes, dry_run, start_eval) return if train_params.best_weights is not None: raise NotImplementedError # spec the model for the current voc/input/batching train = data.get_train() eval_datasets = data.get_eval() loader = data.get_resource_loader() evaluator_runner = EvaluatorRunner(evaluators, model) print("Training on %d batches" % len(train)) print("Evaluation datasets: " + " ".join("%s (%d)" % (name, len(data)) for name, data in eval_datasets.items())) print("Init model...") model.set_inputs([train] + list(eval_datasets.values()), loader) print("Setting up model prediction / tf...") sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) with sess.as_default(): pred = model.get_prediction() evaluator_runner.set_input(pred) if parameter_checkpoint is not None: print("Restoring parameters from %s" % parameter_checkpoint) saver = tf.train.Saver(tf.trainable_variables()) saver.restore(sess, parameter_checkpoint) saver = None loss, summary_tensor, train_opt, global_step, _ = _build_train_ops( train_params) # Pre-compute tensors we need at evaluations time eval_tensors = [] for ev in evaluators: eval_tensors.append(ev.tensors_needed(pred)) saver = tf.train.Saver(max_to_keep=train_params.max_checkpoints_to_keep) summary_writer = tf.summary.FileWriter(out.log_dir) # Load or initialize the model parameters if checkpoint is not None: print("Restoring training from checkpoint...") saver.restore(sess, checkpoint) print("Loaded checkpoint: " + str(sess.run(global_step))) return else: if parameter_checkpoint is not None: print("Initializing training variables...") vars = [ x for x in tf.global_variables() if x not in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) ] sess.run(tf.variables_initializer(vars)) else: print("Initializing parameters...") sess.run(tf.global_variables_initializer()) # Make sure no bugs occur that add to the graph in the train loop, that can cause (eventuall) OOMs tf.get_default_graph().finalize() print("Start training!") on_step = sess.run(global_step) if save_start: summary_writer.add_graph(sess.graph, global_step=on_step) save_train_start(out.dir, data, on_step, evaluators, train_params, notes) if train_params.eval_at_zero: print("Running evaluation...") start_eval = False for name, data in eval_datasets.items(): n_samples = train_params.eval_samples.get(name) evaluation = evaluator_runner.run_evaluators( sess, data, name, n_samples) for s in evaluation.to_summaries(name + "-"): summary_writer.add_summary(s, on_step) batch_time = 0 for epoch in range(train_params.num_epochs): for batch_ix, batch in enumerate(train.get_epoch()): t0 = time.perf_counter() on_step = sess.run( global_step ) + 1 # +1 because all calculations are done after step get_summary = on_step % train_params.log_period == 0 encoded = model.encode(batch, True) if get_summary: summary, _, batch_loss = sess.run( [summary_tensor, train_opt, loss], feed_dict=encoded) else: summary = None _, batch_loss = sess.run([train_opt, loss], feed_dict=encoded) if np.isnan(batch_loss): raise RuntimeError("NaN loss!") batch_time += time.perf_counter() - t0 if get_summary: print("on epoch=%d batch=%d step=%d time=%.3f" % (epoch, batch_ix + 1, on_step, batch_time)) summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="time", simple_value=batch_time) ]), on_step) summary_writer.add_summary(summary, on_step) batch_time = 0 # occasional saving if on_step % train_params.save_period == 0: print("Checkpointing") saver.save(sess, join(out.save_dir, "checkpoint-" + str(on_step)), global_step=global_step) # Occasional evaluation if (on_step % train_params.eval_period == 0) or start_eval: print("Running evaluation...") start_eval = False t0 = time.perf_counter() for name, data in eval_datasets.items(): n_samples = train_params.eval_samples.get(name) evaluation = evaluator_runner.run_evaluators( sess, data, name, n_samples) for s in evaluation.to_summaries(name + "-"): summary_writer.add_summary(s, on_step) print("Evaluation took: %.3f seconds" % (time.perf_counter() - t0)) saver.save(sess, relpath(join(out.save_dir, "checkpoint-" + str(on_step))), global_step=global_step) sess.close()
def test(model, evaluators, datasets: Dict, loader, checkpoint, ema=True, aysnc_encoding=None, sample=None, elmo_char_cnn=True) -> Dict[str, Evaluation]: print("Setting up model") model.set_inputs(list(datasets.values()), loader) if aysnc_encoding: evaluator_runner = AysncEvaluatorRunner(evaluators, model, aysnc_encoding) inputs = evaluator_runner.dequeue_op else: evaluator_runner = EvaluatorRunner(evaluators, model) inputs = model.get_placeholders() input_dict = {p: x for p, x in zip(model.get_placeholders(), inputs)} sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) with sess.as_default(): pred = model.get_predictions_for(input_dict) evaluator_runner.set_input(pred) print("Restoring variables") if elmo_char_cnn: all_vars = tf.global_variables() + tf.get_collection( tf.GraphKeys.SAVEABLE_OBJECTS) lm_var_names = {x.name for x in all_vars if x.name.startswith("bilm")} vars_to_restore = [x for x in all_vars if x.name not in lm_var_names] saver = tf.train.Saver(vars_to_restore) sess.run( tf.variables_initializer( [x for x in all_vars if x.name in lm_var_names])) saver.restore(sess, checkpoint) else: saver = tf.train.Saver() saver.restore(sess, checkpoint) if ema: # FIXME This is a bit stupid, since we are loading variables twice, but I found it # a bit fiddly to load the variables directly.... ema = tf.train.ExponentialMovingAverage(0) reader = tf.train.NewCheckpointReader(checkpoint) expected_ema_names = { ema.average_name(x): x for x in tf.trainable_variables() if reader.has_tensor(ema.average_name(x)) } if len(expected_ema_names) > 0: print("Restoring EMA variables") saver = tf.train.Saver(expected_ema_names) saver.restore(sess, checkpoint) tf.get_default_graph().finalize() print("Begin evaluation") dataset_outputs = {} for name, dataset in datasets.items(): dataset_outputs[name] = evaluator_runner.run_evaluators( sess, dataset, name, sample, {}) return dataset_outputs