Example #1
0
    def fill_mask(self, masked_input: str, topk: int = 5):
        masked_token = '<mask>'
        assert masked_token in masked_input and masked_input.count(masked_token) == 1, \
            "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(masked_token)

        tokens = self.task.source_dictionary.encode_line(
            '<s> ' + masked_input,
            append_eos=True,
        )

        masked_index = (tokens == self.task.mask_idx).nonzero()
        if tokens.dim() == 1:
            tokens = tokens.unsqueeze(0)

        with utils.eval(self.model):
            features, extra = self.model(
                tokens.long().to(device=self.device),
                features_only=False,
                return_all_hiddens=False,
            )

        logits = features[0, masked_index, :].squeeze()
        prob = logits.softmax(dim=0)
        values, index = prob.topk(k=topk, dim=0)
        topk_predicted_token = self.task.source_dictionary.string(index)

        topk_filled_outputs = []
        for index, predicted_token in enumerate(
                topk_predicted_token.split(' ')):
            topk_filled_outputs.append((
                masked_input.replace(masked_token, predicted_token),
                values[index].item(),
                predicted_token,
            ))
        return topk_filled_outputs
    def fill_mask(self, masked_input: str, topk: int = 5):
        masked_token = "<mask>"
        assert (
            masked_token in masked_input
            and masked_input.count(masked_token) == 1
        ), "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(
            masked_token)

        text_spans = masked_input.split(masked_token)
        text_spans_bpe = ((" {0} ".format(masked_token)).join([
            self.bpe.encode(text_span.rstrip()) for text_span in text_spans
        ]).strip())
        tokens = self.task.source_dictionary.encode_line(
            "<s> " + text_spans_bpe + " </s>",
            append_eos=False,
            add_if_not_exist=False,
        )

        masked_index = (tokens == self.task.mask_idx).nonzero()
        if tokens.dim() == 1:
            tokens = tokens.unsqueeze(0)

        with utils.eval(self.model):
            features, extra = self.model(
                tokens.long().to(device=self.device),
                features_only=False,
                return_all_hiddens=False,
            )
        logits = features[0, masked_index, :].squeeze()
        prob = logits.softmax(dim=0)
        values, index = prob.topk(k=topk, dim=0)
        topk_predicted_token_bpe = self.task.source_dictionary.string(index)

        topk_filled_outputs = []
        for index, predicted_token_bpe in enumerate(
                topk_predicted_token_bpe.split(" ")):
            predicted_token = self.bpe.decode(predicted_token_bpe)
            # Quick hack to fix https://github.com/pytorch/fairseq/issues/1306
            if predicted_token_bpe.startswith("\u2581"):
                predicted_token = " " + predicted_token
            if " {0}".format(masked_token) in masked_input:
                topk_filled_outputs.append((
                    masked_input.replace(" {0}".format(masked_token),
                                         predicted_token),
                    values[index].item(),
                    predicted_token,
                ))
            else:
                topk_filled_outputs.append((
                    masked_input.replace(masked_token, predicted_token),
                    values[index].item(),
                    predicted_token,
                ))
        return topk_filled_outputs
Example #3
0
    def disambiguate_pronoun(self, sentence: str) -> bool:
        """
        Usage::

            >>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
            True

            >>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.')
            'The trophy'
        """
        assert hasattr(self.task, 'disambiguate_pronoun'), \
            'roberta.disambiguate_pronoun() requires a model trained with the WSC task.'
        with utils.eval(self.model):
            return self.task.disambiguate_pronoun(self.model, sentence, use_cuda=self.device.type == 'cuda')
    def _get_loss(self, sample, model, criterion):
        assert hasattr(criterion, 'compute_loss'), \
            'language_model_moe task requires the criterion to implement the compute_loss() method'

        bsz = sample['target'].size(0)
        src_tokens = sample['net_input']['src_tokens']
        src_lengths = sample['net_input']['src_lengths']

        #### E-STEP
        with utils.eval(model):  # disable dropout
            with torch.no_grad():  # disable autograd
                net_output = model(src_tokens=src_tokens,
                                   src_lengths=src_lengths)
        # pass net output to gating network to compute expert probabilities
        expert_probs = model.gating_network(net_output)
        # hard selection of experts
        expert_assignments = [
            self.expert_index(x) for x in expert_probs.argmax(dim=1)
        ]
        # add expert assignments as BOS tokens
        src_tokens[:, 0] = torch.Tensor(expert_assignments).long()

        #### M-STEP
        net_output = model(src_tokens=src_tokens, src_lengths=src_lengths)
        loss, _ = criterion.compute_loss(model,
                                         net_output,
                                         sample,
                                         reduce=False)
        loss = loss.view(sample['target'].size(0), -1)
        loss = loss.sum(dim=1, keepdim=True)

        loss = loss.sum()
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data),
            'ntokens': sample['ntokens'],
            'nsentences': bsz,
            'sample_size': sample_size,
            "expert_assignments": expert_probs.argmax(dim=1)
        }
        return loss, sample_size, logging_output
