Exemple #1
0
 def reorder_encoder_out(self, encoder_out, new_order):
     (x, src_tokens, encoder_padding_mask) = encoder_out
     src_tokens_tensor = pytorch_translate_utils.get_source_tokens_tensor(
         src_tokens)
     if x is not None:
         x = x.index_select(1, new_order)
     if src_tokens_tensor is not None:
         src_tokens_tensor = src_tokens_tensor.index_select(0, new_order)
     if encoder_padding_mask is not None:
         encoder_padding_mask = encoder_padding_mask.index_select(
             0, new_order)
     return (x, src_tokens_tensor, encoder_padding_mask)
Exemple #2
0
 def generate_hypo(self, repacked_inputs, maxlen_a=0.0, maxlen_b=None):
     if maxlen_b is None:
         maxlen_b = self.maxlen
     src_tokens = repacked_inputs["src_tokens"]
     srclen = pytorch_translate_utils.get_source_tokens_tensor(
         src_tokens).size(1)
     hypos = self.generate(
         repacked_inputs,
         beam_size=self.beam_size,
         maxlen=int(maxlen_a * srclen + maxlen_b),
         # If we need to generate predictions with teacher forcing, this
         # won't work. Right now this is fine.
         prefix_tokens=None,
     )
     return self._pick_hypothesis_unpack_output(hypos)
Exemple #3
0
    def forward(self, src_tokens, src_lengths):
        # Embed tokens
        x = self.embed_tokens(src_tokens)
        src_tokens_tensor = pytorch_translate_utils.get_source_tokens_tensor(src_tokens)
        # Add position embeddings and dropout
        x = self.embed_scale * x
        positions = self.embed_positions(src_tokens_tensor)
        x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask (B x T)
        encoder_padding_mask = src_tokens_tensor.eq(self.padding_idx)
        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        return x, encoder_padding_mask, positions
