Esempio n. 1
0
    def pre_decode(self,
                   encoder_outputs,
                   tgt_seq,
                   extra_states=None,
                   src_mask=None,
                   tgt_mask=None):
        """Prepare the context and initial states for decoding.
        """
        feedback_embeds = self.lookup_feedback(tgt_seq)

        B = tgt_seq.shape[0]
        context = encoder_outputs
        states = MapDict({"t": 0})
        context["feedbacks"] = tgt_seq
        context["feedback_embeds"] = feedback_embeds
        # Process initial states
        if self._stepwise_training:
            for state_name, size in zip(self._state_names, self._state_sizes):
                if "init_{}".format(state_name) in context:
                    states[state_name] = context["init_{}".format(state_name)]
                    if len(states[state_name].shape) == 2:
                        states[state_name] = states[state_name].unsqueeze(0)
                    del context["init_{}".format(state_name)]
                else:
                    states[state_name] = Variable(torch.zeros((1, B, size)))
                    if torch.cuda.is_available():
                        states[state_name] = states[state_name].cuda()
        if extra_states is not None:
            states.update(extra_states)
        # Process mask
        context["src_mask"] = src_mask
        context["tgt_mask"] = tgt_mask
        return context, states
Esempio n. 2
0
 def combine_states(self, hyps):
     """Batch all states in different hyptheses.
     """
     states = MapDict()
     # Combine states
     for name in self.model.state_names():
         states[name] = torch.cat([h["state"][name] for h in hyps], 1)
     # Combine last tokens
     last_tokens = torch.tensor([h["tokens"][-1] for h in hyps])
     if torch.cuda.is_available():
         last_tokens = last_tokens.cuda()
         states.feedback_embed = self.model.lookup_feedback(last_tokens)
     return states
Esempio n. 3
0
 def initialize_hyps(self, encoder_outputs, items=None):
     """Initialize the first hypothesis for beam search.
     """
     final_hyps = []
     states = MapDict()
     # Create initial states
     for name, size in zip(self.model.state_names(), self.model.state_sizes()):
         if "init_{}".format(name) in encoder_outputs:
             states[name] = encoder_outputs["init_{}".format(name)]
             if len(states[name].shape) == 2:
                 states[name] = states[name].unsqueeze(0)
         else:
             states[name] = torch.zeros((1, 1, size))
             if torch.cuda.is_available():
                 states[name] = states[name].cuda()
     # Create first hypthesis
     first_hyp = {
         "state": states,
         "tokens": [self.start_token_id],
         "score": 0.
     }
     if items:
         first_hyp.update(items)
     hyps = [first_hyp]
     return hyps, final_hyps
Esempio n. 4
0
 def expand_hyps(self, hyps, new_states, batch_scores, sort=True, expand_num=None):
     """Create B x B new hypotheses
     """
     if not expand_num:
         expand_num = self.beam_size
     new_hyps = []
     best_scores, best_tokens = batch_scores.topk(expand_num)
     for i, hyp in enumerate(hyps):
         new_hyp_state = MapDict()
         for sname in self.model.state_names():
             new_hyp_state[sname] = new_states[sname][:, i, :].unsqueeze(1)
         new_scores = best_scores[i].cpu().detach().numpy().tolist()
         new_tokens = best_tokens[i].cpu().detach().numpy().tolist()
         for new_token, new_score in zip(new_tokens, new_scores):
             new_hyp = {
                 "state": new_hyp_state,
                 "tokens": hyp["tokens"] + [new_token],
                 "score": new_score + hyp["score"],
                 "last_token_score": new_score,
                 "old_state": hyp["state"]
             }
             new_hyp = self.fix_new_hyp(i, hyp, new_hyp)
             # Keep old information
             for key in hyp:
                 if key not in new_hyp:
                     new_hyp[key] = copy.copy(hyp[key])
             new_hyps.append(new_hyp)
     if sort:
         new_hyps.sort(key=lambda h: h["score"], reverse=True)
     return new_hyps
Esempio n. 5
0
 def __init__(self,
              model,
              source_vocab,
              target_vocab,
              start_token="<s>",
              end_token="</s>",
              beam_size=5,
              length_norm=False,
              opts=None,
              device=None):
     assert isinstance(model, EncoderDecoderModel)
     # Iniliatize horovod for multigpu translate
     self._is_multigpu = False
     try:
         import horovod.torch as hvd
         hvd.init()
         if torch.cuda.is_available():
             torch.cuda.set_device(hvd.local_rank())
             self._is_multigpu = True
     except ImportError:
         pass
     if torch.cuda.is_available():
         model.cuda(device)
     self.length_norm = length_norm
     self.model = model
     self.source_vocab = source_vocab
     self.target_vocab = target_vocab
     self.start_token = start_token
     self.end_token = end_token
     self.start_token_id = self.source_vocab.encode_token(start_token)
     self.end_token_id = self.target_vocab.encode_token(end_token)
     self.opts = MapDict(opts) if opts else opts
     self.beam_size = beam_size
     self.prepare()
