def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators, logger): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for _ in range(trainer.step, FLAGS.training_steps): if (trainer.step - trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait: logger.Log('No improvement after ' + str(FLAGS.early_stopping_steps_to_wait) + ' steps. Stopping training.') break model.train() log_entry.Clear() log_entry.step = trainer.step should_log = False start = time.time() batch = get_batch(next(training_data_iter)) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. trainer.optimizer_zero_grad() temperature = math.sin( math.pi / 2 + trainer.step / float(FLAGS.rl_confidence_interval) * 2 * math.pi) temperature = (temperature + 1) / 2 # Confidence Penalty for Transition Predictions. if FLAGS.rl_confidence_penalty: epsilon = FLAGS.rl_epsilon * \ math.exp(-trainer.step / float(FLAGS.rl_epsilon_decay)) temp = 1 + \ (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon model.spinn.temperature = max(1e-3, temp) # Soft Wake/Sleep based on temperature. if FLAGS.rl_wake_sleep: model.rl_weight = temperature * FLAGS.rl_weight # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() # get the index of the max log-probability pred = output.data.max(1, keepdim=False)[1].cpu() class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.CrossEntropyLoss()(output, to_gpu( Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr( model, 'transition_loss') else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping nn.utils.clip_grad_norm([ param for name, param in model.named_parameters() if name not in ["embed.embed.weight"] ], FLAGS.clipping_max_value) # Gradient descent step. trainer.optimizer_step() end = time.time() total_time = end - start train_accumulate(model, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) train_rl_accumulate(model, A, batch) if trainer.step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) stats(model, trainer, A, log_entry) should_log = True if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example( ) model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example( ) if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate( [transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed # boost. t_idxs = list(range(FLAGS.num_samples)) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:] log.strg_ev = strength_ev[1:] if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, _ = evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, eval_index=index) if index == 0: trainer.new_dev_accuracy(acc) progress_bar.reset() if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0: should_log = True trainer.checkpoint() if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps)
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) M = MetricsWriter(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name)) # Checkpoint paths. standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch(training_data_iter.next()) model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) logger.Log("") logger.Log("# ----- BEGIN: Log Configuration ----- #") # Preview train string template. train_str = train_format(model) logger.Log("Train-Format: {}".format(train_str)) train_extra_str = train_extra_format(model) logger.Log("Train-Extra-Format: {}".format(train_extra_str)) # Preview eval string template. eval_str = eval_format(model) logger.Log("Eval-Format: {}".format(eval_str)) eval_extra_str = eval_extra_format(model) logger.Log("Eval-Extra-Format: {}".format(eval_extra_str)) logger.Log("# ----- END: Log Configuration ----- #") logger.Log("") # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss total_loss += auxiliary_loss(model) # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats_args = train_stats(model, optimizer, A, step) train_metrics(M, stats_args, step) logger.Log(train_str.format(**stats_args)) logger.Log(train_extra_str.format(**stats_args)) if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example() model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example() transition_str = "Samples:" if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate([ transitions_batch[:,:,0], transitions_batch[:,:,1]], axis=0) # This could be done prior to running the batch for a tiny speed boost. t_idxs = range(FLAGS.num_samples) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] stength_tr = sparks([1] + tr_strength[t_idx].tolist()) stength_ev = sparks([1] + ev_strength[t_idx].tolist()) _, crossing = evalb.crossing(gold, pred) transition_str += "\n{}. crossing={}".format(t_idx, crossing) transition_str += "\n g{}".format("".join(map(str, gold))) transition_str += "\n {}".format(stength_tr[1:].encode('utf-8')) transition_str += "\n pt{}".format("".join(map(str, pred_tr))) transition_str += "\n {}".format(stength_ev[1:].encode('utf-8')) transition_str += "\n pe{}".format("".join(map(str, pred_ev))) logger.Log(transition_str) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc, tacc = evaluate(FLAGS, model, data_manager, eval_set, index, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error, vocabulary): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Checkpoint paths. standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch( training_data_iter.next()) model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=1.0, example_lengths=num_transitions_batch ) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for step in range(step, FLAGS.training_steps): model.train() log_entry.Clear() log_entry.step = step should_log = False start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps ** ( step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos((step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() # get the index of the max log-probability pred = logits.data.max(1, keepdim=False)[1].cpu() class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = get_l2_loss(model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * \ (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats(model, optimizer, A, step, log_entry) should_log = True progress_bar.finish() if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch ) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example() model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch ) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example() if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate([ transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed boost. t_idxs = range(FLAGS.num_samples) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred_ev) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:].encode('utf-8') log.strg_ev = strength_ev[1:].encode('utf-8') if step > 0 and step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, tacc = evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, show_sample=( step % FLAGS.sample_interval_steps == 0), vocabulary=vocabulary, eval_index=index) if FLAGS.ckpt_on_best_dev_error and index == 0 and ( 1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) # TODO: This mixes information across dev sets. Fix. trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: should_log = True logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps)
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Checkpoint paths. standard_checkpoint_path = get_checkpoint_path( FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path( FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch( training_data_iter.next()) model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar( msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for step in range(step, FLAGS.training_steps): model.train() log_entry.Clear() log_entry.step = step should_log = False start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum( [(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() epsilon = FLAGS.rl_epsilon * math.exp(-step / FLAGS.rl_epsilon_decay) # Epsilon Greedy w. Decay. model.spinn.epsilon = epsilon # Confidence Penalty for Transition Predictions. temperature = math.sin(math.pi / 2 + step / float(FLAGS.rl_confidence_interval) * 2 * math.pi) temperature = (temperature + 1) / 2 if FLAGS.rl_confidence_penalty: temp = 1 + \ (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon model.spinn.temperature = max(1e-3, temp) # Soft Wake/Sleep based on temperature. if FLAGS.rl_wake_sleep: model.rl_weight = temperature * FLAGS.rl_weight # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[ 1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()( logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr( model, 'transition_loss') else None # Extract L2 Cost l2_loss = get_l2_loss( model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * \ (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) train_rl_accumulate(model, data_manager, A, batch) if step % FLAGS.statistics_interval_steps == 0 \ or step % FLAGS.metrics_interval_steps == 0: if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats(model, optimizer, A, step, log_entry) if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example( ) model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example( ) if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate([ transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed # boost. t_idxs = range(FLAGS.num_samples) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks( [1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks( [1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:].encode('utf-8') log.strg_ev = strength_ev[1:].encode('utf-8') if step > 0 and step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, tacc = evaluate( FLAGS, model, data_manager, eval_set, log_entry, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and ( 1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log( "Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: should_log = True logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) log_level = afs_safe_logger.ProtoLogger.INFO if not should_log and step % FLAGS.metrics_interval_steps == 0: # Log to file, but not to stderr. should_log = True log_level = afs_safe_logger.ProtoLogger.DEBUG if should_log: logger.LogEntry(log_entry, level=log_level) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators, logger, vocabulary, target_vocabulary): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) rl_only = False log_entry = pb.SpinnEntry() for _ in range(trainer.step, FLAGS.training_steps): if FLAGS.rl_alternate and trainer.step % 1000 == 0 and trainer.step > 0: rl_only = not rl_only if rl_only: logger.Log('Switching training mode: RL only.') else: logger.Log('Switching training mode: MT only.') if (trainer.step - trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait: logger.Log('No improvement after ' + str(FLAGS.early_stopping_steps_to_wait) + ' steps. Stopping training.') break model.train() log_entry.Clear() log_entry.step = trainer.step should_log = False start = time.time() batch = get_batch(next(training_data_iter)) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. trainer.optimizer_zero_grad() temperature = math.sin( math.pi / 2 + trainer.step / float(FLAGS.rl_confidence_interval) * 2 * math.pi) temperature = (temperature + 1) / 2 # Confidence Penalty for Transition Predictions. if FLAGS.rl_confidence_penalty: epsilon = FLAGS.rl_epsilon * \ math.exp(-trainer.step / float(FLAGS.rl_epsilon_decay)) temp = 1 + \ (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon model.spinn.temperature = max(1e-3, temp) # Soft Wake/Sleep based on temperature. if FLAGS.rl_wake_sleep: model.rl_weight = temperature * FLAGS.rl_weight # Run model. output, trg, attention, mask = model( X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, example_lengths=num_transitions_batch) criterion = nn.NLLLoss() batch_size = len(y_batch) trg_seq_len = trg.shape[0] mt_loss = 0.0 if rl_only == False: num_classes = output.shape[-1] mask = to_gpu(mask) for i in range(trg_seq_len): mt_loss += criterion( output[i, :].index_select(0, mask[i].nonzero().squeeze(1)), trg[i].index_select(0, mask[i].nonzero().squeeze(1)).view(-1)) elif FLAGS.rl_alternate: model.policy_loss = 0.0 model.value_loss = 0.0 # Optionally calculate transition loss. mt_loss = mt_loss / trg_seq_len model.transition_loss = model.encoder.transition_loss if hasattr( model.encoder, 'transition_loss') else None transition_loss = model.transition_loss if hasattr( model, 'transition_loss') else None model.mt_loss = mt_loss # Accumulate Total Loss Variable total_loss = 0.0 total_loss += mt_loss if transition_loss is not None and model.encoder.optimize_transition_loss: model.optimize_transition_loss = model.encoder.optimize_transition_loss total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss[0] # Backward pass. total_loss.backward() # Hard Gradient Clipping nn.utils.clip_grad_norm_([ param for name, param in model.named_parameters() if name not in ["embed.embed.weight"] ], FLAGS.clipping_max_value) # Gradient descent step. trainer.optimizer_step() bb = list(model.parameters())[-1].clone() end = time.time() total_time = end - start train_accumulate(model, A, batch) A.add('total_tokens', total_tokens) A.add('total_time', total_time) A.add('mt_loss', float(mt_loss)) train_rl_accumulate(model, A, batch) if trainer.step % FLAGS.statistics_interval_steps == 0: stats(model, trainer, A, log_entry) should_log = True progress_bar.finish() if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, example_lengths=num_transitions_batch) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example( ) model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, example_lengths=num_transitions_batch) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example( ) if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate( [transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed # boost. t_idxs = list(range(FLAGS.num_samples)) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred_ev) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:] log.strg_ev = strength_ev[1:] if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, _ = evaluate( FLAGS, model, eval_set, log_entry, logger, trainer, show_sample=(trainer.step % FLAGS.sample_interval_steps == 0), vocabulary=vocabulary, eval_index=index, target_vocabulary=target_vocabulary) if index == 0: trainer.new_dev_accuracy(acc) progress_bar.reset() if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0: should_log = True trainer.checkpoint() if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps)
def run_pyramid(self, x, show_sample=False): batch_size, seq_len, model_dim = x.data.size() all_state_pairs = [] all_state_pairs.append(torch.chunk(x, seq_len, 1)) # Temp fix: show_sample = False if show_sample: print for layer in range(seq_len - 1, 0, -1): composition_results = [] selection_logits_list = [] for position in range(layer): left = torch.squeeze(all_state_pairs[-1][position]) right = torch.squeeze(all_state_pairs[-1][position + 1]) composition_results.append(self.composition_fn(left, right)) if self.gated: for position in range(layer): selection_logits_list.append( self.selection_fn(composition_results[position])) selection_logits = torch.cat(selection_logits_list, 1) if show_sample: selection_probs = F.softmax(selection_logits) print sparks( np.transpose( selection_probs[0, :].data.cpu().numpy()).tolist()) if self.training and self.selection_keep_rate is not None: noise = torch.bernoulli( (to_gpu(torch.ones(1, 1)) * self.selection_keep_rate ).expand_as(selection_logits)) * -1000. selection_logits += Variable(noise) selection_probs = F.softmax(selection_logits) layer_state_pairs = [] for position in range(layer): if position < (layer - 1): copy_left = torch.sum( selection_probs[:, position + 1:], 1) else: copy_left = to_gpu(Variable(torch.zeros(1, 1))) if position > 0: copy_right = torch.sum(selection_probs[:, :position], 1) else: copy_right = to_gpu(Variable(torch.zeros(1, 1))) select = selection_probs[:, position] left = torch.squeeze(all_state_pairs[-1][position]) right = torch.squeeze(all_state_pairs[-1][position + 1]) composition_result = composition_results[position] new_state_pair = copy_left.expand_as(left) * left \ + copy_right.expand_as(right) * right \ + select.unsqueeze(1).expand_as(composition_result) * composition_result layer_state_pairs.append(new_state_pair) else: layer_state_pairs = composition_results all_state_pairs.append(layer_state_pairs) return all_state_pairs[-1][-1]