def evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, vocabulary=None, show_sample=False, eval_index=0): filename, dataset = eval_set A = Accumulator() eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, store_parse_masks=show_sample, example_lengths=eval_num_transitions_batch) can_sample = (FLAGS.model_type == "RLSPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) tree_strs = prettyprint_trees(tmp_samples) if not FLAGS.write_eval_report: # Only show one sample, regardless of the number of batches. show_sample = False # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() # get the index of the max log-probability pred = output.data.max(1, keepdim=False)[1].cpu() eval_accumulate(model, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[: batch_size] if transitions_per_example is not None else None sent2_transitions = transitions_per_example[ batch_size:] if transitions_per_example is not None else None sent1_trees = tree_strs[: batch_size] if tree_strs is not None else None sent2_trees = tree_strs[ batch_size:] if tree_strs is not None else None else: sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + tree_strs[0]) end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy return eval_class_acc, eval_trans_acc
def evaluate(FLAGS, model, data_manager, eval_set, index, logger, step, vocabulary=None): filename, dataset = eval_set A = Accumulator() M = MetricsWriter(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name)) reporter = EvalReporter() eval_str = eval_format(model) eval_extra_str = eval_extra_format(model) # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 invalid = 0 start = time.time() model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability eval_accumulate(model, data_manager, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Optionally calculate transition loss/acc. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt+1)/2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: reporter_args = [pred, target, eval_ids, output.data.cpu().numpy()] if hasattr(model, 'transition_loss'): transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[:batch_size] sent2_transitions = transitions_per_example[batch_size:] reporter_args.append(sent1_transitions) reporter_args.append(sent2_transitions) else: reporter_args.append(transitions_per_example) reporter.save_batch(*reporter_args) # Print Progress progress_bar.step(i+1, total=total_batches) progress_bar.finish() end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) stats_args = eval_stats(model, A, step) stats_args['filename'] = filename logger.Log(eval_str.format(**stats_args)) logger.Log(eval_extra_str.format(**stats_args)) if FLAGS.write_eval_report: eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".report") reporter.write_report(eval_report_path) eval_class_acc = stats_args['class_acc'] eval_trans_acc = stats_args['transition_acc'] if index == 0: eval_metrics(M, stats_args, step) return eval_class_acc, eval_trans_acc
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators, logger): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for _ in range(trainer.step, FLAGS.training_steps): if (trainer.step - trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait: logger.Log('No improvement after ' + str(FLAGS.early_stopping_steps_to_wait) + ' steps. Stopping training.') break model.train() log_entry.Clear() log_entry.step = trainer.step should_log = False start = time.time() batch = get_batch(next(training_data_iter)) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. trainer.optimizer_zero_grad() temperature = math.sin( math.pi / 2 + trainer.step / float(FLAGS.rl_confidence_interval) * 2 * math.pi) temperature = (temperature + 1) / 2 # Confidence Penalty for Transition Predictions. if FLAGS.rl_confidence_penalty: epsilon = FLAGS.rl_epsilon * \ math.exp(-trainer.step / float(FLAGS.rl_epsilon_decay)) temp = 1 + \ (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon model.spinn.temperature = max(1e-3, temp) # Soft Wake/Sleep based on temperature. if FLAGS.rl_wake_sleep: model.rl_weight = temperature * FLAGS.rl_weight # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() # get the index of the max log-probability pred = output.data.max(1, keepdim=False)[1].cpu() class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.CrossEntropyLoss()(output, to_gpu( Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr( model, 'transition_loss') else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping nn.utils.clip_grad_norm([ param for name, param in model.named_parameters() if name not in ["embed.embed.weight"] ], FLAGS.clipping_max_value) # Gradient descent step. trainer.optimizer_step() end = time.time() total_time = end - start train_accumulate(model, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) train_rl_accumulate(model, A, batch) if trainer.step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) stats(model, trainer, A, log_entry) should_log = True if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example( ) model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example( ) if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate( [transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed # boost. t_idxs = list(range(FLAGS.num_samples)) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:] log.strg_ev = strength_ev[1:] if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, _ = evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, eval_index=index) if index == 0: trainer.new_dev_accuracy(acc) progress_bar.reset() if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0: should_log = True trainer.checkpoint() if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps)
def evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, vocabulary=None, show_sample=False, eval_index=0): filename, dataset = eval_set A = Accumulator() index = len(log_entry.evaluation) eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps ** ( step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos((step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, store_parse_masks=show_sample, example_lengths=eval_num_transitions_batch) can_sample = FLAGS.model_type in ["ChoiPyramid"] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) # TODO: Restore support in Pyramid if using. if show_sample and can_sample: tmp_samples = model.get_samples(eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) tree_strs = prettyprint_trees(tmp_samples) if not FLAGS.write_eval_report: show_sample = False # Only show one sample, regardless of the number of batches. # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() # get the index of the max log-probability pred = logits.data.max(1, keepdim=False)[1].cpu() eval_accumulate(model, data_manager, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Optionally calculate transition loss/acc. model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[:batch_size] if transitions_per_example is not None else None sent2_transitions = transitions_per_example[batch_size:] if transitions_per_example is not None else None sent1_trees = tree_strs[:batch_size] if tree_strs is not None else None sent2_trees = tree_strs[batch_size:] if tree_strs is not None else None else: sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + tree_strs[0]) end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy return eval_class_acc, eval_trans_acc
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) M = MetricsWriter(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name)) # Checkpoint paths. standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch(training_data_iter.next()) model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) logger.Log("") logger.Log("# ----- BEGIN: Log Configuration ----- #") # Preview train string template. train_str = train_format(model) logger.Log("Train-Format: {}".format(train_str)) train_extra_str = train_extra_format(model) logger.Log("Train-Extra-Format: {}".format(train_extra_str)) # Preview eval string template. eval_str = eval_format(model) logger.Log("Eval-Format: {}".format(eval_str)) eval_extra_str = eval_extra_format(model) logger.Log("Eval-Extra-Format: {}".format(eval_extra_str)) logger.Log("# ----- END: Log Configuration ----- #") logger.Log("") # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss total_loss += auxiliary_loss(model) # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats_args = train_stats(model, optimizer, A, step) train_metrics(M, stats_args, step) logger.Log(train_str.format(**stats_args)) logger.Log(train_extra_str.format(**stats_args)) if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example() model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example() transition_str = "Samples:" if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate([ transitions_batch[:,:,0], transitions_batch[:,:,1]], axis=0) # This could be done prior to running the batch for a tiny speed boost. t_idxs = range(FLAGS.num_samples) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] stength_tr = sparks([1] + tr_strength[t_idx].tolist()) stength_ev = sparks([1] + ev_strength[t_idx].tolist()) _, crossing = evalb.crossing(gold, pred) transition_str += "\n{}. crossing={}".format(t_idx, crossing) transition_str += "\n g{}".format("".join(map(str, gold))) transition_str += "\n {}".format(stength_tr[1:].encode('utf-8')) transition_str += "\n pt{}".format("".join(map(str, pred_tr))) transition_str += "\n {}".format(stength_ev[1:].encode('utf-8')) transition_str += "\n pe{}".format("".join(map(str, pred_ev))) logger.Log(transition_str) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc, tacc = evaluate(FLAGS, model, data_manager, eval_set, index, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
def evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, vocabulary=None, show_sample=False, eval_index=0): filename, dataset = eval_set A = Accumulator() index = len(log_entry.evaluation) eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**( step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos( (step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids, _, silver_tree = batch # eval_X_batch: <batch x maxlen x 2> # eval_y_batch: <batch > # silver_tree: # the dist is invalid for val # Run model. output = model( eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, store_parse_masks=True, example_lengths=eval_num_transitions_batch) # TODO: Restore support in Pyramid if using. can_sample = FLAGS.model_type in [ "ChoiPyramid" ] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) # tree_strs = prettyprint_trees(tmp_samples) tree_strs = [tree for tree in tmp_samples] tmp_samples = model.get_samples(eval_X_batch, vocabulary, only_one=False) # def get_max(s): # # test f1 # max = 0 # for x in s: # _, idx = x.split(',') # if int(idx) > max: # max = int(idx) # return max for s in (range(int(model.use_sentence_pair) + 1)): for b in (range(silver_tree.shape[0])): model_out = tmp_samples[s * silver_tree.shape[0] + b] std_out = silver_tree[b, :, s] std_out = set([x for x in std_out if x != '-1,-1']) model_out_brackets, model_out_max_l = get_brackets(model_out) model_out = set(convert_brackets_to_string(model_out_brackets)) outmost_bracket = '{:d},{:d}'.format(0, model_out_max_l) std_out.add(outmost_bracket) model_out.add(outmost_bracket) # print get_max(model_out), get_max(std_out) # print model_out # print std_out # print '=' * 30 # assert get_max(model_out) == get_max(std_out) overlap = model_out & std_out prec = float(len(overlap)) / (len(model_out) + 1e-8) reca = float(len(overlap)) / (len(std_out) + 1e-8) if len(std_out) == 0: reca = 1. if len(model_out) == 0: prec = 1. f1 = 2 * prec * reca / (prec + reca + 1e-8) A.add('f1', f1) if not FLAGS.write_eval_report: # Only show one sample, regardless of the number of batches. show_sample = False # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() # get the index of the max log-probability pred = logits.data.max(1, keepdim=False)[1].cpu() eval_accumulate(model, data_manager, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Optionally calculate transition loss/acc. #model.transition_loss if hasattr(model, 'transition_loss') else None # TODO: review this. the original line seems to have no effect # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[: batch_size] if transitions_per_example is not None else None sent2_transitions = transitions_per_example[ batch_size:] if transitions_per_example is not None else None sent1_trees = tree_strs[: batch_size] if tree_strs is not None else None sent2_trees = tree_strs[ batch_size:] if tree_strs is not None else None else: sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + str(tree_strs[0])) end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) # get the eval statistics (e.g. average F1) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy eval_f1 = eval_log.f1 return eval_class_acc, eval_trans_acc, eval_f1
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error, vocabulary): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Checkpoint paths. standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch( training_data_iter.next()) model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=1.0, example_lengths=num_transitions_batch ) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for step in range(step, FLAGS.training_steps): model.train() log_entry.Clear() log_entry.step = step should_log = False start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps ** ( step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos((step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() # get the index of the max log-probability pred = logits.data.max(1, keepdim=False)[1].cpu() class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = get_l2_loss(model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * \ (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats(model, optimizer, A, step, log_entry) should_log = True progress_bar.finish() if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch ) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example() model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch ) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example() if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate([ transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed boost. t_idxs = range(FLAGS.num_samples) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred_ev) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:].encode('utf-8') log.strg_ev = strength_ev[1:].encode('utf-8') if step > 0 and step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, tacc = evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, show_sample=( step % FLAGS.sample_interval_steps == 0), vocabulary=vocabulary, eval_index=index) if FLAGS.ckpt_on_best_dev_error and index == 0 and ( 1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) # TODO: This mixes information across dev sets. Fix. trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: should_log = True logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps)
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Checkpoint paths. standard_checkpoint_path = get_checkpoint_path( FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path( FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch( training_data_iter.next()) model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar( msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for step in range(step, FLAGS.training_steps): model.train() log_entry.Clear() log_entry.step = step should_log = False start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum( [(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() epsilon = FLAGS.rl_epsilon * math.exp(-step / FLAGS.rl_epsilon_decay) # Epsilon Greedy w. Decay. model.spinn.epsilon = epsilon # Confidence Penalty for Transition Predictions. temperature = math.sin(math.pi / 2 + step / float(FLAGS.rl_confidence_interval) * 2 * math.pi) temperature = (temperature + 1) / 2 if FLAGS.rl_confidence_penalty: temp = 1 + \ (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon model.spinn.temperature = max(1e-3, temp) # Soft Wake/Sleep based on temperature. if FLAGS.rl_wake_sleep: model.rl_weight = temperature * FLAGS.rl_weight # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[ 1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()( logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr( model, 'transition_loss') else None # Extract L2 Cost l2_loss = get_l2_loss( model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * \ (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) train_rl_accumulate(model, data_manager, A, batch) if step % FLAGS.statistics_interval_steps == 0 \ or step % FLAGS.metrics_interval_steps == 0: if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats(model, optimizer, A, step, log_entry) if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example( ) model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example( ) if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate([ transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed # boost. t_idxs = range(FLAGS.num_samples) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks( [1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks( [1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:].encode('utf-8') log.strg_ev = strength_ev[1:].encode('utf-8') if step > 0 and step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, tacc = evaluate( FLAGS, model, data_manager, eval_set, log_entry, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and ( 1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log( "Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: should_log = True logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) log_level = afs_safe_logger.ProtoLogger.INFO if not should_log and step % FLAGS.metrics_interval_steps == 0: # Log to file, but not to stderr. should_log = True log_level = afs_safe_logger.ProtoLogger.DEBUG if should_log: logger.LogEntry(log_entry, level=log_level) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
def evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, vocabulary=None, show_sample=False, eval_index=0, target_vocabulary=None): filename, dataset = eval_set A = Accumulator() len(log_entry.evaluation) eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() model.eval() ref_file_name = FLAGS.log_path + "/ref_file" pred_file_name = FLAGS.log_path + "/pred_file" reference_file = open(ref_file_name, "w") predict_file = open(pred_file_name, "w") full_ref = [] full_pred = [] for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, example_lengths=eval_num_transitions_batch) can_sample = (FLAGS.model_type == "RLSPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.encoder.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) tree_strs = prettyprint_trees(tmp_samples) if not FLAGS.write_eval_report: # Only show one sample, regardless of the number of batches. show_sample = False # Get reference translation ref_out = [" ".join(map(str, k[:-1])) + " ." for k in eval_y_batch] full_ref += ref_out # Get predicted translation predicted = [[] for i in range(len(eval_y_batch))] done = [] for x in output: index = -1 for x_0 in x: index += 1 val = int(x_0) if val == 1: if index in done: continue done.append(index) elif index not in done: predicted[index].append(val) pred_out = [" ".join(map(str, k)) + " ." for k in predicted] full_pred += pred_out eval_accumulate(model, A, batch) # Optionally calculate transition loss/acc. model.encoder.transition_loss if hasattr(model.encoder, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / \ 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.encoder.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(full_pred, full_ref, eval_ids, [None], sent1_transitions, sent2_transitions, sent1_trees, sent2_trees, mt=True) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + tree_strs[0]) reference_file.write("\n".join(full_ref)) reference_file.close() predict_file.write("\n".join(full_pred)) predict_file.close() bleu_score = os.popen("perl spinn/util/multi-bleu.perl " + ref_file_name + " < " + pred_file_name).read() try: bleu_score = float(bleu_score) except: bleu_score = 0.0 end = time.time() total_time = end - start A.add('class_correct', bleu_score) A.add('class_total', 1) A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) stats = parse_comparison.run_main( data_type="mt", main_report_path_template=FLAGS.log_path + "/" + FLAGS.experiment_name + ".eval_set_0.report", main_data_path=FLAGS.source_eval_path) # To-do: include the following into lgog-formatter so it's reported in standard format. if tree_strs is not None: logger.Log( 'F1 w/ GT: ' + str(stats['gt']) + '\n' +\ 'F1 w/ LB: ' + str(stats['lb']) + '\n' +\ 'F1 w/ RB: ' + str(stats['rb']) + '\n' +\ 'Avg. tree depth: ' + str(stats['depth']) ) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy return eval_class_acc, eval_trans_acc
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators, logger, vocabulary, target_vocabulary): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) rl_only = False log_entry = pb.SpinnEntry() for _ in range(trainer.step, FLAGS.training_steps): if FLAGS.rl_alternate and trainer.step % 1000 == 0 and trainer.step > 0: rl_only = not rl_only if rl_only: logger.Log('Switching training mode: RL only.') else: logger.Log('Switching training mode: MT only.') if (trainer.step - trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait: logger.Log('No improvement after ' + str(FLAGS.early_stopping_steps_to_wait) + ' steps. Stopping training.') break model.train() log_entry.Clear() log_entry.step = trainer.step should_log = False start = time.time() batch = get_batch(next(training_data_iter)) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. trainer.optimizer_zero_grad() temperature = math.sin( math.pi / 2 + trainer.step / float(FLAGS.rl_confidence_interval) * 2 * math.pi) temperature = (temperature + 1) / 2 # Confidence Penalty for Transition Predictions. if FLAGS.rl_confidence_penalty: epsilon = FLAGS.rl_epsilon * \ math.exp(-trainer.step / float(FLAGS.rl_epsilon_decay)) temp = 1 + \ (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon model.spinn.temperature = max(1e-3, temp) # Soft Wake/Sleep based on temperature. if FLAGS.rl_wake_sleep: model.rl_weight = temperature * FLAGS.rl_weight # Run model. output, trg, attention, mask = model( X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, example_lengths=num_transitions_batch) criterion = nn.NLLLoss() batch_size = len(y_batch) trg_seq_len = trg.shape[0] mt_loss = 0.0 if rl_only == False: num_classes = output.shape[-1] mask = to_gpu(mask) for i in range(trg_seq_len): mt_loss += criterion( output[i, :].index_select(0, mask[i].nonzero().squeeze(1)), trg[i].index_select(0, mask[i].nonzero().squeeze(1)).view(-1)) elif FLAGS.rl_alternate: model.policy_loss = 0.0 model.value_loss = 0.0 # Optionally calculate transition loss. mt_loss = mt_loss / trg_seq_len model.transition_loss = model.encoder.transition_loss if hasattr( model.encoder, 'transition_loss') else None transition_loss = model.transition_loss if hasattr( model, 'transition_loss') else None model.mt_loss = mt_loss # Accumulate Total Loss Variable total_loss = 0.0 total_loss += mt_loss if transition_loss is not None and model.encoder.optimize_transition_loss: model.optimize_transition_loss = model.encoder.optimize_transition_loss total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss[0] # Backward pass. total_loss.backward() # Hard Gradient Clipping nn.utils.clip_grad_norm_([ param for name, param in model.named_parameters() if name not in ["embed.embed.weight"] ], FLAGS.clipping_max_value) # Gradient descent step. trainer.optimizer_step() bb = list(model.parameters())[-1].clone() end = time.time() total_time = end - start train_accumulate(model, A, batch) A.add('total_tokens', total_tokens) A.add('total_time', total_time) A.add('mt_loss', float(mt_loss)) train_rl_accumulate(model, A, batch) if trainer.step % FLAGS.statistics_interval_steps == 0: stats(model, trainer, A, log_entry) should_log = True progress_bar.finish() if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, example_lengths=num_transitions_batch) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example( ) model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, example_lengths=num_transitions_batch) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example( ) if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate( [transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed # boost. t_idxs = list(range(FLAGS.num_samples)) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred_ev) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:] log.strg_ev = strength_ev[1:] if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, _ = evaluate( FLAGS, model, eval_set, log_entry, logger, trainer, show_sample=(trainer.step % FLAGS.sample_interval_steps == 0), vocabulary=vocabulary, eval_index=index, target_vocabulary=target_vocabulary) if index == 0: trainer.new_dev_accuracy(acc) progress_bar.reset() if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0: should_log = True trainer.checkpoint() if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps)
def evaluate(model, eval_set, logger, step, vocabulary=None): filename, dataset = eval_set reporter = EvalReporter() # Evaluate class_correct = 0 class_total = 0 total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() model.eval() transition_preds = [] transition_targets = [] for i, batch in enumerate(dataset): eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch = get_batch(batch)[:4] # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_correct += pred.eq(target).sum() class_total += target.size(0) # Optionally calculate transition loss/acc. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt+1)/2 for nt in eval_num_transitions_batch.reshape(-1)]) # Accumulate stats for transition accuracy. if transition_loss is not None: transition_preds.append([m["t_preds"] for m in model.spinn.memories]) transition_targets.append([m["t_given"] for m in model.spinn.memories]) # Print Progress progress_bar.step(i+1, total=total_batches) progress_bar.finish() end = time.time() total_time = end - start # Get time per token. time_metric = time_per_token([total_tokens], [total_time]) # Get class accuracy. eval_class_acc = class_correct / float(class_total) # Get transition accuracy if applicable. if len(transition_preds) > 0: all_preds = np.array(flatten(transition_preds)) all_truth = np.array(flatten(transition_targets)) eval_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0]) else: eval_trans_acc = 0.0 stats_str = "Step: %i Eval acc: %f %f %s Time: %5f" % (step, eval_class_acc, eval_trans_acc, filename, time_metric) # Extra Component. stats_str += "\nEval Extra:" logger.Log(stats_str) return eval_class_acc
def run(only_forward=False): logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") # Select data format. data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path, FLAGS.lowercase) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase) raw_eval_sets.append((eval_filename, raw_eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.") train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager) # Build trainer. trainer = ModelTrainer(model, optimizer) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Load checkpoint if available. if FLAGS.load_best and os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoint, restoring.") step, best_dev_error = trainer.load(best_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = trainer.load(standard_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() recursively_set_device(optimizer.state_dict(), the_gpu.gpu) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Do an evaluation-only run. if only_forward: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, step, vocabulary) else: # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch = get_batch(training_data_iter.next())[:4] model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) train_str = train_format(model) logger.Log("Train-Format: {}".format(train_str)) train_extra_str = train_extra_format(model) logger.Log("Train-Extra-Format: {}".format(train_extra_str)) # Train logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch = batch[:4] total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() if FLAGS.model_type == "RLSPINN": model.spinn.epsilon = FLAGS.rl_epsilon * math.exp(-step/FLAGS.rl_epsilon_decay) # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss total_loss += auxiliary_loss(model) # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats_args = train_stats(model, optimizer, A, step) logger.Log(train_str.format(**stats_args)) logger.Log(train_extra_str.format(**stats_args)) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)