def pre_decode(self, encoder_outputs, tgt_seq, extra_states=None, src_mask=None, tgt_mask=None): """Prepare the context and initial states for decoding. """ feedback_embeds = self.lookup_feedback(tgt_seq) B = tgt_seq.shape[0] context = encoder_outputs states = MapDict({"t": 0}) context["feedbacks"] = tgt_seq context["feedback_embeds"] = feedback_embeds # Process initial states if self._stepwise_training: for state_name, size in zip(self._state_names, self._state_sizes): if "init_{}".format(state_name) in context: states[state_name] = context["init_{}".format(state_name)] if len(states[state_name].shape) == 2: states[state_name] = states[state_name].unsqueeze(0) del context["init_{}".format(state_name)] else: states[state_name] = Variable(torch.zeros((1, B, size))) if torch.cuda.is_available(): states[state_name] = states[state_name].cuda() if extra_states is not None: states.update(extra_states) # Process mask context["src_mask"] = src_mask context["tgt_mask"] = tgt_mask return context, states
def combine_states(self, hyps): """Batch all states in different hyptheses. """ states = MapDict() # Combine states for name in self.model.state_names(): states[name] = torch.cat([h["state"][name] for h in hyps], 1) # Combine last tokens last_tokens = torch.tensor([h["tokens"][-1] for h in hyps]) if torch.cuda.is_available(): last_tokens = last_tokens.cuda() states.feedback_embed = self.model.lookup_feedback(last_tokens) return states
def initialize_hyps(self, encoder_outputs, items=None): """Initialize the first hypothesis for beam search. """ final_hyps = [] states = MapDict() # Create initial states for name, size in zip(self.model.state_names(), self.model.state_sizes()): if "init_{}".format(name) in encoder_outputs: states[name] = encoder_outputs["init_{}".format(name)] if len(states[name].shape) == 2: states[name] = states[name].unsqueeze(0) else: states[name] = torch.zeros((1, 1, size)) if torch.cuda.is_available(): states[name] = states[name].cuda() # Create first hypthesis first_hyp = { "state": states, "tokens": [self.start_token_id], "score": 0. } if items: first_hyp.update(items) hyps = [first_hyp] return hyps, final_hyps
def expand_hyps(self, hyps, new_states, batch_scores, sort=True, expand_num=None): """Create B x B new hypotheses """ if not expand_num: expand_num = self.beam_size new_hyps = [] best_scores, best_tokens = batch_scores.topk(expand_num) for i, hyp in enumerate(hyps): new_hyp_state = MapDict() for sname in self.model.state_names(): new_hyp_state[sname] = new_states[sname][:, i, :].unsqueeze(1) new_scores = best_scores[i].cpu().detach().numpy().tolist() new_tokens = best_tokens[i].cpu().detach().numpy().tolist() for new_token, new_score in zip(new_tokens, new_scores): new_hyp = { "state": new_hyp_state, "tokens": hyp["tokens"] + [new_token], "score": new_score + hyp["score"], "last_token_score": new_score, "old_state": hyp["state"] } new_hyp = self.fix_new_hyp(i, hyp, new_hyp) # Keep old information for key in hyp: if key not in new_hyp: new_hyp[key] = copy.copy(hyp[key]) new_hyps.append(new_hyp) if sort: new_hyps.sort(key=lambda h: h["score"], reverse=True) return new_hyps
def __init__(self, model, source_vocab, target_vocab, start_token="<s>", end_token="</s>", beam_size=5, length_norm=False, opts=None, device=None): assert isinstance(model, EncoderDecoderModel) # Iniliatize horovod for multigpu translate self._is_multigpu = False try: import horovod.torch as hvd hvd.init() if torch.cuda.is_available(): torch.cuda.set_device(hvd.local_rank()) self._is_multigpu = True except ImportError: pass if torch.cuda.is_available(): model.cuda(device) self.length_norm = length_norm self.model = model self.source_vocab = source_vocab self.target_vocab = target_vocab self.start_token = start_token self.end_token = end_token self.start_token_id = self.source_vocab.encode_token(start_token) self.end_token_id = self.target_vocab.encode_token(end_token) self.opts = MapDict(opts) if opts else opts self.beam_size = beam_size self.prepare()
def forward(self, src_seq, tgt_seq, sampling=False): """Forward to compute the loss. """ sampling = False src_mask = self.to_float(torch.ne(src_seq, 0)) tgt_mask = self.to_float(torch.ne(tgt_seq, 0)) encoder_outputs = MapDict(self.encode(src_seq, src_mask)) context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask) decoder_outputs = self.decode(context, states) if self._shard_size is not None and self._shard_size > 0: from nmtlab.utils.distributed import local_rank if local_rank() == 0: from line_profiler import LineProfiler lp = LineProfiler() compute_shard_loss = lp(self.compute_shard_loss) else: compute_shard_loss = self.compute_shard_loss compute_shard_loss(decoder_outputs, tgt_seq, tgt_mask) else: logits = self.expand(decoder_outputs) loss = self.compute_loss(logits, tgt_seq, tgt_mask) acc = self.compute_word_accuracy(logits, tgt_seq, tgt_mask) self.monitor("loss", loss) self.monitor("word_acc", acc) if sampling: context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask) sample_outputs = self.decode(context, states, sampling=True) self.monitor("sampled_tokens", sample_outputs.prev_token) return self._monitors
def forward(self, src_seq, tgt_seq, sampling=False): """Forward to compute the loss. """ sampling = False src_mask = torch.ne(src_seq, 0).float() tgt_mask = torch.ne(tgt_seq, 0).float() encoder_outputs = MapDict(self.encode(src_seq, src_mask)) context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask) decoder_outputs = self.decode(context, states) if self._shard_size is not None and self._shard_size > 0: self.compute_shard_loss(decoder_outputs, tgt_seq, tgt_mask) else: logits = self.expand(decoder_outputs) loss = self.compute_loss(logits, tgt_seq, tgt_mask) acc = self.compute_word_accuracy(logits, tgt_seq, tgt_mask) self.monitor("loss", loss) self.monitor("word_acc", acc) if sampling: context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask) sample_outputs = self.decode(context, states, sampling=True) self.monitor("sampled_tokens", sample_outputs.prev_token) return self._monitors
def combine_states(self, t, hyps): """Batch all states in different hyptheses. Args: t - time step hyps - hypotheses """ states = MapDict({"t": t}) # Combine states for name in self.model.state_names(): states[name] = torch.cat([h["state"][name] for h in hyps], 1) # Combine last tokens last_tokens = torch.tensor([h["tokens"][-1] for h in hyps]) if torch.cuda.is_available(): last_tokens = last_tokens.cuda() states.prev_token = last_tokens.unsqueeze(0) states.feedback_embed = self.model.lookup_feedback(last_tokens) return states
def encode(self, input_tokens): """Run the encoder to get context. """ input_tensor = torch.tensor([input_tokens]) if torch.cuda.is_available(): input_tensor = input_tensor.cuda() input_mask = torch.gt(input_tensor, 0) encoder_outputs = self.model.encode(input_tensor, input_mask) return MapDict(encoder_outputs)
def test_stepwise_graph(self): input_seq = torch.randint(0, 100, (3, 5)).long() target_seq = torch.randint(0, 100, (3, 2)).long() src_mask = input_seq.clone().fill_(1) context = self.model.encode(input_seq, src_mask) self.model.set_stepwise_training(False) context, states = self.model.pre_decode(context, target_seq, src_mask=src_mask) context = MapDict(context) full_states = self.model.decode(context, states, False) self.model.set_stepwise_training(True) context, states = self.model.pre_decode(context, target_seq, src_mask=src_mask) context = MapDict(context) stepwise_states = self.model.decode(context, states, False) self.assertEqual( torch.eq(full_states.final_hidden, stepwise_states.final_hidden).prod().numpy(), 1)
def test_stepwise_graph(self): input_seq = torch.randint(0, 100, (3, 5)).long() target_seq = torch.randint(0, 100, (3, 6)).long() src_mask = input_seq.clone().fill_(1) context = self.model.encode(input_seq, src_mask) self.model.set_stepwise_training(False) context, states = self.model.pre_decode(context, target_seq, src_mask=src_mask) context = MapDict(context) full_states = self.model.decode(context, states, False) self.model.set_stepwise_training(True) context, states = self.model.pre_decode(context, target_seq, src_mask=src_mask) context = MapDict(context) stepwise_states = self.model.decode(context, states, False) self.assertAlmostEqual(float( (full_states.final_states - stepwise_states.final_states).sum()), 0, delta=0.0001)
def __init__(self, model, source_vocab, target_vocab, start_token="<s>", end_token="</s>", beam_size=5, opts=None): assert isinstance(model, EncoderDecoderModel) if torch.cuda.is_available(): model.cuda() self.model = model self.source_vocab = source_vocab self.target_vocab = target_vocab self.start_token = start_token self.end_token = end_token self.start_token_id = self.source_vocab.encode_token(start_token) self.end_token_id = self.target_vocab.encode_token(end_token) self.opts = MapDict(opts) if opts else opts self.beam_size = beam_size self.prepare()
def forward(self, src_seq, tgt_seq): with torch.no_grad(): src_mask = torch.ne(src_seq, 0).float() tgt_mask = torch.ne(tgt_seq, 0).float() encoder_outputs = MapDict(self.encode(src_seq, src_mask)) context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask) decoder_outputs = self.decode(context, states) logits = self.expand(decoder_outputs) logp = torch.log_softmax(logits, 2) flat_logp = logp.view(-1, self._tgt_vocab_size) flat_tgt = tgt_seq[:, 1:].flatten() logp = -torch.nn.functional.nll_loss( flat_logp, flat_tgt, reduction="none") logp = logp.view(tgt_seq.shape[0], tgt_seq.shape[1] - 1) scores = (logp * tgt_mask[:, 1:]).sum(1) return scores.cpu().numpy()
def forward(self, src_seq, tgt_seq, sampling=False): """ Forward to compute the loss. """ src_mask = torch.ne(src_seq, 0) tgt_mask = torch.ne(tgt_seq, 0) encoder_outputs = MapDict(self.encode(src_seq, src_mask)) context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask) decoder_outputs = self.decode(context, states) logits = self.expand(decoder_outputs) if sampling: context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask) sample_outputs = self.decode(context, states, sampling=True) self.monitor("sampled_tokens", sample_outputs.sampled_token) loss = self.compute_loss(logits, tgt_seq, tgt_mask) self.monitor("loss", loss) return self._monitors