def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators, logger, vocabulary): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for _ in range(trainer.step, FLAGS.training_steps): if (trainer.step - trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait: logger.Log('No improvement after ' + str(FLAGS.early_stopping_steps_to_wait) + ' steps. Stopping training.') break model.train() log_entry.Clear() log_entry.step = trainer.step should_log = False start = time.time() batch = get_batch(next(training_data_iter)) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. trainer.optimizer_zero_grad() if FLAGS.model_type in ["ChoiPyramid", "Maillard", "CatalanPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**( trainer.step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos( (trainer.step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None # Run model. output = model( X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() # get the index of the max log-probability pred = output.data.max(1, keepdim=False)[1].cpu() class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.CrossEntropyLoss()(output, to_gpu( Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr( model, 'transition_loss') else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping nn.utils.clip_grad_norm([ param for name, param in model.named_parameters() if name not in ["embed.embed.weight"] ], FLAGS.clipping_max_value) # Gradient descent step. trainer.optimizer_step() end = time.time() total_time = end - start train_accumulate(model, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if trainer.step % FLAGS.statistics_interval_steps == 0: A.add('xent_cost', xent_loss.data[0]) stats(model, trainer, A, log_entry) should_log = True progress_bar.finish() if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model( X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example( ) model.eval() model( X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example( ) if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate( [transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed # boost. t_idxs = list(range(FLAGS.num_samples)) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred_ev) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:] log.strg_ev = strength_ev[1:] if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, _ = evaluate( FLAGS, model, eval_set, log_entry, logger, trainer, show_sample=(trainer.step % FLAGS.sample_interval_steps == 0), vocabulary=vocabulary, eval_index=index) if index == 0: trainer.new_dev_accuracy(acc) progress_bar.reset() if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0: should_log = True trainer.checkpoint() if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps)
def run(only_forward=False): logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") # Select data format. data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path, FLAGS.lowercase) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase) raw_eval_sets.append((eval_filename, raw_eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.") train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager) # Build trainer. trainer = ModelTrainer(model, optimizer) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Load checkpoint if available. if FLAGS.load_best and os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoint, restoring.") step, best_dev_error = trainer.load(best_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = trainer.load(standard_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() recursively_set_device(optimizer.state_dict(), the_gpu.gpu) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Do an evaluation-only run. if only_forward: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, step, vocabulary) else: # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch = get_batch(training_data_iter.next())[:4] model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) train_str = train_format(model) logger.Log("Train-Format: {}".format(train_str)) train_extra_str = train_extra_format(model) logger.Log("Train-Extra-Format: {}".format(train_extra_str)) # Train logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch = batch[:4] total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() if FLAGS.model_type == "RLSPINN": model.spinn.epsilon = FLAGS.rl_epsilon * math.exp(-step/FLAGS.rl_epsilon_decay) # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss total_loss += auxiliary_loss(model) # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats_args = train_stats(model, optimizer, A, step) logger.Log(train_str.format(**stats_args)) logger.Log(train_extra_str.format(**stats_args)) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
def evaluate(model, eval_set, logger, step, vocabulary=None): filename, dataset = eval_set reporter = EvalReporter() # Evaluate class_correct = 0 class_total = 0 total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() model.eval() transition_preds = [] transition_targets = [] for i, batch in enumerate(dataset): eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch = get_batch(batch)[:4] # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_correct += pred.eq(target).sum() class_total += target.size(0) # Optionally calculate transition loss/acc. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt+1)/2 for nt in eval_num_transitions_batch.reshape(-1)]) # Accumulate stats for transition accuracy. if transition_loss is not None: transition_preds.append([m["t_preds"] for m in model.spinn.memories]) transition_targets.append([m["t_given"] for m in model.spinn.memories]) # Print Progress progress_bar.step(i+1, total=total_batches) progress_bar.finish() end = time.time() total_time = end - start # Get time per token. time_metric = time_per_token([total_tokens], [total_time]) # Get class accuracy. eval_class_acc = class_correct / float(class_total) # Get transition accuracy if applicable. if len(transition_preds) > 0: all_preds = np.array(flatten(transition_preds)) all_truth = np.array(flatten(transition_targets)) eval_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0]) else: eval_trans_acc = 0.0 stats_str = "Step: %i Eval acc: %f %f %s Time: %5f" % (step, eval_class_acc, eval_trans_acc, filename, time_metric) # Extra Component. stats_str += "\nEval Extra:" logger.Log(stats_str) return eval_class_acc
def evaluate(FLAGS, model, data_manager, eval_set, log_entry, step, vocabulary=None, show_sample=False, eval_index=0): filename, dataset = eval_set A = Accumulator() index = len(log_entry.evaluation) eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. """ output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) """ output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, store_parse_masks=show_sample, example_lengths=eval_num_transitions_batch) can_sample = (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) tree_strs = prettyprint_trees(tmp_samples) if not FLAGS.write_eval_report: show_sample = False # Only show one sample, regardless of the number of batches. # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() # get the index of the max log-probability pred = logits.data.max(1, keepdim=False)[1].cpu() eval_accumulate(model, data_manager, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Optionally calculate transition loss/acc. model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)]) """ if FLAGS.write_eval_report: reporter_args = [pred, target, eval_ids, output.data.cpu().numpy()] if hasattr(model, 'transition_loss'): transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[:batch_size] sent2_transitions = transitions_per_example[batch_size:] reporter_args.append(sent1_transitions) reporter_args.append(sent2_transitions) else: reporter_args.append(transitions_per_example) reporter.save_batch(*reporter_args) """ if FLAGS.write_eval_report: transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[: batch_size] if transitions_per_example is not None else None sent2_transitions = transitions_per_example[ batch_size:] if transitions_per_example is not None else None sent1_trees = tree_strs[: batch_size] if tree_strs is not None else None sent2_trees = tree_strs[ batch_size:] if tree_strs is not None else None else: sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) eval_log.filename = filename """ if FLAGS.write_eval_report: eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".report") reporter.write_report(eval_report_path) """ if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy return eval_class_acc, eval_trans_acc
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, true_step, best_dev_error, perturbation_id, ev_step, header, root_id, vocabulary): perturbation_name = FLAGS.experiment_name + "_p" + str(perturbation_id) root_name = FLAGS.experiment_name + "_p" + str(root_id) # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) header.start_step = true_step header.start_time = int(time.time()) #header.model_label = perturbation_name # Checkpoint paths. standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, perturbation_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, perturbation_name, best=True) # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch( training_data_iter.next()) model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) # Train. logger.Log("Training perturbation %s" % perturbation_id) # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for step in range(FLAGS.es_episode_length): true_step += 1 model.train() log_entry.Clear() log_entry.step = true_step log_entry.model_label = perturbation_name log_entry.root_label = root_name should_log = False start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() # get the index of the max log-probability pred = logits.data.max(1, keepdim=False)[1].cpu() class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr( model, 'transition_loss') else None # Extract L2 Cost l2_loss = get_l2_loss(model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * \ (FLAGS.learning_rate_decay_per_10k_steps ** (true_step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if true_step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats(model, optimizer, A, true_step, log_entry) should_log = True if true_step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example( ) model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example( ) if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate( [transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed boost. t_idxs = range(FLAGS.num_samples) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred_ev) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:].encode('utf-8') log.strg_ev = strength_ev[1:].encode('utf-8') if true_step > 0 and true_step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, tacc = evaluate( FLAGS, model, data_manager, eval_set, log_entry, true_step, vocabulary, show_sample=(true_step % FLAGS.sample_interval_steps == 0), eval_index=index) if FLAGS.ckpt_on_best_dev_error and index == 0 and \ (1 - acc) < 0.99 * best_dev_error and \ true_step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log( "Checkpointing with new best dev accuracy of %f" % acc ) # TODO: This mixes information across dev sets. Fix. trainer.save(best_checkpoint_path, true_step, best_dev_error, ev_step) progress_bar.reset() if true_step > FLAGS.ckpt_step and true_step % FLAGS.ckpt_interval_steps == 0: should_log = True logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, true_step, best_dev_error, ev_step) if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(true_step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps) if os.path.exists(best_checkpoint_path): return ev_step, true_step, perturbation_id, best_dev_error else: return ev_step, true_step, perturbation_id, (1 - acc)
def evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, vocabulary=None, show_sample=False, eval_index=0): filename, dataset = eval_set A = Accumulator() eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar( msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 cpt = 0 cpt_max = 0 start = time.time() model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. np.set_printoptions(threshold=np.inf) output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, store_parse_masks=show_sample, example_lengths=eval_num_transitions_batch) can_sample = (FLAGS.model_type == "RLSPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) tree_strs = prettyprint_trees(tmp_samples) if not FLAGS.write_eval_report: # Only show one sample, regardless of the number of batches. show_sample = False # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() # get the index of the max log-probability pred = output.data.max(1, keepdim=False)[1].cpu() eval_accumulate(model, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Update Aggregate Accuracies total_tokens += sum( [(nt + 1) / 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "RLSPINN" and FLAGS.use_internal_parser) else ( None, None) if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[: batch_size] if transitions_per_example is not None else None sent2_transitions = transitions_per_example[batch_size: ] if transitions_per_example is not None else None sent1_trees = tree_strs[:batch_size] if tree_strs is not None else None sent2_trees = tree_strs[batch_size: ] if tree_strs is not None else None else: sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None #cp_metric = True # Flagify this if FLAGS.cp_metric: cp, cp_max = reporter.save_batch( pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees, cp_metric=FLAGS.cp_metric) cpt += cp cpt_max += cp_max else: reporter.save_batch( pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() cp_metric_value = cpt / cpt_max if tree_strs is not None: logger.Log('Sample: ' + tree_strs[0]) end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) logger.Log("Eval cp_acc: " + str(cp_metric_value)) eval_stats(model, A, eval_log) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy return eval_class_acc, eval_trans_acc
def evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, vocabulary=None, show_sample=False, eval_index=0, target_vocabulary=None): filename, dataset = eval_set A = Accumulator() len(log_entry.evaluation) eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() if FLAGS.model_type in ["ChoiPyramid", "Maillard", "CatalanPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**( trainer.step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos( (trainer.step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None model.eval() ref_file_name = FLAGS.log_path + "/ref_file" pred_file_name = FLAGS.log_path + "/pred_file" reference_file = open(ref_file_name, "w") predict_file = open(pred_file_name, "w") full_ref = [] full_pred = [] for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. output = model( eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=eval_num_transitions_batch) can_sample = FLAGS.model_type in [ "ChoiPyramid", "Maillard", "CatalanPyramid" ] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.encoder.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) tree_strs = prettyprint_trees(tmp_samples) if not FLAGS.write_eval_report: # Only show one sample, regardless of the number of batches. show_sample = False # Get reference translation ref_out = [" ".join(map(str, k[:-1])) + " ." for k in eval_y_batch] full_ref += ref_out # Get predicted translation predicted = [[] for i in range(len(eval_y_batch))] done = [] for x in output: index = -1 for x_0 in x: index += 1 val = int(x_0) if val == 1: if index in done: continue done.append(index) elif index not in done: predicted[index].append(val) pred_out = [" ".join(map(str, k)) + " ." for k in predicted] full_pred += pred_out eval_accumulate(model, A, batch) # Optionally calculate transition loss/acc. model.encoder.transition_loss if hasattr(model.encoder, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / \ 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.encoder.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(full_pred, full_ref, eval_ids, [None], sent1_transitions, sent2_transitions, sent1_trees, sent2_trees, mt=True) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + tree_strs[0]) reference_file.write("\n".join(full_ref)) reference_file.close() predict_file.write("\n".join(full_pred)) predict_file.close() bleu_score = os.popen("perl spinn/util/multi-bleu.perl " + ref_file_name + " < " + pred_file_name).read() try: bleu_score = float(bleu_score) except: bleu_score = 0.0 end = time.time() total_time = end - start A.add('class_correct', bleu_score) A.add('class_total', 1) A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy return eval_class_acc, eval_trans_acc
def evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, vocabulary=None, show_sample=False, eval_index=0): filename, dataset = eval_set A = Accumulator() len(log_entry.evaluation) eval_log = log_entry.evaluation.add() reporter = EvalReporter() tree_strs = None # Evaluate total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() if FLAGS.model_type in ["ChoiPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps**( trainer.step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos( (trainer.step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None model.eval() for i, dataset_batch in enumerate(dataset): batch = get_batch(dataset_batch) eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids = batch # Run model. output = model( eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, store_parse_masks=show_sample, example_lengths=eval_num_transitions_batch) # TODO: Restore support in Pyramid if using. can_sample = FLAGS.model_type in [ "ChoiPyramid" ] or (FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) if show_sample and can_sample: tmp_samples = model.get_samples( eval_X_batch, vocabulary, only_one=not FLAGS.write_eval_report) tree_strs = prettyprint_trees(tmp_samples) if not FLAGS.write_eval_report: # Only show one sample, regardless of the number of batches. show_sample = False # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() # get the index of the max log-probability pred = output.data.max(1, keepdim=False)[1].cpu() eval_accumulate(model, A, batch) A.add('class_correct', pred.eq(target).sum()) A.add('class_total', target.size(0)) # Optionally calculate transition loss/acc. model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt + 1) / \ 2 for nt in eval_num_transitions_batch.reshape(-1)]) if FLAGS.write_eval_report: transitions_per_example, _ = model.spinn.get_transitions_per_example( style="preds" if FLAGS.eval_report_use_preds else "given") if ( FLAGS.model_type == "SPINN" and FLAGS.use_internal_parser) else (None, None) if model.use_sentence_pair: batch_size = pred.size(0) sent1_transitions = transitions_per_example[: batch_size] if transitions_per_example is not None else None sent2_transitions = transitions_per_example[ batch_size:] if transitions_per_example is not None else None sent1_trees = tree_strs[: batch_size] if tree_strs is not None else None sent2_trees = tree_strs[ batch_size:] if tree_strs is not None else None else: sent1_transitions = transitions_per_example if transitions_per_example is not None else None sent2_transitions = None sent1_trees = tree_strs if tree_strs is not None else None sent2_trees = None reporter.save_batch(pred, target, eval_ids, output.data.cpu().numpy(), sent1_transitions, sent2_transitions, sent1_trees, sent2_trees) # Print Progress progress_bar.step(i + 1, total=total_batches) progress_bar.finish() if tree_strs is not None: logger.Log('Sample: ' + tree_strs[0]) end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) eval_stats(model, A, eval_log) eval_log.filename = filename if FLAGS.write_eval_report: eval_report_path = os.path.join( FLAGS.log_path, FLAGS.experiment_name + ".eval_set_" + str(eval_index) + ".report") reporter.write_report(eval_report_path) eval_class_acc = eval_log.eval_class_accuracy eval_trans_acc = eval_log.eval_transition_accuracy return eval_class_acc, eval_trans_acc