Example #1
0
    def fill_mask(self, masked_input: str, topk: int = 5):
        masked_token = "<mask>"
        assert (
            masked_token in masked_input
            and masked_input.count(masked_token) == 1
        ), "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(
            masked_token)

        text_spans = masked_input.split(masked_token)
        text_spans_bpe = ((" {0} ".format(masked_token)).join([
            self.bpe.encode(text_span.rstrip()) for text_span in text_spans
        ]).strip())
        tokens = self.task.source_dictionary.encode_line(
            "<s> " + text_spans_bpe + " </s>",
            append_eos=False,
            add_if_not_exist=False,
        )

        masked_index = (tokens == self.task.mask_idx).nonzero()
        if tokens.dim() == 1:
            tokens = tokens.unsqueeze(0)

        with utils.model_eval(self.model):
            features, extra = self.model(
                tokens.long().to(device=self.device),
                features_only=False,
                return_all_hiddens=False,
            )
        logits = features[0, masked_index, :].squeeze()
        prob = logits.softmax(dim=0)
        values, index = prob.topk(k=topk, dim=0)
        topk_predicted_token_bpe = self.task.source_dictionary.string(index)

        topk_filled_outputs = []
        for index, predicted_token_bpe in enumerate(
                topk_predicted_token_bpe.split(" ")):
            predicted_token = self.bpe.decode(predicted_token_bpe)
            # Quick hack to fix https://github.com/pytorch/fairseq/issues/1306
            if predicted_token_bpe.startswith("\u2581"):
                predicted_token = " " + predicted_token
            if " {0}".format(masked_token) in masked_input:
                topk_filled_outputs.append((
                    masked_input.replace(" {0}".format(masked_token),
                                         predicted_token),
                    values[index].item(),
                    predicted_token,
                ))
            else:
                topk_filled_outputs.append((
                    masked_input.replace(masked_token, predicted_token),
                    values[index].item(),
                    predicted_token,
                ))
        return topk_filled_outputs
Example #2
0
    def disambiguate_pronoun(self, sentence: str) -> bool:
        """
        Usage::

            >>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
            True

            >>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.')
            'The trophy'
        """
        assert hasattr(self.task, 'disambiguate_pronoun'), \
            'roberta.disambiguate_pronoun() requires a model trained with the WSC task.'
        with utils.model_eval(self.model):
            return self.task.disambiguate_pronoun(
                self.model, sentence, use_cuda=self.device.type == 'cuda')
Example #3
0
    def _get_loss(self, sample, model, criterion):
        assert hasattr(
            criterion, "compute_loss"
        ), "translation_moe task requires the criterion to implement the compute_loss() method"

        k = self.cfg.num_experts
        bsz = sample["target"].size(0)

        def get_lprob_y(encoder_out, prev_output_tokens_k):
            net_output = model.decoder(
                prev_output_tokens=prev_output_tokens_k,
                encoder_out=encoder_out,
            )
            loss, _ = criterion.compute_loss(model,
                                             net_output,
                                             sample,
                                             reduce=False)
            loss = loss.view(bsz, -1)
            return -loss.sum(dim=1, keepdim=True)  # -> B x 1

        def get_lprob_yz(winners=None):
            encoder_out = model.encoder(
                src_tokens=sample["net_input"]["src_tokens"],
                src_lengths=sample["net_input"]["src_lengths"],
            )

            if winners is None:
                lprob_y = []
                for i in range(k):
                    prev_output_tokens_k = sample["net_input"][
                        "prev_output_tokens"].clone()
                    assert not prev_output_tokens_k.requires_grad
                    prev_output_tokens_k[:, 0] = self.expert_index(i)
                    lprob_y.append(
                        get_lprob_y(encoder_out, prev_output_tokens_k))
                lprob_y = torch.cat(lprob_y, dim=1)  # -> B x K
            else:
                prev_output_tokens_k = sample["net_input"][
                    "prev_output_tokens"].clone()
                prev_output_tokens_k[:, 0] = self.expert_index(winners)
                lprob_y = get_lprob_y(encoder_out,
                                      prev_output_tokens_k)  # -> B

            if self.uniform_prior:
                lprob_yz = lprob_y
            else:
                lprob_z = model.gating_network(encoder_out)  # B x K
                if winners is not None:
                    lprob_z = lprob_z.gather(dim=1,
                                             index=winners.unsqueeze(-1))
                lprob_yz = lprob_y + lprob_z.type_as(lprob_y)  # B x K

            return lprob_yz

        # compute responsibilities without dropout
        with utils.model_eval(model):  # disable dropout
            with torch.no_grad():  # disable autograd
                lprob_yz = get_lprob_yz()  # B x K
                prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
        assert not prob_z_xy.requires_grad

        # compute loss with dropout
        if self.hard_selection:
            winners = prob_z_xy.max(dim=1)[1]
            loss = -get_lprob_yz(winners)
        else:
            lprob_yz = get_lprob_yz()  # B x K
            loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)

        loss = loss.sum()
        sample_size = (sample["target"].size(0)
                       if self.cfg.sentence_avg else sample["ntokens"])
        logging_output = {
            "loss": utils.item(loss.data),
            "ntokens": sample["ntokens"],
            "nsentences": bsz,
            "sample_size": sample_size,
            "posterior": prob_z_xy.float().sum(dim=0).cpu(),
        }
        return loss, sample_size, logging_output
Example #4
0
    def batch_fill_mask(self, masked_inputs: List[str], topk: int = 5):
        '''
        Allow batch inference for predicting mask token on each input sentence.
        '''
        masked_token = "<mask>"

        collect_tokens = []
        for masked_input in masked_inputs:
            text_spans = masked_input.split(masked_token)
            text_spans_bpe = ((" {0} ".format(masked_token)).join([
                self.bpe.encode(text_span.rstrip()) for text_span in text_spans
            ]).strip())
            tokens = self.task.source_dictionary.encode_line(
                "<s> " + text_spans_bpe + " </s>",
                append_eos=False,
                add_if_not_exist=False,
            )
            collect_tokens.append(tokens)

        # Pad the tokens in a 2D Tensor
        tokens = collate_tokens(collect_tokens, pad_idx=1)
        masked_index = (tokens == self.task.mask_idx).nonzero()

        with utils.model_eval(self.model):
            features, extra = self.model(
                tokens.long().to(device=self.device),
                features_only=False,
                return_all_hiddens=False,
            )

        all_outputs = []
        idx = 0
        for mask_idx, masked_input in zip(masked_index[:, -1], masked_inputs):
            logits = features[idx, mask_idx, :].squeeze()
            prob = logits.softmax(dim=0)
            values, index = prob.topk(k=topk, dim=0)

            topk_predicted_token_bpe = self.task.source_dictionary.string(
                index)

            topk_filled_outputs = []
            for index, predicted_token_bpe in enumerate(
                    topk_predicted_token_bpe.split(" ")):
                predicted_token = self.bpe.decode(predicted_token_bpe)
                # Quick hack to fix https://github.com/pytorch/fairseq/issues/1306
                if predicted_token_bpe.startswith("\u2581"):
                    predicted_token = " " + predicted_token

                if " {0}".format(masked_token) in masked_input:
                    topk_filled_outputs.append((
                        masked_input.replace(" {0}".format(masked_token),
                                             predicted_token),
                        values[index].item(),
                        predicted_token,
                    ))
                else:
                    topk_filled_outputs.append((
                        masked_input.replace(masked_token, predicted_token),
                        values[index].item(),
                        predicted_token,
                    ))
            all_outputs.append(topk_filled_outputs)
            idx += 1
        return all_outputs