def test(model: Model, evaluators, datasets: Dict[str, Dataset], loader, checkpoint, ema=True, aysnc_encoding=None, sample=None) -> Dict[str, Evaluation]: print("Setting up model") model.set_inputs(list(datasets.values()), loader) if aysnc_encoding: evaluator_runner = AysncEvaluatorRunner(evaluators, model, aysnc_encoding) inputs = evaluator_runner.dequeue_op else: evaluator_runner = EvaluatorRunner(evaluators, model) inputs = model.get_placeholders() input_dict = {p: x for p, x in zip(model.get_placeholders(), inputs)} #pdb.set_trace() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) with sess.as_default(): pred = model.get_predictions_for(input_dict) #vz #pdb.set_trace() evaluator_runner.set_input(pred) print("Restoring variables") saver = tf.train.Saver() saver.restore(sess, checkpoint) if ema: # FIXME This is a bit stupid, since we are loading variables twice, but I found it # a bit fiddly to load the variables directly.... ema = tf.train.ExponentialMovingAverage(0) reader = tf.train.NewCheckpointReader(checkpoint) expected_ema_names = { ema.average_name(x): x for x in tf.trainable_variables() if reader.has_tensor(ema.average_name(x)) } if len(expected_ema_names) > 0: print("Restoring EMA variables") saver = tf.train.Saver(expected_ema_names) saver.restore(sess, checkpoint) tf.get_default_graph().finalize() print("Begin evaluation") dataset_outputs = {} for name, dataset in datasets.items(): dataset_outputs[name] = evaluator_runner.run_evaluators( sess, dataset, name, sample, {}) return dataset_outputs
def __init__(self, evaluators: List[Evaluator], model: Model, queue_size: int): placeholders = model.get_placeholders() self.eval_queue = tf.FIFOQueue(queue_size, [x.dtype for x in placeholders], name="eval_queue") self.enqueue_op = self.eval_queue.enqueue(placeholders) self.dequeue_op = self.eval_queue.dequeue() self.close_queue = self.eval_queue.close(True) for x,p in zip(placeholders, self.dequeue_op): p.set_shape(x.shape) self.evaluators = evaluators self.queue_size = self.eval_queue.size() self.model = model self.tensors_needed = None
def _train_async(model: Model, data: TrainingData, checkpoint: Union[str, None], parameter_checkpoint: Union[str, None], save_start: bool, train_params: TrainParams, evaluators: List[Evaluator], out: ModelDir, notes=None, dry_run=False, start_eval=False): """ Train while encoding batches on a seperate thread and storing them in a tensorflow Queue, can be much faster then using the feed_dict approach """ train = data.get_train() eval_datasets = data.get_eval() loader = data.get_resource_loader() #pdb.set_trace() print("Training on %d batches" % len(train)) print("Evaluation datasets: " + " ".join("%s (%d)" % (name, len(data)) for name, data in eval_datasets.items())) # spec the model for the given datasets #pdb.set_trace() model.set_inputs([train] + list(eval_datasets.values()), loader) placeholders = model.get_placeholders() train_queue = tf.FIFOQueue(train_params.async_encoding, [x.dtype for x in placeholders], name="train_queue") evaluator_runner = AysncEvaluatorRunner(evaluators, model, train_params.async_encoding) #evaluator_runner = EvaluatorRunner(evaluators, model) train_enqeue = train_queue.enqueue(placeholders) train_close = train_queue.close(True) is_train = tf.placeholder(tf.bool, ()) #pdb.set_trace() input_tensors = tf.cond(is_train, lambda: train_queue.dequeue(), lambda: evaluator_runner.eval_queue.dequeue()) #new input tensor with no eval #input_tensors = train_queue.dequeue() # tensorfow can't infer the shape for an unsized queue, so set it manually for input_tensor, pl in zip(input_tensors, placeholders): input_tensor.set_shape(pl.shape) print("Init model...") sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) #pdb.set_trace() with sess.as_default(): pred = model.get_predictions_for(dict(zip(placeholders, input_tensors))) evaluator_runner.set_input(pred) #az need to fix #pdb.set_trace() if parameter_checkpoint is not None: print("Restoring parameters from %s" % parameter_checkpoint) saver = tf.train.Saver() saver.restore(sess, checkpoint) saver = None print("Setting up model prediction / tf...") all_vars = tf.global_variables() loss, summary_tensor, train_opt, global_step, weight_ema = _build_train_ops( train_params) #pdb.set_trace() # Pre-compute tensors we need at evaluations time eval_tensors = [] for ev in evaluators: eval_tensors.append(ev.tensors_needed(pred)) #az rmeoved if train_params.best_weights is not None: lst = all_vars if weight_ema is not None: for x in lst: v = weight_ema.average(x) if v is not None: lst.append(v) best_weight_saver = tf.train.Saver(var_list=lst, max_to_keep=1) cur_best = None else: best_weight_saver = None cur_best = None saver = tf.train.Saver(max_to_keep=train_params.max_checkpoints_to_keep) summary_writer = tf.summary.FileWriter(out.log_dir) # Load or initialize the model parameters if checkpoint is not None: print("Restoring from checkpoint...") saver.restore(sess, checkpoint) print("Loaded checkpoint: " + str(sess.run(global_step))) else: print("Initializing parameters...") sess.run(tf.global_variables_initializer()) # Make sure no bugs occur that add to the graph in the train loop, that can cause (eventuall) OOMs tf.get_default_graph().finalize() if dry_run: return on_step = sess.run(global_step) if save_start: summary_writer.add_graph(sess.graph, global_step=on_step) save_train_start(out.dir, data, sess.run(global_step), evaluators, train_params, notes) def enqueue_train(): try: # feed data from the dataset iterator -> encoder -> queue for epoch in range(train_params.num_epochs): for batch in train.get_epoch(): feed_dict = model.encode(batch, True) sess.run(train_enqeue, feed_dict) except tf.errors.CancelledError: # The queue_close operator has been called, exit gracefully return except Exception as e: # Crashes the main thread with a queue exception sess.run(train_close) raise e train_enqueue_thread = Thread(target=enqueue_train) train_enqueue_thread.daemon = True # Ensure we exit the program on an excpetion print("Start training!") batch_time = 0 epoch_best = 0 dev_acc = [] train_acc = [] train_dict = {is_train: True} eval_dict = {is_train: False} #pdb.set_trace() try: train_enqueue_thread.start() for epoch in range(train_params.num_epochs): for batch_ix in range(len(train)): t0 = time.perf_counter() on_step = sess.run(global_step) + 1 get_summary = on_step % train_params.log_period == 0 if get_summary: summary, _, batch_loss = sess.run( [summary_tensor, train_opt, loss], feed_dict=train_dict) print('batch_loss is: ' + str(batch_loss)) else: summary = None _, batch_loss = sess.run([train_opt, loss], feed_dict=train_dict) #print(batch_loss) #pdb.set_trace() #with sess.as_default(): # temp= model.get_predictions_for(dict(zip(placeholders, input_tensors))) #pdb.set_trace() if np.isnan(batch_loss): raise RuntimeError("NaN loss!") batch_time += time.perf_counter() - t0 if summary is not None: print("on epoch=%d batch=%d step=%d, time=%.3f" % (epoch, batch_ix + 1, on_step, batch_time)) summary_writer.add_summary( tf.Summary(value=[ tf.Summary.Value(tag="time", simple_value=batch_time) ]), on_step) summary_writer.add_summary(summary, on_step) batch_time = 0 # occasional saving if on_step % train_params.save_period == 0: print("Checkpointing") saver.save(sess, join(out.save_dir, "checkpoint-" + str(on_step)), global_step=global_step) # Occasional evaluation #pdb.set_trace() if (on_step % train_params.eval_period == 0) or start_eval: print("Running evaluation...") start_eval = False t0 = time.perf_counter() for name, data in eval_datasets.items(): n_samples = train_params.eval_samples.get(name) #pdb.set_trace() evaluation = evaluator_runner.run_evaluators( sess, data, name, n_samples, eval_dict) #pdb.set_trace() for s in evaluation.to_summaries(name + "-"): summary_writer.add_summary(s, on_step) #pdb.set_trace() # Maybe save as the best weights if train_params.best_weights is not None and name == train_params.best_weights[ 0]: #pdb.set_trace() val = evaluation.scalars[ train_params.best_weights[1]] dev_acc.append(val) if cur_best is None or val > cur_best: epoch_best = epoch send_email( 'epoch: ' + str(epoch_best) + 'acc: ' + str(val), 'New Best') print( "Save weights with current best weights (%s vs %.5f)" % ("None" if cur_best is None else ("%.5f" % cur_best), val)) best_weight_saver.save(sess, join( out.best_weight_dir, "best"), global_step=global_step) cur_best = val if (cur_best > 0.37): email_text = 'Best accuracy for dev data: ' + ( '%.3f' % cur_best) + ' <br> On epoch n: ' + str( epoch_best) + ' out of: ' + str( train_params.num_epochs ) + ' <br> Folder: ' + str( out.save_dir) email_title = 'Good News EveryOne!' send_email(email_text, email_title) print('Current accuracy for dev data: ' + ('%.3f' % val)) print('Best accuracy for dev data: ' + ('%.3f' % cur_best) + 'on epoch n:' + str(epoch_best)) else: val_train = evaluation.scalars[ train_params.best_weights[1]] train_acc.append(val_train) print('Current accuracy for train data: ' + ('%.3f' % val_train)) print("Evaluation took: %.3f seconds" % (time.perf_counter() - t0)) finally: sess.run( train_close) # terminates the enqueue thread with an exception train_enqueue_thread.join() email_text = 'Finished ' + str( train_params.num_epochs) + ' Best accuracy for dev data: ' + ( '%.3f' % cur_best) + ' <br> On epoch n: ' + str( epoch_best) + ' <br> Acc for train data last: ' + ( '%.3f' % val_train) + ' <br> Folder: ' + str(out.save_dir) email_title = 'Test Finished' image_path = create_train_dev_plot(dev_acc, train_acc, out.save_dir) send_email(email_text, email_title, image_path) saver.save(sess, relpath(join(out.save_dir, "checkpoint-" + str(on_step))), global_step=global_step) sess.close()