class GenerateLogProbsForDecoding(nn.Module):
    def __init__(self, models, retain_dropout=False, apply_log_softmax=False):
        """Generate the neural network's output intepreted as log probabilities
        for decoding with Kaldi.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models,
                currently support fairseq.models.TransformerModel for scripting
            retain_dropout (bool, optional): use dropout when generating
                (default: False)
            apply_log_softmax (bool, optional): apply log-softmax on top of the
                network's output (default: False)
        """
        super().__init__()
        from fairseq.sequence_generator import EnsembleModel
        if isinstance(models, EnsembleModel):
            self.model = models
        else:
            self.model = EnsembleModel(models)
        self.retain_dropout = retain_dropout
        self.apply_log_softmax = apply_log_softmax

        if not self.retain_dropout:
            self.model.eval()

    def cuda(self):
        self.model.cuda()
        return self

    @torch.no_grad()
    def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
        """Generate a batch of translations.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models
            sample (dict): batch
        """
        self.model.reset_incremental_state()
        return self._generate(sample, **kwargs)

    def _generate(self, sample: Dict[str, Dict[str, Tensor]], **kwargs):
        net_input = sample["net_input"]
        src_tokens = net_input["src_tokens"]
        bsz = src_tokens.size(0)

        # compute the encoder output
        encoder_outs = self.model.forward_encoder(net_input)
        logits = encoder_outs[0].encoder_out.transpose(
            0, 1).float()  # T x B x V -> B x T x V
        assert logits.size(0) == bsz
        padding_mask = encoder_outs[0].encoder_padding_mask.t() \
            if encoder_outs[0].encoder_padding_mask is not None else None
        if self.apply_log_softmax:
            return F.log_softmax(logits, dim=-1), padding_mask
        return logits, padding_mask
Exemple #2
0
class SimpleGreedyDecoder(nn.Module):
    def __init__(
        self,
        models,
        dictionary,
        max_len_a=0,
        max_len_b=200,
        retain_dropout=False,
        temperature=1.0,
        for_validation=True,
    ):
        """Decode given speech audios with the simple greedy search.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models,
                currently support fairseq.models.TransformerModel for scripting
            dictionary (~fairseq.data.Dictionary): dictionary
            max_len_a/b (int, optional): generate sequences of maximum length
                ax + b, where x is the source length
            retain_dropout (bool, optional): use dropout when generating
                (default: False)
            temperature (float, optional): temperature, where values
                >1.0 produce more uniform samples and values <1.0 produce
                sharper samples (default: 1.0)
            for_validation (bool, optional): indicate whether the decoder is
                used for validation. It affects how max_len is determined, and
                whether a tensor of lprobs is returned. If true, target should be
                not None
        """
        super().__init__()
        from fairseq.sequence_generator import EnsembleModel
        if isinstance(models, EnsembleModel):
            self.model = models
        else:
            self.model = EnsembleModel(models)
        self.pad = dictionary.pad()
        self.unk = dictionary.unk()
        self.eos = dictionary.eos()
        self.vocab_size = len(dictionary)
        self.max_len_a = max_len_a
        self.max_len_b = max_len_b
        self.retain_dropout = retain_dropout
        self.temperature = temperature
        assert temperature > 0, "--temperature must be greater than 0"
        if not self.retain_dropout:
            self.model.eval()
        self.for_validation = for_validation

    def cuda(self):
        self.model.cuda()
        return self

    @torch.no_grad()
    def decode(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
        """Generate a batch of translations. Match the api of other fairseq generators.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models
            sample (dict): batch
            bos_token (int, optional): beginning of sentence token
                (default: self.eos)
        """
        self.model.reset_incremental_state()
        return self._decode(sample, **kwargs)

    @torch.no_grad()
    def _decode(self,
                sample: Dict[str, Dict[str, Tensor]],
                bos_token: Optional[int] = None):
        net_input = sample["net_input"]
        src_tokens = net_input["src_tokens"]
        input_size = src_tokens.size()
        bsz, src_len = input_size[0], input_size[1]

        # compute the encoder output
        encoder_outs = self.model.forward_encoder(net_input)
        target = sample["target"]
        # target can only be None if not for validation
        assert target is not None or not self.for_validation
        max_encoder_output_length = encoder_outs[0].encoder_out.size(0)
        # for validation, make the maximum decoding length equal to at least the
        # length of target, and the length of encoder_out if possible; otherwise
        # max_len is obtained from max_len_a/b
        max_len = max(max_encoder_output_length, target.size(1)) \
            if self.for_validation else \
            min(
                int(self.max_len_a * src_len + self.max_len_b),
                # exclude the EOS marker
                self.model.max_decoder_positions() - 1,
            )

        tokens = src_tokens.new(bsz, max_len + 2).long().fill_(self.pad)
        tokens[:, 0] = self.eos if bos_token is None else bos_token
        # lprobs is only used when target is not None (i.e., for validation)
        lprobs = encoder_outs[0].encoder_out.new_full(
            (bsz, target.size(1), self.vocab_size),
            -np.log(self.vocab_size),
        ) if self.for_validation else None
        attn = None
        for step in range(max_len + 1):  # one extra step for EOS marker
            is_eos = tokens[:, step].eq(self.eos)
            if step > 0 and is_eos.sum() == is_eos.size(0):
                # all predictions are finished (i.e., ended with eos)
                tokens = tokens[:, :step + 1]
                if attn is not None:
                    attn = attn[:, :, :step + 1]
                break
            log_probs, avg_attn_scores = self.model.forward_decoder(
                tokens[:, :step + 1],
                encoder_outs,
                temperature=self.temperature,
            )
            tokens[:, step + 1] = log_probs.argmax(-1)
            if step > 0:  # deal with finished predictions
                # make log_probs uniform if the previous output token is EOS
                # and add consecutive EOS to the end of prediction
                log_probs[is_eos, :] = -np.log(log_probs.size(1))
                tokens[is_eos, step + 1] = self.eos
            if self.for_validation and step < target.size(1):
                lprobs[:, step, :] = log_probs

            # Record attention scores
            if type(avg_attn_scores) is list:
                avg_attn_scores = avg_attn_scores[0]
            if avg_attn_scores is not None:
                if attn is None:
                    attn = avg_attn_scores.new(bsz, max_encoder_output_length,
                                               max_len + 2)
                attn[:, :, step + 1].copy_(avg_attn_scores)

        return tokens, lprobs, attn