Exemple #4
0
    def _decode_target(
        self,
        encoder_input,
        encoder_outs,
        incremental_states,
        diversity_sibling_gamma=0.0,
        beam_size=None,
        maxlen=None,
        prefix_tokens=None,
    ):
        src_tokens_tensor = pytorch_translate_utils.get_source_tokens_tensor(
            encoder_input["src_tokens"])
        beam_size = beam_size if beam_size is not None else self.beam_size
        bsz = src_tokens_tensor.size(0)
        reorder_indices = (torch.arange(bsz).view(-1, 1).repeat(
            1, beam_size).view(-1).long())
        for i, model in enumerate(self.models):
            encoder_outs[i] = model.encoder.reorder_encoder_out(
                encoder_out=encoder_outs[i],
                new_order=reorder_indices.type_as(src_tokens_tensor),
            )
        maxlen = min(maxlen,
                     self.maxlen) if maxlen is not None else self.maxlen
        # initialize buffers
        scores = src_tokens_tensor.new(bsz * beam_size,
                                       maxlen + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src_tokens_tensor.new(bsz * beam_size,
                                       maxlen + 2).fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos

        # may differ from input length
        if isinstance(encoder_outs[0], (list, tuple)):
            src_encoding_len = encoder_outs[0][0].size(0)
        elif isinstance(encoder_outs[0], dict):
            if isinstance(encoder_outs[0]["encoder_out"], tuple):
                # Fairseq compatibility
                src_encoding_len = encoder_outs[0]["encoder_out"][0].size(1)
            else:
                src_encoding_len = encoder_outs[0]["encoder_out"].size(0)

        attn = scores.new(bsz * beam_size, src_encoding_len, maxlen + 2)
        attn_buf = attn.clone()

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        worst_finalized = [{
            "idx": None,
            "score": -math.inf
        } for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) *
                        beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfinalized_scores=None):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                if self.stop_early or step == maxlen or unfinalized_scores is None:
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
                best_unfinalized_score = unfinalized_scores[sent].max()
                if self.normalize_scores:
                    best_unfinalized_score /= (maxlen + 1)**self.len_penalty
                if worst_finalized[sent]["score"] >= best_unfinalized_score:
                    return True
            return False

        def finalize_hypos(step,
                           bbsz_idx,
                           eos_scores,
                           unfinalized_scores=None):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
                unfinalized_scores: A vector containing scores for all
                    unfinalized hypotheses
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step +
                                        2]  # skip the first index, which is EOS
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step + 2]

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step + 1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1)**self.len_penalty

            sents_seen = set()
            for i, (idx, score) in enumerate(
                    zip(bbsz_idx.tolist(), eos_scores.tolist())):
                sent = idx // beam_size
                sents_seen.add(sent)

                def get_hypo():
                    _, alignment = attn_clone[i].max(dim=0)
                    return {
                        "tokens": tokens_clone[i],
                        "score": score,
                        "attention": attn_clone[i],  # src_len x tgt_len
                        "alignment": alignment,
                        "positional_scores": pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
                elif not self.stop_early and score > worst_finalized[sent][
                        "score"]:
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]["idx"]
                    if worst_idx is not None:
                        finalized[sent][worst_idx] = get_hypo()

                    # find new worst finalized hypo for this sentence
                    idx, s = min(enumerate(finalized[sent]),
                                 key=lambda r: r[1]["score"])
                    worst_finalized[sent] = {"score": s["score"], "idx": idx}

            # return number of hypotheses finished this step
            num_finished = 0
            for sent in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step,
                                                      unfinalized_scores):
                    finished[sent] = True
                    num_finished += 1
            return num_finished

        reorder_state = None
        for step in range(maxlen + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                for model in self.models:
                    if isinstance(model.decoder, FairseqIncrementalDecoder):
                        model.decoder.reorder_incremental_state(
                            incremental_states[model], reorder_state)
            # Run decoder for one step
            logprobs, avg_attn, possible_translation_tokens = self._decode(
                tokens[:, :step + 1], encoder_outs, incremental_states)

            logprobs[:, self.pad] = -math.inf  # never select pad
            # apply unk reward
            if possible_translation_tokens is None:
                # No vocab reduction, so unk is represented by self.unk at
                # position self.unk
                unk_index = self.unk
                logprobs[:, unk_index] += self.unk_reward
            else:
                # When we use vocab reduction, the token value self.unk may not
                # be at the position self.unk, but somewhere else in the list
                # of possible_translation_tokens. It's also possible not to
                # show up in possible_translation_tokens at all, meaning we
                # can't generate an unk.
                unk_pos = torch.nonzero(
                    possible_translation_tokens == self.unk)
                if unk_pos.size()[0] != 0:
                    # only add unk_reward if unk index appears in
                    # possible_translation_tokens
                    unk_index = unk_pos[0][0]
                    logprobs[:, unk_index] += self.unk_reward
            # external lexicon reward
            logprobs[:, self.lexicon_indices] += self.lexicon_reward

            logprobs += self.word_reward
            logprobs[:, self.eos] -= self.word_reward
            # Record attention scores
            attn[:, :, step + 1].copy_(avg_attn)

            cand_scores = buffer("cand_scores", type_of=scores)
            cand_indices = buffer("cand_indices")
            cand_beams = buffer("cand_beams")
            eos_bbsz_idx = buffer("eos_bbsz_idx")
            eos_scores = buffer("eos_scores", type_of=scores)
            scores = scores.type_as(logprobs)
            scores_buf = scores_buf.type_as(logprobs)

            if step < maxlen:
                if prefix_tokens is not None and step < prefix_tokens.size(1):
                    logprobs_slice = logprobs.view(bsz, -1,
                                                   logprobs.size(-1))[:, 0, :]
                    cand_scores = torch.gather(
                        logprobs_slice,
                        dim=1,
                        index=prefix_tokens[:, step].view(-1, 1)).expand(
                            -1, cand_size)
                    cand_indices = (prefix_tokens[:, step].view(-1, 1).expand(
                        bsz, cand_size))
                    cand_beams.resize_as_(cand_indices).fill_(0)
                else:
                    possible_tokens_size = self.vocab_size
                    if possible_translation_tokens is not None:
                        possible_tokens_size = possible_translation_tokens.size(
                            0)
                    if diversity_sibling_gamma > 0:
                        logprobs = self.diversity_sibling_rank(
                            logprobs.view(bsz, -1, possible_tokens_size),
                            diversity_sibling_gamma,
                        )
                    cand_scores, cand_indices, cand_beams = self.search.step(
                        step,
                        logprobs.view(bsz, -1, possible_tokens_size),
                        scores.view(bsz, beam_size, -1)[:, :, :step],
                    )
                    # vocabulary reduction
                    if possible_translation_tokens is not None:
                        possible_translation_tokens = possible_translation_tokens.view(
                            1, possible_tokens_size).expand(
                                cand_indices.size(0), possible_tokens_size)
                        cand_indices = torch.gather(
                            possible_translation_tokens,
                            dim=1,
                            index=cand_indices,
                            out=cand_indices,
                        )
            else:
                # finalize all active hypotheses once we hit maxlen
                # pick the hypothesis with the highest log prob of EOS right now
                logprobs.add_(scores[:, step - 1].view(-1, 1))
                torch.sort(
                    logprobs[:, self.eos],
                    descending=True,
                    out=(eos_scores, eos_bbsz_idx),
                )
                num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx,
                                                     eos_scores)
                assert num_remaining_sent == 0
                break

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add_(bbsz_offsets)

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)
            if step >= self.minlen:
                # only consider eos when it's among the top beam_size indices
                torch.masked_select(
                    cand_bbsz_idx[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_bbsz_idx,
                )
                if eos_bbsz_idx.numel() > 0:
                    torch.masked_select(
                        cand_scores[:, :beam_size],
                        mask=eos_mask[:, :beam_size],
                        out=eos_scores,
                    )
                    num_remaining_sent -= finalize_hypos(
                        step, eos_bbsz_idx, eos_scores, cand_scores)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < maxlen

            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer("active_mask")
            torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer("active_hypos"), buffer("_ignore")
            torch.topk(
                active_mask,
                k=beam_size,
                dim=1,
                largest=False,
                out=(_ignore, active_hypos),
            )
            active_bbsz_idx = buffer("active_bbsz_idx")
            torch.gather(cand_bbsz_idx,
                         dim=1,
                         index=active_hypos,
                         out=active_bbsz_idx)
            active_scores = torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )
            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1],
                dim=0,
                index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices,
                dim=1,
                index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step],
                    dim=0,
                    index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores,
                dim=1,
                index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            torch.index_select(
                attn[:, :, :step + 2],
                dim=0,
                index=active_bbsz_idx,
                out=attn_buf[:, :, :step + 2],
            )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(bsz):
            finalized[sent] = sorted(finalized[sent],
                                     key=lambda r: r["score"],
                                     reverse=True)

        return finalized