def forward_hook(self, embeds, batch_size, seq_length, h): if self.rl_baseline == "value" and self.training: # Break the computational graph. x = Variable(embeds.data, volatile=not self.training).view( batch_size, seq_length, -1) h0 = Variable(to_gpu(torch.zeros(1, batch_size, self.v_rnn_dim)), volatile=not self.training) c0 = Variable(to_gpu(torch.zeros(1, batch_size, self.v_rnn_dim)), volatile=not self.training) output, (hn, _) = self.v_rnn(x, (h0, c0)) if self.use_sentence_pair: hn = hn.squeeze() h1, h2 = hn[:batch_size // 2], hn[batch_size // 2:] hn_both = torch.cat([h1, h2], 1) self.baseline_outp = self.v_mlp(hn_both.squeeze()) else: self.baseline_outp = self.v_mlp(hn.squeeze()) elif self.rl_baseline == "shared" and self.training: # Break the computational graph. hn = h[0] # model_dim//2, batch_size if self.use_sentence_pair: # To-do: Not currently supported!! hn = hn.squeeze() h1, h2 = hn[:batch_size // 2], hn[batch_size // 2:] hn_both = torch.cat([h1, h2], 1) self.baseline_outp = self.v_mlp(hn_both.squeeze()) else: self.baseline_outp = self.v_mlp(hn.squeeze())
def run_rnn(self, x): batch_size, seq_len, model_dim = x.data.size() num_layers = 1 bidirectional=self.is_bidirectional bi = 2 if bidirectional else 1 h0 = Variable( to_gpu( torch.zeros( num_layers * bi, batch_size, self.model_dim)), volatile=not self.training) c0 = Variable( to_gpu( torch.zeros( num_layers * bi, batch_size, self.model_dim)), volatile=not self.training) # Expects (input, h_0): # input => batch_size x seq_len x model_dim # h_0 => (num_layers x num_directions[1,2]) x batch_size x model_dim # c_0 => (num_layers x num_directions[1,2]) x batch_size x model_dim output, (hn, cn) = self.rnn(x, (h0, c0)) if self.data_type=="mt": return hn, cn, output return hn
def unwrap_tree(self, lefts, rights, writes): max_len = lefts.shape[1] left_prem = lefts[:, :, 0] left_hyp = lefts[:, :, 1] left = np.concatenate([left_prem, left_hyp], axis=0) right_prem = rights[:, :, 0] right_hyp = rights[:, :, 1] right = np.concatenate([right_prem, right_hyp], axis=0) write_prem = writes[:, :, 0] write_hyp = writes[:, :, 1] write = np.concatenate([write_prem, write_hyp], axis=0) #print("left") #print(left) #print("write") #print(write) l = to_gpu(Variable(torch.from_numpy(left), volatile=not self.training)) r = to_gpu(Variable(torch.from_numpy(right), volatile=not self.training)) w = to_gpu(Variable(torch.from_numpy(write), volatile=not self.training)) l = l - (l.ge(200).int() * (200 - max_len)) #print("left new") #print(l) r = r - (r.ge(200).int() * (200 - max_len)) w = w - (w.ge(201).int() * (201 - max_len)) w_mask = w.ge(0).long() w = w + (w.le(0).int() * (2 * max_len)) #print("write new") #print(w) #print("write mask") #print(w_mask) return l.long(), r.long(), w.long(), w_mask
def forward_hook(self, embeds, batch_size, seq_length): if self.rl_baseline == "value" and self.training: # Break the computational graph. x = Variable(embeds.data, volatile=not self.training).view( batch_size, seq_length, -1) h0 = Variable(to_gpu(torch.zeros(1, batch_size, self.v_dim)), volatile=not self.training) c0 = Variable(to_gpu(torch.zeros(1, batch_size, self.v_dim)), volatile=not self.training) output, (hn, cn) = self.v_rnn(x, (h0, c0)) self.baseline_outp = self.v_mlp(hn.squeeze())
def build_reward(self, output, target, mask, rl_reward="mean"): if rl_reward == "xent": batch_size = target.size(0) seq_length = target.size(1) _target = target.permute(1, 0).long() output = output[:-1, :, :] # drop <end> token probs = F.softmax(output, dim=2).data.cpu() log_inv_prob = torch.log(1 - probs) # Looping over seq_length to get a sum of rewards across the full sequence # Element-wise mean not supported yet. rewards = torch.zeros(batch_size) for i in range(seq_length): rewards += -1 * torch.gather( log_inv_prob[i], 1, _target[i].unsqueeze(1)).squeeze() else: output = output.permute(1, 0, 2) target = to_gpu(Variable(target)) if rl_reward == "mean": criterion = nn.NLLLoss(reduction="elementwise_mean") elif rl_reward == "sum": criterion = nn.NLLLoss(reduction="sum") batch_size = output.shape[0] rewards = [0.0] * batch_size # Note that we're putting NLLLoss to an unusual use below # Instead of passing a full batch of single token, we're passing a single full example of some sequence length # If summing, we're summing over all prediction, similarly for elementwise-mean for i in range(batch_size): rewards[i] = criterion(output[i][:-1, :], target[i].long()) rewards = torch.tensor([float(x) for x in rewards]) return rewards
def run_spinn(self, example, embeds, use_internal_parser, validate_transitions=True): self.spinn.reset_state() h_list, transition_acc, transition_loss, attended = self.spinn( example, use_internal_parser=use_internal_parser, validate_transitions=validate_transitions) ## Not using during attention debugging. maxlen_attended = max([len(x) for x in attended]) memory_lengths = None #to_gpu(Variable(torch.Tensor([len(x) for x in attended]))) attended = [ x + (maxlen_attended - len(x)) * [to_gpu(Variable(torch.zeros(1, self.model_dim)))] for x in attended ] attended = [torch.cat(x) for x in attended] attended = torch.cat([x.unsqueeze(1) for x in attended], 1) if self.data_type == "mt": h = torch.cat(h_list).unsqueeze(0) else: h = self.wrap(h_list) return h, h_list, transition_acc, transition_loss, attended, memory_lengths
def unwrap_sentence_pair(self, sentences, transitions): x_prem = sentences[:, :, 0] x_hyp = sentences[:, :, 1] x = np.concatenate([x_prem, x_hyp], axis=0) return to_gpu(Variable(torch.from_numpy(x), volatile=not self.training))
def predict_actions(self, transition_output): transition_output_t = transition_output / max(self.temperature, TINY) transition_dist = F.softmax(transition_output_t, dim=1) if self.catalan: # Use the catalan distribution as a prior. p_shift_catalan = [ self.shift_probabilities.prob(n_red, n_step, n_tok) for n_red, n_step, n_tok in zip(self.n_reduces, self.n_steps, self.n_tokens) ] p_shift_catalan = torch.FloatTensor(p_shift_catalan).view(-1, 1) p_catalan = torch.cat([p_shift_catalan, 1. - p_shift_catalan], 1) p_catalan = to_gpu(Variable(p_catalan)) _p_new = transition_dist * p_catalan p_new = _p_new / (_p_new.sum(1) + TINY) # normalize transition_dist = p_new if self.catalan and self.catalan_backprop: transition_logdist = torch.log(transition_dist + TINY) else: transition_logdist = F.log_softmax(transition_output_t, dim=1) shift_probs = transition_dist.data[:, 0] if self.training: np_shift_probs = shift_probs.cpu().numpy() transition_preds = (np.random.rand(*np_shift_probs.shape) > np_shift_probs).astype('int32') else: # Greedy prediction transition_preds = torch.round( 1 - shift_probs).cpu().numpy().astype('int32') return transition_logdist, transition_preds
def forward(self, example, use_internal_parser=False, validate_transitions=True): self.buffers_n = (example.tokens.data != 0).long().sum(1).view(-1).tolist() if self.debug: seq_length = example.tokens.size(1) assert all(buf_n <= (seq_length + 1) // 2 for buf_n in self.buffers_n), \ "All sentences (including cropped) must be the appropriate length." self.bufs = example.bufs # Notes on adding zeros to bufs/stacks. # - After the buffer is consumed, we need one zero on the buffer # used as input to the tracker. # - For the first two steps, the stack would be empty, but we add # zeros so that the tracker still gets input. zeros = to_gpu(Variable(torch.from_numpy( np.zeros(self.bufs[0][0].size(), dtype=np.float32)), volatile=self.bufs[0][0].volatile)) # Trim unused tokens. self.bufs = [[zeros] + b[-b_n:] for b, b_n in zip(self.bufs, self.buffers_n)] self.stacks = [[zeros, zeros] for buf in self.bufs] if hasattr(self, 'tracker'): self.tracker.reset_state() if not hasattr(example, 'transitions'): # TODO: Support no transitions. In the meantime, must at least pass dummy transitions. raise ValueError('Transitions must be included.') self.forward_hook() return self.run(example.transitions, run_internal_parser=True, use_internal_parser=use_internal_parser, validate_transitions=validate_transitions)
def forward(self, top_buf, top_stack_1, top_stack_2): if self.tracking_ln: top_buf = self.buf_ln(top_buf) top_stack_1 = self.stack1_ln(top_stack_1) top_stack_2 = self.stack2_ln(top_stack_2) if self.lateral_tracking: tracker_inp = self.buf(top_buf) tracker_inp += self.stack1(top_stack_1) tracker_inp += self.stack2(top_stack_2) batch_size = tracker_inp.size(0) if self.h is not None: tracker_inp += self.lateral(self.h) if self.c is None: self.c = to_gpu( Variable(torch.from_numpy( np.zeros((batch_size, self.state_size), dtype=np.float32)), volatile=tracker_inp.volatile)) # Run tracking lstm. self.c, self.h = lstm(self.c, tracker_inp) return self.h, self.c else: return torch.cat([top_buf, top_stack_1, top_stack_2], 1), None
def forward(self, sentences, _, __=None, example_lengths=None, store_parse_masks=False, pyramid_temperature_multiplier=1.0, **kwargs): # Useful when investigating dynamic batching: # self.seq_lengths = sentences.shape[1] - (sentences == 0).sum(1) x, example_lengths = self.unwrap(sentences, example_lengths) emb = self.run_embed(x) batch_size, seq_len, model_dim = emb.data.size() example_lengths_var = to_gpu( Variable(torch.from_numpy(example_lengths))).long() hh, _, masks, temperature = self.binary_tree_lstm( emb, example_lengths_var, temperature_multiplier=pyramid_temperature_multiplier) if self.training: self.temperature_to_display = temperature if store_parse_masks: self.mask_memory = [mask.data.cpu().numpy() for mask in masks] h = self.wrap(hh) output = self.mlp(self.build_features(h)) return output
def forward(self, sentences, _, __, dist=None, pyramid_temperature_multiplier=1.0, example_lengths=None, store_parse_masks=False, **kwargs): # before: sentences and dist: <batch x maxlen x 2> (2 = |{prm, hyp}|) # Useful when investigating dynamic batching: # self.seq_lengths = sentences.shape[1] - (sentences == 0).sum(1) orig_example_lengths = example_lengths # <maxlen x 2> x, example_lengths = self.unwrap(sentences, orig_example_lengths) if dist is not None: dist, _ = self.unwrap(dist, orig_example_lengths) # gone to gpu # after: x and dist: < numSent x maxlen >, numSent = batch x 2 emb = self.run_embed(x) # <numSent, maxlen, dim> batch_size, seq_len, model_dim = emb.data.size() example_lengths_var = to_gpu( Variable(torch.from_numpy(example_lengths))).long() # <numSent> # self.binary_tree_lstm.sbs_loss = 0.0 # self.binary_tree_lstm.sbs_acc = 0 hh, _, masks, temperature = self.binary_tree_lstm( emb, example_lengths_var, dist=dist, temperature_multiplier=pyramid_temperature_multiplier) if self.training: self.temperature_to_display = temperature # if self.binary_tree_lstm.sbs_acc.cpu().data.numpy() > self.binary_tree_lstm.n_total.cpu().data.numpy(): # print 'acc', self.binary_tree_lstm.sbs_acc.data # print 'total', self.binary_tree_lstm.n_total.data # sys.exit(0) self.sbs_loss = self.binary_tree_lstm.sbs_loss / self.binary_tree_lstm.n_total.float( ) self.sbs_acc = self.binary_tree_lstm.sbs_acc / self.binary_tree_lstm.n_total.float( ) #TODO: sbs_acc may not divided by num at this moment if store_parse_masks: self.mask_memory = [mask.data.cpu().numpy() for mask in masks] h = self.wrap(hh) output = self.mlp(self.build_features(h)) return output
def build_baseline(self, rewards, sentences, transitions, y_batch=None, embeds=None): if self.rl_baseline == "ema": mu = self.rl_mu baseline = self.baseline[0] self.baseline[0] = self.baseline[0] * (1 - mu) + rewards.mean() * mu elif self.rl_baseline == "pass": baseline = 0. elif self.rl_baseline == "greedy": # Pass inputs to Greedy Max output = self.run_greedy(sentences, transitions) # Estimate Reward probs = F.softmax(output).data.cpu() target = torch.from_numpy(y_batch).long() approx_rewards = self.build_reward(probs, target, rl_reward=self.rl_reward) baseline = approx_rewards elif self.rl_baseline == "value": output = self.baseline_outp if self.rl_reward == "standard": baseline = F.sigmoid(output) self.value_loss = nn.BCELoss()( baseline, to_gpu(Variable(rewards, volatile=not self.training))) elif self.rl_reward == "xent": baseline = output self.value_loss = nn.MSELoss()( baseline, to_gpu(Variable(rewards, volatile=not self.training))) else: raise NotImplementedError baseline = baseline.data.cpu() else: raise NotImplementedError return baseline
def mc_reinforce(self, rewards, baseline): t_preds = np.concatenate( [m['t_preds'] for m in self.spinn.memories if 't_preds' in m]) t_mask = np.concatenate( [m['t_mask'] for m in self.spinn.memories if 't_mask' in m]) t_valid_mask = np.concatenate( [m['t_valid_mask'] for m in self.spinn.memories if 't_mask' in m]) t_logprobs = torch.cat([ m['t_logprobs'] for m in self.spinn.memories if 't_logprobs' in m ], 0) if self.use_sentence_pair: # Handles the case of SNLI where each reward is used for two # sentences. rewards = torch.cat([rewards, rewards], 0) baseline = torch.cat([baseline, baseline], 0) #t_logprobs=t_logprobs.view(1,-1) #p_actions=t_logprobs[:,0].long() advantage = -1 * (rewards - baseline) batch_size = advantage.size(0) seq_length = t_preds.shape[0] / batch_size a_index = np.arange(batch_size) a_index = a_index.reshape(1, -1).repeat(seq_length, axis=0).flatten() a_index = torch.from_numpy(a_index[t_mask]).long() t_index = to_gpu( Variable(torch.from_numpy(np.arange( t_mask.shape[0])[t_mask])).long()) t_logprobs = torch.index_select(t_logprobs, 0, t_index) #p_actions = torch.index_select(p_actions, 0, a_index) actions = to_gpu( Variable(torch.from_numpy(t_preds[t_mask]).long().view(-1, 1), volatile=not self.training)) log_p_action = torch.gather(t_logprobs, 1, actions) advantage = torch.index_select(advantage, 0, a_index) policy_loss = to_gpu(Variable(advantage.long().view( 1, -1))) * log_p_action.view(-1).long() print( torch.max(advantage.long().view(1, -1) * log_p_action.view(-1).long())) policy_loss = torch.sum(policy_loss.float()) / log_p_action.size(0) #print(policy_loss) return policy_loss * 0.000121392198451
def build_example(self, sentences, transitions): batch_size = sentences.shape[0] # Build Tokens x_prem = sentences[:, :, 0] x_hyp = sentences[:, :, 1] x = np.concatenate([x_prem, x_hyp], axis=0) return to_gpu(Variable(torch.from_numpy(x), volatile=not self.training))
def unwrap_sentence_pair(self, sentences, lengths=None): x_prem = sentences[:, :, 0] x_hyp = sentences[:, :, 1] x = np.concatenate([x_prem, x_hyp], axis=0) if lengths is not None: len_prem = lengths[:, 0] len_hyp = lengths[:, 1] lengths = np.concatenate([len_prem, len_hyp], axis=0) return to_gpu(Variable(torch.from_numpy(x), volatile=not self.training)), lengths
def unwrap_sentence(self, sentences, transitions): # Build Tokens x = sentences # Build Transitions t = transitions example = Example() example.tokens = to_gpu(Variable(torch.from_numpy(x), volatile=not self.training)) example.transitions = t return example
def run_rnn(self, x): batch_size, seq_len, _ = x.data.size() num_layers = 1 bidirectional = self.bidirectional bi = 2 if bidirectional else 1 h0 = Variable(to_gpu( torch.zeros(num_layers * bi, batch_size, self.model_dim / bi)), volatile=not self.training) c0 = Variable(to_gpu( torch.zeros(num_layers * bi, batch_size, self.model_dim / bi)), volatile=not self.training) # Expects (input, h_0): # input => batch_size x seq_len x model_dim # h_0 => (num_layers x num_directions[1,2]) x batch_size x model_dim # c_0 => (num_layers x num_directions[1,2]) x batch_size x model_dim output, (hn, cn) = self.rnn(x, (h0, c0)) hn = hn.transpose(0, 1).contiguous().view(batch_size, -1) return hn
def auxiliary_loss(model): has_spinn = hasattr(model, 'spinn') has_policy = has_spinn and hasattr(model, 'policy_loss') has_value = has_spinn and hasattr(model, 'value_loss') total_loss = to_gpu(Variable(torch.Tensor([0.0]))) if has_policy: total_loss += model.policy_loss if has_value: total_loss += model.value_loss return total_loss
def build_example(self, sentences, transitions): batch_size = sentences.shape[0] # Build Tokens x = sentences # Build Transitions t = transitions example = Example() example.tokens = to_gpu(Variable(torch.from_numpy(x), volatile=not self.training)) example.transitions = t return example
def unwrap_sentence_pair(self, sentences, transitions): # Build Tokens x_prem = sentences[:, :, 0] x_hyp = sentences[:, :, 1] x = np.concatenate([x_prem, x_hyp], axis=0) # Build Transitions t_prem = transitions[:, :, 0] t_hyp = transitions[:, :, 1] t = np.concatenate([t_prem, t_hyp], axis=0) example = Example() example.tokens = to_gpu(Variable(torch.from_numpy(x), volatile=not self.training)) example.transitions = t return example
def reset_decoder(self, example): """Run decoder on input to initialize rnn states.""" batch_size = len(example.bufs) # TODO: Would prefer to run decoder forwards or backwards? batch = torch.cat([torch.cat(b, 0).unsqueeze(0) for b in example.bufs], 0) init = to_gpu( Variable(torch.zeros(1, batch_size, self.decoder_dim), volatile=not self.training)) self.dec_h = list(torch.chunk(init, batch_size, 1)) self.dec_c = list(torch.chunk(init, batch_size, 1)) # TODO: Right now the decoder runs over the entire sentence, which is a bit like cheating! self.run_decoder_rnn(range(batch_size), batch)
def build_example(self, sentences, transitions): batch_size = sentences.shape[0] # sentences: (#batches, #feature, #2) # Build Tokens x_prem = sentences[:,:,0] x_hyp = sentences[:,:,1] x = np.concatenate([x_prem, x_hyp], axis=0) # Build Transitions t_prem = transitions[:,:,0] t_hyp = transitions[:,:,1] t = np.concatenate([t_prem, t_hyp], axis=0) example = Example() example.tokens = to_gpu(Variable(torch.from_numpy(x), volatile=not self.training)) example.transitions = t return example
def t_reduce(self, buf, stack, tracking, lefts, rights, trackings): """REDUCE: Should compose top two items of the stack into new item.""" # The right-most input will be popped first. for reduce_inp in [rights, lefts]: if len(stack) > 0: reduce_inp.append(stack.pop()) else: if self.debug: raise IndexError # If we try to Reduce, but there are less than 2 items on the stack, # then treat any available item as the right input, and use zeros # for any other inputs. # NOTE: Only happens on cropped data. zeros = to_gpu(Variable( torch.from_numpy(np.zeros(buf[0].size(), dtype=np.float32)), volatile=buf[0].volatile)) reduce_inp.append(zeros) trackings.append(tracking)
def build_baseline(self, output, rewards, sentences, transitions, y_batch=None): if self.rl_baseline == "ema": mu = self.rl_mu self.baseline[0] = self.baseline[0] * (1 - mu) + rewards.mean() * mu baseline = self.baseline[0] elif self.rl_baseline == "policy": # Pass inputs to Policy Net policy_outp = self.policy(sentences, transitions) # Estimate Reward policy_prob = policy_outp # Save MSE Loss using Reward as target self.policy_loss = nn.MSELoss()( policy_prob, to_gpu(Variable(rewards, volatile=not self.training))) baseline = policy_prob.data.cpu() elif self.rl_baseline == "greedy": # Pass inputs to Greedy Max greedy_outp = self.run_greedy(sentences, transitions) # Estimate Reward logits = F.softmax(output).data.cpu() target = torch.from_numpy(y_batch).long() greedy_rewards = self.build_reward(logits, target) baseline = greedy_rewards else: raise NotImplementedError return baseline
def reinforce(self, rewards): t_preds, t_logits, t_given, t_mask = self.spinn.get_statistics() # TODO: Many of these ops are on the cpu. Might be worth shifting to GPU. if self.use_sentence_pair: # Handles the case of SNLI where each reward is used for two sentences. rewards = torch.cat([rewards, rewards], 0) # Expand rewards. if not self.spinn.use_skips: rewards = rewards.index_select(0, torch.from_numpy(t_mask).long()) else: raise NotImplementedError log_p_action = torch.cat( [t_logits[i, p] for i, p in enumerate(t_preds)], 0) rl_loss = -1. * torch.sum(log_p_action * to_gpu( Variable(rewards, volatile=log_p_action.volatile))) rl_loss /= log_p_action.size(0) rl_loss *= self.rl_weight return rl_loss
def forward(self, top_buf, top_stack_1, top_stack_2): tracker_inp = self.buf(top_buf.h) tracker_inp += self.stack1(top_stack_1.h) tracker_inp += self.stack2(top_stack_2.h) batch_size = tracker_inp.size(0) if self.lateral_tracking: if self.h is not None: tracker_inp += self.lateral(self.h) if self.c is None: self.c = to_gpu(Variable(torch.from_numpy( np.zeros((batch_size, self.state_size), dtype=np.float32)), volatile=tracker_inp.volatile)) # Run tracking lstm. self.c, self.h = lstm(self.c, tracker_inp) return self.h, self.c else: outp = self.transform(tracker_inp) return outp, None
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators, logger): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) log_entry = pb.SpinnEntry() for _ in range(trainer.step, FLAGS.training_steps): if (trainer.step - trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait: logger.Log('No improvement after ' + str(FLAGS.early_stopping_steps_to_wait) + ' steps. Stopping training.') break model.train() log_entry.Clear() log_entry.step = trainer.step should_log = False start = time.time() batch = get_batch(next(training_data_iter)) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt + 1) / 2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. trainer.optimizer_zero_grad() temperature = math.sin( math.pi / 2 + trainer.step / float(FLAGS.rl_confidence_interval) * 2 * math.pi) temperature = (temperature + 1) / 2 # Confidence Penalty for Transition Predictions. if FLAGS.rl_confidence_penalty: epsilon = FLAGS.rl_epsilon * \ math.exp(-trainer.step / float(FLAGS.rl_epsilon_decay)) temp = 1 + \ (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon model.spinn.temperature = max(1e-3, temp) # Soft Wake/Sleep based on temperature. if FLAGS.rl_wake_sleep: model.rl_weight = temperature * FLAGS.rl_weight # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() # get the index of the max log-probability pred = output.data.max(1, keepdim=False)[1].cpu() class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.CrossEntropyLoss()(output, to_gpu( Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr( model, 'transition_loss') else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss aux_loss = auxiliary_loss(model) total_loss += aux_loss # Backward pass. total_loss.backward() # Hard Gradient Clipping nn.utils.clip_grad_norm([ param for name, param in model.named_parameters() if name not in ["embed.embed.weight"] ], FLAGS.clipping_max_value) # Gradient descent step. trainer.optimizer_step() end = time.time() total_time = end - start train_accumulate(model, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) train_rl_accumulate(model, A, batch) if trainer.step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) stats(model, trainer, A, log_entry) should_log = True if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: should_log = True model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example( ) model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example( ) if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate( [transitions_batch[:, :, 0], transitions_batch[:, :, 1]], axis=0) # This could be done prior to running the batch for a tiny speed # boost. t_idxs = list(range(FLAGS.num_samples)) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: log = log_entry.rl_sampling.add() gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] strength_tr = sparks([1] + tr_strength[t_idx].tolist(), dec_str) strength_ev = sparks([1] + ev_strength[t_idx].tolist(), dec_str) _, crossing = evalb.crossing(gold, pred) log.t_idx = t_idx log.crossing = crossing log.gold_lb = "".join(map(str, gold)) log.pred_tr = "".join(map(str, pred_tr)) log.pred_ev = "".join(map(str, pred_ev)) log.strg_tr = strength_tr[1:] log.strg_ev = strength_ev[1:] if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0: should_log = True for index, eval_set in enumerate(eval_iterators): acc, _ = evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, eval_index=index) if index == 0: trainer.new_dev_accuracy(acc) progress_bar.reset() if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0: should_log = True trainer.checkpoint() if should_log: logger.LogEntry(log_entry) progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) + 1, total=FLAGS.statistics_interval_steps)
def unwrap_sentence(self, sentences, lengths=None): return to_gpu( Variable(torch.from_numpy(sentences), volatile=not self.training)), lengths
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error): # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) M = MetricsWriter(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name)) # Checkpoint paths. standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch(training_data_iter.next()) model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) logger.Log("") logger.Log("# ----- BEGIN: Log Configuration ----- #") # Preview train string template. train_str = train_format(model) logger.Log("Train-Format: {}".format(train_str)) train_extra_str = train_extra_format(model) logger.Log("Train-Extra-Format: {}".format(train_extra_str)) # Preview eval string template. eval_str = eval_format(model) logger.Log("Eval-Format: {}".format(eval_str)) eval_extra_str = eval_extra_format(model) logger.Log("Eval-Extra-Format: {}".format(eval_extra_str)) logger.Log("# ----- END: Log Configuration ----- #") logger.Log("") # Train. logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss total_loss += auxiliary_loss(model) # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats_args = train_stats(model, optimizer, A, step) train_metrics(M, stats_args, step) logger.Log(train_str.format(**stats_args)) logger.Log(train_extra_str.format(**stats_args)) if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0: model.train() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example() model.eval() model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example() transition_str = "Samples:" if model.use_sentence_pair and len(transitions_batch.shape) == 3: transitions_batch = np.concatenate([ transitions_batch[:,:,0], transitions_batch[:,:,1]], axis=0) # This could be done prior to running the batch for a tiny speed boost. t_idxs = range(FLAGS.num_samples) random.shuffle(t_idxs) t_idxs = sorted(t_idxs[:FLAGS.num_samples]) for t_idx in t_idxs: gold = transitions_batch[t_idx] pred_tr = tr_transitions_per_example[t_idx] pred_ev = ev_transitions_per_example[t_idx] stength_tr = sparks([1] + tr_strength[t_idx].tolist()) stength_ev = sparks([1] + ev_strength[t_idx].tolist()) _, crossing = evalb.crossing(gold, pred) transition_str += "\n{}. crossing={}".format(t_idx, crossing) transition_str += "\n g{}".format("".join(map(str, gold))) transition_str += "\n {}".format(stength_tr[1:].encode('utf-8')) transition_str += "\n pt{}".format("".join(map(str, pred_tr))) transition_str += "\n {}".format(stength_ev[1:].encode('utf-8')) transition_str += "\n pe{}".format("".join(map(str, pred_ev))) logger.Log(transition_str) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc, tacc = evaluate(FLAGS, model, data_manager, eval_set, index, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)