def fill_mask(self, masked_input: str, topk: int = 5): masked_token = "<mask>" assert ( masked_token in masked_input and masked_input.count(masked_token) == 1 ), "Please add one {0} token for the input, eg: 'He is a {0} guy'".format( masked_token) text_spans = masked_input.split(masked_token) text_spans_bpe = ((" {0} ".format(masked_token)).join([ self.bpe.encode(text_span.rstrip()) for text_span in text_spans ]).strip()) tokens = self.task.source_dictionary.encode_line( "<s> " + text_spans_bpe + " </s>", append_eos=False, add_if_not_exist=False, ) masked_index = (tokens == self.task.mask_idx).nonzero() if tokens.dim() == 1: tokens = tokens.unsqueeze(0) with utils.model_eval(self.model): features, extra = self.model( tokens.long().to(device=self.device), features_only=False, return_all_hiddens=False, ) logits = features[0, masked_index, :].squeeze() prob = logits.softmax(dim=0) values, index = prob.topk(k=topk, dim=0) topk_predicted_token_bpe = self.task.source_dictionary.string(index) topk_filled_outputs = [] for index, predicted_token_bpe in enumerate( topk_predicted_token_bpe.split(" ")): predicted_token = self.bpe.decode(predicted_token_bpe) # Quick hack to fix https://github.com/pytorch/fairseq/issues/1306 if predicted_token_bpe.startswith("\u2581"): predicted_token = " " + predicted_token if " {0}".format(masked_token) in masked_input: topk_filled_outputs.append(( masked_input.replace(" {0}".format(masked_token), predicted_token), values[index].item(), predicted_token, )) else: topk_filled_outputs.append(( masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token, )) return topk_filled_outputs
def disambiguate_pronoun(self, sentence: str) -> bool: """ Usage:: >>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.') True >>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.') 'The trophy' """ assert hasattr(self.task, 'disambiguate_pronoun'), \ 'roberta.disambiguate_pronoun() requires a model trained with the WSC task.' with utils.model_eval(self.model): return self.task.disambiguate_pronoun( self.model, sentence, use_cuda=self.device.type == 'cuda')
def _get_loss(self, sample, model, criterion): assert hasattr( criterion, "compute_loss" ), "translation_moe task requires the criterion to implement the compute_loss() method" k = self.cfg.num_experts bsz = sample["target"].size(0) def get_lprob_y(encoder_out, prev_output_tokens_k): net_output = model.decoder( prev_output_tokens=prev_output_tokens_k, encoder_out=encoder_out, ) loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False) loss = loss.view(bsz, -1) return -loss.sum(dim=1, keepdim=True) # -> B x 1 def get_lprob_yz(winners=None): encoder_out = model.encoder( src_tokens=sample["net_input"]["src_tokens"], src_lengths=sample["net_input"]["src_lengths"], ) if winners is None: lprob_y = [] for i in range(k): prev_output_tokens_k = sample["net_input"][ "prev_output_tokens"].clone() assert not prev_output_tokens_k.requires_grad prev_output_tokens_k[:, 0] = self.expert_index(i) lprob_y.append( get_lprob_y(encoder_out, prev_output_tokens_k)) lprob_y = torch.cat(lprob_y, dim=1) # -> B x K else: prev_output_tokens_k = sample["net_input"][ "prev_output_tokens"].clone() prev_output_tokens_k[:, 0] = self.expert_index(winners) lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B if self.uniform_prior: lprob_yz = lprob_y else: lprob_z = model.gating_network(encoder_out) # B x K if winners is not None: lprob_z = lprob_z.gather(dim=1, index=winners.unsqueeze(-1)) lprob_yz = lprob_y + lprob_z.type_as(lprob_y) # B x K return lprob_yz # compute responsibilities without dropout with utils.model_eval(model): # disable dropout with torch.no_grad(): # disable autograd lprob_yz = get_lprob_yz() # B x K prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1) assert not prob_z_xy.requires_grad # compute loss with dropout if self.hard_selection: winners = prob_z_xy.max(dim=1)[1] loss = -get_lprob_yz(winners) else: lprob_yz = get_lprob_yz() # B x K loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1) loss = loss.sum() sample_size = (sample["target"].size(0) if self.cfg.sentence_avg else sample["ntokens"]) logging_output = { "loss": utils.item(loss.data), "ntokens": sample["ntokens"], "nsentences": bsz, "sample_size": sample_size, "posterior": prob_z_xy.float().sum(dim=0).cpu(), } return loss, sample_size, logging_output
def batch_fill_mask(self, masked_inputs: List[str], topk: int = 5): ''' Allow batch inference for predicting mask token on each input sentence. ''' masked_token = "<mask>" collect_tokens = [] for masked_input in masked_inputs: text_spans = masked_input.split(masked_token) text_spans_bpe = ((" {0} ".format(masked_token)).join([ self.bpe.encode(text_span.rstrip()) for text_span in text_spans ]).strip()) tokens = self.task.source_dictionary.encode_line( "<s> " + text_spans_bpe + " </s>", append_eos=False, add_if_not_exist=False, ) collect_tokens.append(tokens) # Pad the tokens in a 2D Tensor tokens = collate_tokens(collect_tokens, pad_idx=1) masked_index = (tokens == self.task.mask_idx).nonzero() with utils.model_eval(self.model): features, extra = self.model( tokens.long().to(device=self.device), features_only=False, return_all_hiddens=False, ) all_outputs = [] idx = 0 for mask_idx, masked_input in zip(masked_index[:, -1], masked_inputs): logits = features[idx, mask_idx, :].squeeze() prob = logits.softmax(dim=0) values, index = prob.topk(k=topk, dim=0) topk_predicted_token_bpe = self.task.source_dictionary.string( index) topk_filled_outputs = [] for index, predicted_token_bpe in enumerate( topk_predicted_token_bpe.split(" ")): predicted_token = self.bpe.decode(predicted_token_bpe) # Quick hack to fix https://github.com/pytorch/fairseq/issues/1306 if predicted_token_bpe.startswith("\u2581"): predicted_token = " " + predicted_token if " {0}".format(masked_token) in masked_input: topk_filled_outputs.append(( masked_input.replace(" {0}".format(masked_token), predicted_token), values[index].item(), predicted_token, )) else: topk_filled_outputs.append(( masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token, )) all_outputs.append(topk_filled_outputs) idx += 1 return all_outputs