Example #5
0
    def fill_single_mask(self, masked_inputs, topk=3):
        if isinstance(masked_inputs, str):
            masked_inputs = [masked_inputs]
        assert all(self.masked_token in masked_input for masked_input in masked_inputs), \
            "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(self.masked_token)

        tokens = [
            self.encode_masked_input(masked_input)
            for masked_input in masked_inputs
        ]
        pad_to_length = max(len(token) for token in tokens)

        tokens = data_utils.collate_tokens(
            tokens,
            self.task.source_dictionary.pad(),
            self.task.source_dictionary.eos(),
            False,
            False,
            pad_to_length=pad_to_length,
        )
        if tokens.dim() == 1:
            tokens = tokens.unsqueeze(0)
        src_lengths = tokens.ne(self.task.source_dictionary.pad()).sum(dim=-1)
        masked_tokens = tokens.eq(self.task.source_dictionary.mask_index)
        # with utils.model_eval(self.model):  # new version
        with utils.eval(self.model):
            logits = self.model.forward_encoder(
                tokens.long().to(device=self.device),
                src_lengths=src_lengths.to(device=self.device),
                masked_tokens=masked_tokens)
        prob = logits.softmax(dim=-1)
        all_values, all_index = prob.topk(k=topk, dim=-1)
        topk_predicted_token_bpe = self.task.source_dictionary.string(
            all_index)

        topk_predicted_token_bpe = [
            tokens.split(' ')
            for tokens in topk_predicted_token_bpe.split('\n')
        ]
        return topk_predicted_token_bpe
Example #6
0
    def _get_loss(self, sample, model, criterion):
        assert hasattr(criterion, 'compute_loss'), \
            'translation_moe task requires the criterion to implement the compute_loss() method'

        k = self.args.num_experts
        bsz = sample['target'].size(0)

        def get_lprob_y(encoder_out, prev_output_tokens_k):
            net_output = model.decoder(
                prev_output_tokens=prev_output_tokens_k,
                encoder_out=encoder_out,
            )
            loss, _ = criterion.compute_loss(model,
                                             net_output,
                                             sample,
                                             reduce=False)
            loss = loss.view(bsz, -1)
            return -loss.sum(dim=1, keepdim=True)  # -> B x 1

        def get_lprob_yz(winners=None):
            encoder_out = model.encoder(
                src_tokens=sample['net_input']['src_tokens'],
                src_lengths=sample['net_input']['src_lengths'],
            )

            if winners is None:
                lprob_y = []
                for i in range(k):
                    prev_output_tokens_k = sample['net_input'][
                        'prev_output_tokens'].clone()
                    assert not prev_output_tokens_k.requires_grad
                    prev_output_tokens_k[:, 0] = self.expert_index(i)
                    lprob_y.append(
                        get_lprob_y(encoder_out, prev_output_tokens_k))
                lprob_y = torch.cat(lprob_y, dim=1)  # -> B x K
            else:
                prev_output_tokens_k = sample['net_input'][
                    'prev_output_tokens'].clone()
                prev_output_tokens_k[:, 0] = self.expert_index(winners)
                lprob_y = get_lprob_y(encoder_out,
                                      prev_output_tokens_k)  # -> B

            if self.uniform_prior:
                lprob_yz = lprob_y
            else:
                lprob_z = model.gating_network(encoder_out)  # B x K
                if winners is not None:
                    lprob_z = lprob_z.gather(dim=1,
                                             index=winners.unsqueeze(-1))
                lprob_yz = lprob_y + lprob_z.type_as(lprob_y)  # B x K

            return lprob_yz

        # compute responsibilities without dropout
        with utils.eval(model):  # disable dropout
            with torch.no_grad():  # disable autograd
                lprob_yz = get_lprob_yz()  # B x K
                prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
        assert not prob_z_xy.requires_grad

        # compute loss with dropout
        if self.hard_selection:
            winners = prob_z_xy.max(dim=1)[1]
            loss = -get_lprob_yz(winners)
        else:
            lprob_yz = get_lprob_yz()  # B x K
            loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)

        loss = loss.sum()
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data),
            'ntokens': sample['ntokens'],
            'nsentences': bsz,
            'sample_size': sample_size,
            'posterior': prob_z_xy.float().sum(dim=0).cpu(),
        }
        return loss, sample_size, logging_output
    def fill_noised_mask(self, masked_inputs: List[str], topk=1):
        masked_token = '<mask>'
        noises, topk_opt = [], []

        text_spans = [sent.split(masked_token) for src, sent in masked_inputs]
        noised_tokens = []
        targets_bpe = []
        for (src, _), segs in zip(masked_inputs, text_spans):
            bpe_src = self.bpe.encode(src.strip())
            bpe_tgt = ' {0} '.format(masked_token).join(
                [self.bpe.encode(seg.rstrip()) for seg in segs])
            bpe_idx = self.task.source_dictionary.encode_line(
                '<s> ' + bpe_src + ' </s> </s> ' + bpe_tgt + ' </s>',
                append_eos=False,
                add_if_not_exist=False,
            )
            tgt_bpe_idx = self.task.source_dictionary.encode_line(
                '<s> ' + bpe_tgt + ' </s>',
                append_eos=False,
                add_if_not_exist=False,
            )
            noised_tokens.append(bpe_idx)
            targets_bpe.append(tgt_bpe_idx)

        sample = self._build_sample(noised_tokens).long()
        masked_index = (sample == self.task.mask_idx)

        with utils.eval(self.model):
            # features: B x T x |V|
            features, extra = self.model(sample,
                                         features_only=False,
                                         return_all_hiddens=False,
                                         masked_tokens=masked_index)
        prob = features.softmax(dim=-1)
        # values, index = prob.topk(k=topk, dim=-1)
        values, index = prob.max(dim=-1)
        index = index.squeeze(-1)  # K
        extra_symbols_to_ignore = set([])
        extra_symbols_to_ignore.add(
            self.task.source_dictionary[self.task.source_dictionary.eos()])
        extra_symbols_to_ignore.add(
            self.task.source_dictionary[self.task.source_dictionary.bos()])

        tot_masks = 0
        for ii, sent in enumerate(targets_bpe):
            decode_noise_tokens = self.decode(sent)
            decode_noise_tokens = decode_noise_tokens.replace(
                "<mask>", " <mask>").strip()
            K = masked_index[ii, :].sum().item()
            topk_predictions = index[tot_masks:tot_masks + K]
            tot_masks += K
            assert len(topk_predictions) == decode_noise_tokens.split(
                " ").count('<mask>')
            output = []
            mask_count = 0
            topk_predicted_token_bpe = self.task.source_dictionary.string(
                topk_predictions, skip_ignore=True).split()
            for token in decode_noise_tokens.split(" "):
                if token == "<mask>":
                    predict_bpe = topk_predicted_token_bpe[mask_count]
                    if predict_bpe in extra_symbols_to_ignore:
                        continue
                    predicted_token = self.bpe.decode(predict_bpe)
                    # output.append("[" + predicted_token.strip() + "]")
                    output.append(predicted_token.strip())
                    mask_count += 1
                else:
                    output.append(token.strip())
            topk_opt.append(" ".join(output))
            noises.append(decode_noise_tokens)
        return topk_opt, noises
