def eval_stats(model, A, eval_data): im = inspect(model) class_correct = A.get('class_correct') class_total = A.get('class_total') if sum(class_total) != 0: class_acc = sum(class_correct).item() / float(sum(class_total)) else: class_acc = 0 eval_data.eval_class_accuracy = class_acc if im.has_transition_loss: all_preds = np.array(flatten(A.get('preds'))) all_truth = np.array(flatten(A.get('truth'))) avg_trans_acc = (all_preds == all_truth).sum() / \ float(all_truth.shape[0]) eval_data.eval_transition_accuracy = avg_trans_acc if im.has_invalid: eval_data.invalid = A.get_avg('invalid') time_metric = time_per_token(A.get('total_tokens'), A.get('total_time')) eval_data.time_per_token_seconds = time_metric return eval_data
def train_stats(model, optimizer, A, step): has_spinn = hasattr(model, 'spinn') has_transition_loss = hasattr( model, 'transition_loss') and model.transition_loss is not None has_invalid = has_spinn and hasattr(model.spinn, 'invalid') has_policy = has_spinn and hasattr(model, 'policy_loss') has_value = has_spinn and hasattr(model, 'value_loss') has_epsilon = has_spinn and hasattr(model.spinn, "epsilon") if has_transition_loss: all_preds = np.array(flatten(A.get('preds'))) all_truth = np.array(flatten(A.get('truth'))) avg_trans_acc = (all_preds == all_truth).sum() / float( all_truth.shape[0]) time_metric = time_per_token(A.get('total_tokens'), A.get('total_time')) ret = dict( step=step, class_acc=A.get_avg('class_acc'), transition_acc=avg_trans_acc if has_transition_loss else 0.0, xent_cost=A.get_avg('xent_cost'), # not actual mean transition_cost=model.transition_loss.data[0] if has_transition_loss else 0.0, l2_cost=A.get_avg('l2_cost'), # not actual mean invalid=A.get_avg('invalid') if has_invalid else 0.0, learning_rate=optimizer.lr, time=time_metric, ) total_cost = 0.0 for key in ret.keys(): if key == 'transition_cost' and has_transition_loss and not model.optimize_transition_loss: pass elif 'cost' in key: total_cost += ret[key] ret['total_cost'] = total_cost return ret
def eval_stats(model, A, step): has_spinn = hasattr(model, 'spinn') has_transition_loss = hasattr( model, 'transition_loss') and model.transition_loss is not None has_invalid = has_spinn and hasattr(model.spinn, 'invalid') has_policy = has_spinn and hasattr(model, 'policy_loss') has_value = has_spinn and hasattr(model, 'value_loss') has_epsilon = has_spinn and hasattr(model.spinn, "epsilon") class_correct = A.get('class_correct') class_total = A.get('class_total') class_acc = sum(class_correct) / float(sum(class_total)) if has_transition_loss: all_preds = np.array(flatten(A.get('preds'))) all_truth = np.array(flatten(A.get('truth'))) avg_trans_acc = (all_preds == all_truth).sum() / float( all_truth.shape[0]) time_metric = time_per_token(A.get('total_tokens'), A.get('total_time')) ret = dict( step=step, class_acc=class_acc, transition_acc=avg_trans_acc if has_transition_loss else 0.0, # xent_cost=A.get_avg('xent_cost'), # not actual mean # transition_cost=model.transition_loss.data[0] if has_transition_loss else 0.0, # policy_cost=model.policy_loss.data[0] if has_policy else 0.0, # value_cost=model.value_loss.data[0] if has_value else 0.0, invalid=A.get_avg('invalid') if has_invalid else 0.0, # epsilon=model.spinn.epsilon if has_epsilon else 0.0, time=time_metric, ) return ret
def stats(model, trainer, A, log_entry): im = inspect(model) if im.has_transition_loss: all_preds = np.array(flatten(A.get('preds'))) all_truth = np.array(flatten(A.get('truth'))) avg_trans_acc = (all_preds == all_truth).sum() / \ float(all_truth.shape[0]) time_metric = time_per_token(A.get('total_tokens'), A.get('total_time')) log_entry.step = trainer.step log_entry.class_accuracy = A.get_avg('class_acc') log_entry.cross_entropy_cost = A.get_avg('xent_cost') # not actual mean log_entry.learning_rate = trainer.learning_rate log_entry.time_per_token_seconds = time_metric total_cost = log_entry.cross_entropy_cost if im.has_transition_loss: log_entry.transition_accuracy = avg_trans_acc log_entry.transition_cost = model.transition_loss.data[0] if model.optimize_transition_loss: total_cost += log_entry.transition_cost if im.has_invalid: log_entry.invalid = A.get_avg('invalid') adv_mean = np.array(A.get('adv_mean'), dtype=np.float32) adv_mean_magnitude = np.array(A.get('adv_mean_magnitude'), dtype=np.float32) adv_var = np.array(A.get('adv_var'), dtype=np.float32) adv_var_magnitude = np.array(A.get('adv_var_magnitude'), dtype=np.float32) if im.has_policy: log_entry.policy_cost = A.get_avg('policy_cost') total_cost += log_entry.policy_cost if im.has_value: log_entry.value_cost = A.get_avg('value_cost') total_cost += log_entry.value_cost def get_mean(x): val = x.mean() if isinstance(val, float): return val else: return float(val) if len(adv_mean) > 0: log_entry.mean_adv_mean = get_mean(adv_mean) if len(adv_mean_magnitude) > 0: log_entry.mean_adv_mean_magnitude = get_mean(adv_mean_magnitude) if len(adv_var) > 0: log_entry.mean_adv_var = get_mean(adv_var) if len(adv_var_magnitude) > 0: log_entry.mean_adv_var_magnitude = get_mean(adv_var_magnitude) if im.has_epsilon: log_entry.epsilon = model.spinn.epsilon if im.has_spinn_temperature: log_entry.temperature = model.spinn.temperature if im.has_pyramid_temperature: log_entry.temperature = model.temperature_to_display log_entry.total_cost = total_cost return log_entry
def evaluate(model, eval_set, logger, metrics_logger, step, vocabulary=None): filename, dataset = eval_set reporter = EvalReporter() # Evaluate class_correct = 0 class_total = 0 total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() model.eval() transition_preds = [] transition_targets = [] for i, (eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch, eval_ids) in enumerate(dataset): if FLAGS.truncate_eval_batch: eval_X_batch, eval_transitions_batch = truncate( eval_X_batch, eval_transitions_batch, eval_num_transitions_batch) if FLAGS.saving_eval_attention_matrix: model.set_recording_attention_weight_matrix(True) # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions,) if FLAGS.saving_eval_attention_matrix: # WARNING: only attention SPINN model have attention matrix attention_matrix = model.get_attention_matrix_from_last_forward() with open(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name, 'attention-matrix-{}.txt'.format(step)), 'a') as txtfile: for eval_id, attmat in izip(eval_ids, attention_matrix): txtfile.write('{}\n'.format(eval_id)) txtfile.write('{},{}\n'.format(len(attmat), len(attmat[0]))) for row in attmat: txtfile.write(','.join(['{:.1f}'.format(x*100.0) for x in row])) txtfile.write('\n') model.set_recording_attention_weight_matrix(False) # reset it after run # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_correct += pred.eq(target).sum() class_total += target.size(0) # Optionally calculate transition loss/acc. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += eval_num_transitions_batch.ravel().sum() # Accumulate stats for transition accuracy. if transition_loss is not None: transition_preds.append([m["t_preds"] for m in model.spinn.memories]) transition_targets.append([m["t_given"] for m in model.spinn.memories]) if FLAGS.write_eval_report: reporter_args = [pred, target, eval_ids, output.data.cpu().numpy()] if hasattr(model, 'transition_loss'): transition_preds_per_example = model.spinn.get_transition_preds_per_example() if model.use_sentence_pair: batch_size = pred.size(0) sent1_preds = transition_preds_per_example[:batch_size] sent2_preds = transition_preds_per_example[batch_size:] reporter_args.append(sent1_preds) reporter_args.append(sent2_preds) else: reporter_args.append(transition_preds_per_example) reporter.save_batch(*reporter_args) # Print Progress progress_bar.step(i+1, total=total_batches) progress_bar.finish() end = time.time() total_time = end - start # Get time per token. time_metric = time_per_token([total_tokens], [total_time]) # Get class accuracy. eval_class_acc = class_correct / float(class_total) # Get transition accuracy if applicable. if len(transition_preds) > 0: all_preds = np.array(flatten(transition_preds)) all_truth = np.array(flatten(transition_targets)) eval_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0]) else: eval_trans_acc = 0.0 logger.Log("Step: %i Eval acc: %f %f %s Time: %5f" % (step, eval_class_acc, eval_trans_acc, filename, time_metric)) metrics_logger.Log('eval_class_acc', eval_class_acc, step) metrics_logger.Log('eval_trans_acc', eval_trans_acc, step) if FLAGS.write_eval_report: eval_report_path = os.path.join(FLAGS.log_path, FLAGS.experiment_name + ".report") reporter.write_report(eval_report_path) return eval_class_acc
def run(only_forward=False): logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") # Select data format. if FLAGS.data_type == "bl": data_manager = load_boolean_data elif FLAGS.data_type == "sst": data_manager = load_sst_data elif FLAGS.data_type == "snli": data_manager = load_snli_data elif FLAGS.data_type == "arithmetic": data_manager = load_simple_data else: logger.Log("Bad data type.") return pp = pprint.PrettyPrinter(indent=4) logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict())) # Make Metrics Logger. metrics_path = "{}/{}".format(FLAGS.metrics_path, FLAGS.experiment_name) if not os.path.exists(metrics_path): os.makedirs(metrics_path) metrics_logger = MetricsLogger(metrics_path) M = Accumulator(maxlen=FLAGS.deque_length) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path, FLAGS.lowercase) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase) raw_eval_sets.append((eval_filename, raw_eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.") train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only(), use_left_padding=FLAGS.use_left_padding) training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only(), use_left_padding=FLAGS.use_left_padding) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) # Choose model. model_specific_params = {} logger.Log("Building model.") if FLAGS.model_type == "CBOW": model_module = spinn.cbow elif FLAGS.model_type == "RNN": model_module = spinn.plain_rnn elif FLAGS.model_type == "SPINN": model_module = spinn.fat_stack elif FLAGS.model_type == "RLSPINN": model_module = spinn.rl_spinn elif FLAGS.model_type == "RAESPINN": model_module = spinn.rae_spinn elif FLAGS.model_type == "GENSPINN": model_module = spinn.gen_spinn elif FLAGS.model_type == "ATTSPINN": model_module = spinn.att_spinn model_specific_params['using_diff_in_mlstm'] = FLAGS.using_diff_in_mlstm model_specific_params['using_prod_in_mlstm'] = FLAGS.using_prod_in_mlstm model_specific_params['using_null_in_attention'] = FLAGS.using_null_in_attention attlogger = logging.getLogger('spinn.attention') attlogger.setLevel(logging.INFO) fh = logging.FileHandler(os.path.join(FLAGS.log_path, '{}.att.log'.format(FLAGS.experiment_name))) fh.setLevel(logging.INFO) fh.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s')) attlogger.addHandler(fh) else: raise Exception("Requested unimplemented model type %s" % FLAGS.model_type) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) if data_manager.SENTENCE_PAIR_DATA: trainer_cls = model_module.SentencePairTrainer model_cls = model_module.SentencePairModel use_sentence_pair = True else: trainer_cls = model_module.SentenceTrainer model_cls = model_module.SentenceModel num_classes = len(data_manager.LABEL_MAP) use_sentence_pair = False model = model_cls(model_dim=FLAGS.model_dim, word_embedding_dim=FLAGS.word_embedding_dim, vocab_size=vocab_size, initial_embeddings=initial_embeddings, num_classes=num_classes, mlp_dim=FLAGS.mlp_dim, embedding_keep_rate=FLAGS.embedding_keep_rate, classifier_keep_rate=FLAGS.semantic_classifier_keep_rate, tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim, transition_weight=FLAGS.transition_weight, encode_style=FLAGS.encode_style, encode_reverse=FLAGS.encode_reverse, encode_bidirectional=FLAGS.encode_bidirectional, encode_num_layers=FLAGS.encode_num_layers, use_sentence_pair=use_sentence_pair, use_skips=FLAGS.use_skips, lateral_tracking=FLAGS.lateral_tracking, use_tracking_in_composition=FLAGS.use_tracking_in_composition, use_difference_feature=FLAGS.use_difference_feature, use_product_feature=FLAGS.use_product_feature, num_mlp_layers=FLAGS.num_mlp_layers, mlp_bn=FLAGS.mlp_bn, rl_mu=FLAGS.rl_mu, rl_baseline=FLAGS.rl_baseline, rl_reward=FLAGS.rl_reward, rl_weight=FLAGS.rl_weight, predict_leaf=FLAGS.predict_leaf, gen_h=FLAGS.gen_h, model_specific_params=model_specific_params, ) # Build optimizer. if FLAGS.optimizer_type == "Adam": optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate, betas=(0.9, 0.999), eps=1e-08) elif FLAGS.optimizer_type == "RMSprop": optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08) else: raise NotImplementedError # Build trainer. classifier_trainer = trainer_cls(model, optimizer) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Load checkpoint if available. if FLAGS.load_best and os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoint, restoring.") step, best_dev_error = classifier_trainer.load(best_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = classifier_trainer.load(standard_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # Print model size. logger.Log("Architecture: {}".format(model)) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Do an evaluation-only run. if only_forward: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, metrics_logger, step, vocabulary) else: # Train logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = training_data_iter.next() if FLAGS.truncate_train_batch: X_batch, transitions_batch = truncate( X_batch, transitions_batch, num_transitions_batch) total_tokens = num_transitions_batch.ravel().sum() # Reset cached gradients. optimizer.zero_grad() # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) A.add('class_acc', class_acc) M.add('class_acc', class_acc) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss/accuracy. transition_acc = model.transition_acc if hasattr(model, 'transition_acc') else 0.0 transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None rl_loss = model.rl_loss if hasattr(model, 'rl_loss') else None policy_loss = model.policy_loss if hasattr(model, 'policy_loss') else None rae_loss = model.spinn.rae_loss if hasattr(model.spinn, 'rae_loss') else None leaf_loss = model.spinn.leaf_loss if hasattr(model.spinn, 'leaf_loss') else None gen_loss = model.spinn.gen_loss if hasattr(model.spinn, 'gen_loss') else None # Force Transition Loss Optimization if FLAGS.force_transition_loss: model.optimize_transition_loss = True # Accumulate stats for transition accuracy. if transition_loss is not None: preds = [m["t_preds"] for m in model.spinn.memories] truth = [m["t_given"] for m in model.spinn.memories] A.add('preds', preds) A.add('truth', truth) # Accumulate stats for leaf prediction accuracy. if leaf_loss is not None: A.add('leaf_acc', model.spinn.leaf_acc) # Accumulate stats for word prediction accuracy. if gen_loss is not None: A.add('gen_acc', model.spinn.gen_acc) # Note: Keep track of transition_acc, although this is a naive average. # Should be weighted by length of sequences in batch. M.add('transition_acc', transition_acc) # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Boilerplate for calculating loss values. xent_cost_val = xent_loss.data[0] transition_cost_val = transition_loss.data[0] if transition_loss is not None else 0.0 l2_cost_val = l2_loss.data[0] if l2_loss is not None else 0.0 rl_cost_val = rl_loss.data[0] if rl_loss is not None else 0.0 policy_cost_val = policy_loss.data[0] if policy_loss is not None else 0.0 rae_cost_val = rae_loss.data[0] if rae_loss is not None else 0.0 leaf_cost_val = leaf_loss.data[0] if leaf_loss is not None else 0.0 gen_cost_val = gen_loss.data[0] if gen_loss is not None else 0.0 # Accumulate Total Loss Data total_cost_val = 0.0 total_cost_val += xent_cost_val if transition_loss is not None and model.optimize_transition_loss: total_cost_val += transition_cost_val total_cost_val += l2_cost_val total_cost_val += rl_cost_val total_cost_val += policy_cost_val total_cost_val += rae_cost_val total_cost_val += leaf_cost_val total_cost_val += gen_cost_val M.add('total_cost', total_cost_val) M.add('xent_cost', xent_cost_val) M.add('transition_cost', transition_cost_val) M.add('l2_cost', l2_cost_val) # Logging for RL rl_keys = ['rl_loss', 'policy_loss', 'norm_rewards', 'norm_baseline', 'norm_advantage'] for k in rl_keys: if hasattr(model, k): val = getattr(model, k) val = val.data[0] if isinstance(val, Variable) else val M.add(k, val) # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss if rl_loss is not None: total_loss += rl_loss if policy_loss is not None: total_loss += policy_loss if rae_loss is not None: total_loss += rae_loss if leaf_loss is not None: total_loss += leaf_loss if gen_loss is not None: total_loss += gen_loss # Useful for debugging gradient flow. if FLAGS.debug: losses = [('total_loss', total_loss), ('xent_loss', xent_loss)] if l2_loss is not None: losses.append(('l2_loss', l2_loss)) if transition_loss is not None and model.optimize_transition_loss: losses.append(('transition_loss', transition_loss)) if rl_loss is not None: losses.append(('rl_loss', rl_loss)) if policy_loss is not None: losses.append(('policy_loss', policy_loss)) debug_gradient(model, losses) import ipdb; ipdb.set_trace() # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() avg_class_acc = A.get_avg('class_acc') if transition_loss is not None: all_preds = np.array(flatten(A.get('preds'))) all_truth = np.array(flatten(A.get('truth'))) avg_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0]) else: avg_trans_acc = 0.0 if leaf_loss is not None: avg_leaf_acc = A.get_avg('leaf_acc') else: avg_leaf_acc = 0.0 if gen_loss is not None: avg_gen_acc = A.get_avg('gen_acc') else: avg_gen_acc = 0.0 time_metric = time_per_token(A.get('total_tokens'), A.get('total_time')) stats_args = { "step": step, "class_acc": avg_class_acc, "transition_acc": avg_trans_acc, "total_cost": total_cost_val, "xent_cost": xent_cost_val, "transition_cost": transition_cost_val, "l2_cost": l2_cost_val, "rl_cost": rl_cost_val, "policy_cost": policy_cost_val, "rae_cost": rae_cost_val, "leaf_acc": avg_leaf_acc, "leaf_cost": leaf_cost_val, "gen_acc": avg_gen_acc, "gen_cost": gen_cost_val, "time": time_metric, } stats_str = "Step: {step}" # Accuracy Component. stats_str += " Acc: {class_acc:.5f} {transition_acc:.5f}" if leaf_loss is not None: stats_str += " leaf{leaf_acc:.5f}" if gen_loss is not None: stats_str += " gen{gen_acc:.5f}" # Cost Component. stats_str += " Cost: {total_cost:.5f} {xent_cost:.5f} {transition_cost:.5f} {l2_cost:.5f}" if rl_loss is not None: stats_str += " r{rl_cost:.5f}" if policy_loss is not None: stats_str += " p{policy_cost:.5f}" if rae_loss is not None: stats_str += " rae{rae_cost:.5f}" if leaf_loss is not None: stats_str += " leaf{leaf_cost:.5f}" if gen_loss is not None: stats_str += " gen{gen_cost:.5f}" # Time Component. stats_str += " Time: {time:.5f}" logger.Log(stats_str.format(**stats_args)) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, metrics_logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) classifier_trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") classifier_trainer.save(standard_checkpoint_path, step, best_dev_error) if step % FLAGS.metrics_interval_steps == 0: m_keys = M.cache.keys() for k in m_keys: metrics_logger.Log(k, M.get_avg(k), step) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
def evaluate(model, eval_set, logger, step, vocabulary=None): filename, dataset = eval_set reporter = EvalReporter() # Evaluate class_correct = 0 class_total = 0 total_batches = len(dataset) progress_bar = SimpleProgressBar(msg="Run Eval", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(0, total=total_batches) total_tokens = 0 start = time.time() model.eval() transition_preds = [] transition_targets = [] for i, batch in enumerate(dataset): eval_X_batch, eval_transitions_batch, eval_y_batch, eval_num_transitions_batch = get_batch(batch)[:4] # Run model. output = model(eval_X_batch, eval_transitions_batch, eval_y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(eval_y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_correct += pred.eq(target).sum() class_total += target.size(0) # Optionally calculate transition loss/acc. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Update Aggregate Accuracies total_tokens += sum([(nt+1)/2 for nt in eval_num_transitions_batch.reshape(-1)]) # Accumulate stats for transition accuracy. if transition_loss is not None: transition_preds.append([m["t_preds"] for m in model.spinn.memories]) transition_targets.append([m["t_given"] for m in model.spinn.memories]) # Print Progress progress_bar.step(i+1, total=total_batches) progress_bar.finish() end = time.time() total_time = end - start # Get time per token. time_metric = time_per_token([total_tokens], [total_time]) # Get class accuracy. eval_class_acc = class_correct / float(class_total) # Get transition accuracy if applicable. if len(transition_preds) > 0: all_preds = np.array(flatten(transition_preds)) all_truth = np.array(flatten(transition_targets)) eval_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0]) else: eval_trans_acc = 0.0 stats_str = "Step: %i Eval acc: %f %f %s Time: %5f" % (step, eval_class_acc, eval_trans_acc, filename, time_metric) # Extra Component. stats_str += "\nEval Extra:" logger.Log(stats_str) return eval_class_acc