Esempio n. 6
0
 def forward(self, src_seq, tgt_seq, sampling=False):
     """Forward to compute the loss.
     """
     sampling = False
     src_mask = self.to_float(torch.ne(src_seq, 0))
     tgt_mask = self.to_float(torch.ne(tgt_seq, 0))
     encoder_outputs = MapDict(self.encode(src_seq, src_mask))
     context, states = self.pre_decode(encoder_outputs,
                                       tgt_seq,
                                       src_mask=src_mask,
                                       tgt_mask=tgt_mask)
     decoder_outputs = self.decode(context, states)
     if self._shard_size is not None and self._shard_size > 0:
         from nmtlab.utils.distributed import local_rank
         if local_rank() == 0:
             from line_profiler import LineProfiler
             lp = LineProfiler()
             compute_shard_loss = lp(self.compute_shard_loss)
         else:
             compute_shard_loss = self.compute_shard_loss
         compute_shard_loss(decoder_outputs, tgt_seq, tgt_mask)
     else:
         logits = self.expand(decoder_outputs)
         loss = self.compute_loss(logits, tgt_seq, tgt_mask)
         acc = self.compute_word_accuracy(logits, tgt_seq, tgt_mask)
         self.monitor("loss", loss)
         self.monitor("word_acc", acc)
     if sampling:
         context, states = self.pre_decode(encoder_outputs,
                                           tgt_seq,
                                           src_mask=src_mask,
                                           tgt_mask=tgt_mask)
         sample_outputs = self.decode(context, states, sampling=True)
         self.monitor("sampled_tokens", sample_outputs.prev_token)
     return self._monitors
Esempio n. 7
0
 def forward(self, src_seq, tgt_seq, sampling=False):
     """Forward to compute the loss.
     """
     sampling = False
     src_mask = torch.ne(src_seq, 0).float()
     tgt_mask = torch.ne(tgt_seq, 0).float()
     encoder_outputs = MapDict(self.encode(src_seq, src_mask))
     context, states = self.pre_decode(encoder_outputs,
                                       tgt_seq,
                                       src_mask=src_mask,
                                       tgt_mask=tgt_mask)
     decoder_outputs = self.decode(context, states)
     if self._shard_size is not None and self._shard_size > 0:
         self.compute_shard_loss(decoder_outputs, tgt_seq, tgt_mask)
     else:
         logits = self.expand(decoder_outputs)
         loss = self.compute_loss(logits, tgt_seq, tgt_mask)
         acc = self.compute_word_accuracy(logits, tgt_seq, tgt_mask)
         self.monitor("loss", loss)
         self.monitor("word_acc", acc)
     if sampling:
         context, states = self.pre_decode(encoder_outputs,
                                           tgt_seq,
                                           src_mask=src_mask,
                                           tgt_mask=tgt_mask)
         sample_outputs = self.decode(context, states, sampling=True)
         self.monitor("sampled_tokens", sample_outputs.prev_token)
     return self._monitors
Esempio n. 8
0
 def combine_states(self, t, hyps):
     """Batch all states in different hyptheses.
     Args:
         t - time step
         hyps - hypotheses
     """
     states = MapDict({"t": t})
     # Combine states
     for name in self.model.state_names():
         states[name] = torch.cat([h["state"][name] for h in hyps], 1)
     # Combine last tokens
     last_tokens = torch.tensor([h["tokens"][-1] for h in hyps])
     if torch.cuda.is_available():
         last_tokens = last_tokens.cuda()
     states.prev_token = last_tokens.unsqueeze(0)
     states.feedback_embed = self.model.lookup_feedback(last_tokens)
     return states
Esempio n. 9
0
 def encode(self, input_tokens):
     """Run the encoder to get context.
     """
     input_tensor = torch.tensor([input_tokens])
     if torch.cuda.is_available():
         input_tensor = input_tensor.cuda()
     input_mask = torch.gt(input_tensor, 0)
     encoder_outputs = self.model.encode(input_tensor, input_mask)
     return MapDict(encoder_outputs)