Example #8
0
    def fill_mask(self, masked_inputs, topk=3, return_filled_sentence=False):
        if isinstance(masked_inputs, str):
            masked_inputs = [masked_inputs]
        masked_token = '[MASK]'
        assert all(masked_token in masked_input for masked_input in masked_inputs), \
            "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(masked_token)

        def encode_masked_input(masked_input):
            text_spans = masked_input.split(masked_token)
            text_spans_bpe = (' {0} '.format(masked_token)).join([
                self.bpe.encode(text_span.rstrip()) for text_span in text_spans
            ]).strip()
            tokens = self.task.source_dictionary.encode_line(
                '[CLS] ' + text_spans_bpe + ' [SEP]',
                append_eos=False,
                add_if_not_exist=False,
            )
            return tokens

        tokens = [
            encode_masked_input(masked_input) for masked_input in masked_inputs
        ]
        pad_to_length = max(len(token) for token in tokens)

        tokens = data_utils.collate_tokens(
            tokens,
            self.task.source_dictionary.pad(),
            self.task.source_dictionary.eos(),
            False,
            False,
            pad_to_length=pad_to_length,
        )
        if tokens.dim() == 1:
            tokens = tokens.unsqueeze(0)
        src_lengths = tokens.ne(self.task.source_dictionary.pad()).sum(dim=-1)
        masked_tokens = tokens.eq(self.task.source_dictionary.mask_index)
        # with utils.model_eval(self.model):  # new version
        with utils.eval(self.model):
            logits = self.model.forward_encoder(
                tokens.long().to(device=self.device),
                src_lengths=src_lengths.to(device=self.device),
                masked_tokens=masked_tokens)
        prob = logits.softmax(dim=-1)
        all_values, all_index = prob.topk(k=topk, dim=-1)
        topk_predicted_token_bpe = self.task.source_dictionary.string(
            all_index)

        topk_predicted_token_bpe = [
            tokens.split(' ')
            for tokens in topk_predicted_token_bpe.split('\n')
        ]
        if not return_filled_sentence:
            return topk_predicted_token_bpe

        # all_outputs = []
        # topk_predicted_token_bpe = iter(topk_predicted_token_bpe)
        # topk_filled_outputs = []
        # for masked_input in masked_inputs:
        #         predicted_token = self.bpe.decode(predicted_token_bpe)
        #         if predicted_token_bpe.startswith('\u2581'):
        #             predicted_token = ' ' + predicted_token
        #         if " {0}".format(masked_token) in masked_input:
        #             topk_filled_outputs.append((
        #                 masked_input.replace(
        #                     ' {0}'.format(masked_token), predicted_token
        #                 ),
        #                 values[index].item(),
        #                 predicted_token,
        #             ))
        #         else:
        #             topk_filled_outputs.append((
        #                 masked_input.replace(masked_token, predicted_token),
        #                 values[index].item(),
        #                 predicted_token,
        #             ))
        #     all_outputs.append(topk_filled_outputs)
        return None