def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error, vocabulary): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Checkpoint paths. standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch( training_data_iter.next()) model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=1.0, example_lengths=num_transitions_batch ) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for step in range(step, FLAGS.training_steps): model.train() log_entry.Clear() log_entry.step = step should_log = False start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() if FLAGS.model_type in ["Pyramid", "ChoiPyramid"]: pyramid_temperature_multiplier = FLAGS.pyramid_temperature_decay_per_10k_steps ** ( step / 10000.0) if FLAGS.pyramid_temperature_cycle_length > 0.0: min_temp = 1e-5 pyramid_temperature_multiplier *= (math.cos((step) / FLAGS.pyramid_temperature_cycle_length) + 1 + min_temp) / 2 else: pyramid_temperature_multiplier = None # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() # get the index of the max log-probability pred = logits.data.max(1, keepdim=False)[1].cpu() class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = get_l2_loss(model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * \ (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats(model, optimizer, A, step, log_entry) should_log = True progress_bar.finish() if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch ) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example() model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions, pyramid_temperature_multiplier=pyramid_temperature_multiplier, example_lengths=num_transitions_batch ) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example() if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate([ transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed boost. t_idxs = range(FLAGS.num_samples) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred_ev) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:].encode('utf-8') log.strg_ev = strength_ev[1:].encode('utf-8') if step > 0 and step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, tacc = evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, show_sample=( step % FLAGS.sample_interval_steps == 0), vocabulary=vocabulary, eval_index=index) if FLAGS.ckpt_on_best_dev_error and index == 0 and ( 1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) # TODO: This mixes information across dev sets. Fix. trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: should_log = True logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps)
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Checkpoint paths. standard_checkpoint_path = get_checkpoint_path( FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path( FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch( training_data_iter.next()) model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar( msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for step in range(step, FLAGS.training_steps): model.train() log_entry.Clear() log_entry.step = step should_log = False start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum( [(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() epsilon = FLAGS.rl_epsilon * math.exp(-step / FLAGS.rl_epsilon_decay) # Epsilon Greedy w. Decay. model.spinn.epsilon = epsilon # Confidence Penalty for Transition Predictions. temperature = math.sin(math.pi / 2 + step / float(FLAGS.rl_confidence_interval) * 2 * math.pi) temperature = (temperature + 1) / 2 if FLAGS.rl_confidence_penalty: temp = 1 + \ (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon model.spinn.temperature = max(1e-3, temp) # Soft Wake/Sleep based on temperature. if FLAGS.rl_wake_sleep: model.rl_weight = temperature * FLAGS.rl_weight # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[ 1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()( logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr( model, 'transition_loss') else None # Extract L2 Cost l2_loss = get_l2_loss( model, FLAGS.l2_lambda) if FLAGS.use_l2_loss else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * \ (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) train_rl_accumulate(model, data_manager, A, batch) if step % FLAGS.statistics_interval_steps == 0 \ or step % FLAGS.metrics_interval_steps == 0: if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats(model, optimizer, A, step, log_entry) if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example( ) model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example( ) if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate([ transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed # boost. t_idxs = range(FLAGS.num_samples) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks( [1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks( [1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:].encode('utf-8') log.strg_ev = strength_ev[1:].encode('utf-8') if step > 0 and step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, tacc = evaluate( FLAGS, model, data_manager, eval_set, log_entry, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and ( 1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log( "Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: should_log = True logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) log_level = afs_safe_logger.ProtoLogger.INFO if not should_log and step % FLAGS.metrics_interval_steps == 0: # Log to file, but not to stderr. should_log = True log_level = afs_safe_logger.ProtoLogger.DEBUG if should_log: logger.LogEntry(log_entry, level=log_level) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)