Esempio n. 10
0
 def test_stepwise_graph(self):
     input_seq = torch.randint(0, 100, (3, 5)).long()
     target_seq = torch.randint(0, 100, (3, 2)).long()
     src_mask = input_seq.clone().fill_(1)
     context = self.model.encode(input_seq, src_mask)
     self.model.set_stepwise_training(False)
     context, states = self.model.pre_decode(context,
                                             target_seq,
                                             src_mask=src_mask)
     context = MapDict(context)
     full_states = self.model.decode(context, states, False)
     self.model.set_stepwise_training(True)
     context, states = self.model.pre_decode(context,
                                             target_seq,
                                             src_mask=src_mask)
     context = MapDict(context)
     stepwise_states = self.model.decode(context, states, False)
     self.assertEqual(
         torch.eq(full_states.final_hidden,
                  stepwise_states.final_hidden).prod().numpy(), 1)
Esempio n. 11
0
 def test_stepwise_graph(self):
     input_seq = torch.randint(0, 100, (3, 5)).long()
     target_seq = torch.randint(0, 100, (3, 6)).long()
     src_mask = input_seq.clone().fill_(1)
     context = self.model.encode(input_seq, src_mask)
     self.model.set_stepwise_training(False)
     context, states = self.model.pre_decode(context,
                                             target_seq,
                                             src_mask=src_mask)
     context = MapDict(context)
     full_states = self.model.decode(context, states, False)
     self.model.set_stepwise_training(True)
     context, states = self.model.pre_decode(context,
                                             target_seq,
                                             src_mask=src_mask)
     context = MapDict(context)
     stepwise_states = self.model.decode(context, states, False)
     self.assertAlmostEqual(float(
         (full_states.final_states - stepwise_states.final_states).sum()),
                            0,
                            delta=0.0001)
Esempio n. 12
0
 def __init__(self, model, source_vocab, target_vocab, start_token="<s>", end_token="</s>", beam_size=5, opts=None):
     assert isinstance(model, EncoderDecoderModel)
     if torch.cuda.is_available():
         model.cuda()
     self.model = model
     self.source_vocab = source_vocab
     self.target_vocab = target_vocab
     self.start_token = start_token
     self.end_token = end_token
     self.start_token_id = self.source_vocab.encode_token(start_token)
     self.end_token_id = self.target_vocab.encode_token(end_token)
     self.opts = MapDict(opts) if opts else opts
     self.beam_size = beam_size
     self.prepare()
Esempio n. 13
0
 def forward(self, src_seq, tgt_seq):
     with torch.no_grad():
         src_mask = torch.ne(src_seq, 0).float()
         tgt_mask = torch.ne(tgt_seq, 0).float()
         encoder_outputs = MapDict(self.encode(src_seq, src_mask))
         context, states = self.pre_decode(encoder_outputs,
                                           tgt_seq,
                                           src_mask=src_mask,
                                           tgt_mask=tgt_mask)
         decoder_outputs = self.decode(context, states)
         logits = self.expand(decoder_outputs)
         logp = torch.log_softmax(logits, 2)
         flat_logp = logp.view(-1, self._tgt_vocab_size)
         flat_tgt = tgt_seq[:, 1:].flatten()
         logp = -torch.nn.functional.nll_loss(
             flat_logp, flat_tgt, reduction="none")
         logp = logp.view(tgt_seq.shape[0], tgt_seq.shape[1] - 1)
         scores = (logp * tgt_mask[:, 1:]).sum(1)
     return scores.cpu().numpy()
Esempio n. 14
0
 def forward(self, src_seq, tgt_seq, sampling=False):
     """
     Forward to compute the loss.
     """
     src_mask = torch.ne(src_seq, 0)
     tgt_mask = torch.ne(tgt_seq, 0)
     encoder_outputs = MapDict(self.encode(src_seq, src_mask))
     context, states = self.pre_decode(encoder_outputs,
                                       tgt_seq,
                                       src_mask=src_mask,
                                       tgt_mask=tgt_mask)
     decoder_outputs = self.decode(context, states)
     logits = self.expand(decoder_outputs)
     if sampling:
         context, states = self.pre_decode(encoder_outputs,
                                           tgt_seq,
                                           src_mask=src_mask,
                                           tgt_mask=tgt_mask)
         sample_outputs = self.decode(context, states, sampling=True)
         self.monitor("sampled_tokens", sample_outputs.sampled_token)
     loss = self.compute_loss(logits, tgt_seq, tgt_mask)
     self.monitor("loss", loss)
     return self._monitors