def fill_mask(self, masked_input: str, topk: int = 5): masked_token = '<mask>' assert masked_token in masked_input and masked_input.count(masked_token) == 1, \ "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(masked_token) tokens = self.task.source_dictionary.encode_line( '<s> ' + masked_input, append_eos=True, ) masked_index = (tokens == self.task.mask_idx).nonzero() if tokens.dim() == 1: tokens = tokens.unsqueeze(0) with utils.eval(self.model): features, extra = self.model( tokens.long().to(device=self.device), features_only=False, return_all_hiddens=False, ) logits = features[0, masked_index, :].squeeze() prob = logits.softmax(dim=0) values, index = prob.topk(k=topk, dim=0) topk_predicted_token = self.task.source_dictionary.string(index) topk_filled_outputs = [] for index, predicted_token in enumerate( topk_predicted_token.split(' ')): topk_filled_outputs.append(( masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token, )) return topk_filled_outputs
def fill_mask(self, masked_input: str, topk: int = 5): masked_token = "<mask>" assert ( masked_token in masked_input and masked_input.count(masked_token) == 1 ), "Please add one {0} token for the input, eg: 'He is a {0} guy'".format( masked_token) text_spans = masked_input.split(masked_token) text_spans_bpe = ((" {0} ".format(masked_token)).join([ self.bpe.encode(text_span.rstrip()) for text_span in text_spans ]).strip()) tokens = self.task.source_dictionary.encode_line( "<s> " + text_spans_bpe + " </s>", append_eos=False, add_if_not_exist=False, ) masked_index = (tokens == self.task.mask_idx).nonzero() if tokens.dim() == 1: tokens = tokens.unsqueeze(0) with utils.eval(self.model): features, extra = self.model( tokens.long().to(device=self.device), features_only=False, return_all_hiddens=False, ) logits = features[0, masked_index, :].squeeze() prob = logits.softmax(dim=0) values, index = prob.topk(k=topk, dim=0) topk_predicted_token_bpe = self.task.source_dictionary.string(index) topk_filled_outputs = [] for index, predicted_token_bpe in enumerate( topk_predicted_token_bpe.split(" ")): predicted_token = self.bpe.decode(predicted_token_bpe) # Quick hack to fix https://github.com/pytorch/fairseq/issues/1306 if predicted_token_bpe.startswith("\u2581"): predicted_token = " " + predicted_token if " {0}".format(masked_token) in masked_input: topk_filled_outputs.append(( masked_input.replace(" {0}".format(masked_token), predicted_token), values[index].item(), predicted_token, )) else: topk_filled_outputs.append(( masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token, )) return topk_filled_outputs
def disambiguate_pronoun(self, sentence: str) -> bool: """ Usage:: >>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.') True >>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.') 'The trophy' """ assert hasattr(self.task, 'disambiguate_pronoun'), \ 'roberta.disambiguate_pronoun() requires a model trained with the WSC task.' with utils.eval(self.model): return self.task.disambiguate_pronoun(self.model, sentence, use_cuda=self.device.type == 'cuda')
def _get_loss(self, sample, model, criterion): assert hasattr(criterion, 'compute_loss'), \ 'language_model_moe task requires the criterion to implement the compute_loss() method' bsz = sample['target'].size(0) src_tokens = sample['net_input']['src_tokens'] src_lengths = sample['net_input']['src_lengths'] #### E-STEP with utils.eval(model): # disable dropout with torch.no_grad(): # disable autograd net_output = model(src_tokens=src_tokens, src_lengths=src_lengths) # pass net output to gating network to compute expert probabilities expert_probs = model.gating_network(net_output) # hard selection of experts expert_assignments = [ self.expert_index(x) for x in expert_probs.argmax(dim=1) ] # add expert assignments as BOS tokens src_tokens[:, 0] = torch.Tensor(expert_assignments).long() #### M-STEP net_output = model(src_tokens=src_tokens, src_lengths=src_lengths) loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False) loss = loss.view(sample['target'].size(0), -1) loss = loss.sum(dim=1, keepdim=True) loss = loss.sum() sample_size = sample['target'].size( 0) if self.args.sentence_avg else sample['ntokens'] logging_output = { 'loss': utils.item(loss.data), 'ntokens': sample['ntokens'], 'nsentences': bsz, 'sample_size': sample_size, "expert_assignments": expert_probs.argmax(dim=1) } return loss, sample_size, logging_output
def fill_single_mask(self, masked_inputs, topk=3): if isinstance(masked_inputs, str): masked_inputs = [masked_inputs] assert all(self.masked_token in masked_input for masked_input in masked_inputs), \ "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(self.masked_token) tokens = [ self.encode_masked_input(masked_input) for masked_input in masked_inputs ] pad_to_length = max(len(token) for token in tokens) tokens = data_utils.collate_tokens( tokens, self.task.source_dictionary.pad(), self.task.source_dictionary.eos(), False, False, pad_to_length=pad_to_length, ) if tokens.dim() == 1: tokens = tokens.unsqueeze(0) src_lengths = tokens.ne(self.task.source_dictionary.pad()).sum(dim=-1) masked_tokens = tokens.eq(self.task.source_dictionary.mask_index) # with utils.model_eval(self.model): # new version with utils.eval(self.model): logits = self.model.forward_encoder( tokens.long().to(device=self.device), src_lengths=src_lengths.to(device=self.device), masked_tokens=masked_tokens) prob = logits.softmax(dim=-1) all_values, all_index = prob.topk(k=topk, dim=-1) topk_predicted_token_bpe = self.task.source_dictionary.string( all_index) topk_predicted_token_bpe = [ tokens.split(' ') for tokens in topk_predicted_token_bpe.split('\n') ] return topk_predicted_token_bpe
def _get_loss(self, sample, model, criterion): assert hasattr(criterion, 'compute_loss'), \ 'translation_moe task requires the criterion to implement the compute_loss() method' k = self.args.num_experts bsz = sample['target'].size(0) def get_lprob_y(encoder_out, prev_output_tokens_k): net_output = model.decoder( prev_output_tokens=prev_output_tokens_k, encoder_out=encoder_out, ) loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False) loss = loss.view(bsz, -1) return -loss.sum(dim=1, keepdim=True) # -> B x 1 def get_lprob_yz(winners=None): encoder_out = model.encoder( src_tokens=sample['net_input']['src_tokens'], src_lengths=sample['net_input']['src_lengths'], ) if winners is None: lprob_y = [] for i in range(k): prev_output_tokens_k = sample['net_input'][ 'prev_output_tokens'].clone() assert not prev_output_tokens_k.requires_grad prev_output_tokens_k[:, 0] = self.expert_index(i) lprob_y.append( get_lprob_y(encoder_out, prev_output_tokens_k)) lprob_y = torch.cat(lprob_y, dim=1) # -> B x K else: prev_output_tokens_k = sample['net_input'][ 'prev_output_tokens'].clone() prev_output_tokens_k[:, 0] = self.expert_index(winners) lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B if self.uniform_prior: lprob_yz = lprob_y else: lprob_z = model.gating_network(encoder_out) # B x K if winners is not None: lprob_z = lprob_z.gather(dim=1, index=winners.unsqueeze(-1)) lprob_yz = lprob_y + lprob_z.type_as(lprob_y) # B x K return lprob_yz # compute responsibilities without dropout with utils.eval(model): # disable dropout with torch.no_grad(): # disable autograd lprob_yz = get_lprob_yz() # B x K prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1) assert not prob_z_xy.requires_grad # compute loss with dropout if self.hard_selection: winners = prob_z_xy.max(dim=1)[1] loss = -get_lprob_yz(winners) else: lprob_yz = get_lprob_yz() # B x K loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1) loss = loss.sum() sample_size = sample['target'].size( 0) if self.args.sentence_avg else sample['ntokens'] logging_output = { 'loss': utils.item(loss.data), 'ntokens': sample['ntokens'], 'nsentences': bsz, 'sample_size': sample_size, 'posterior': prob_z_xy.float().sum(dim=0).cpu(), } return loss, sample_size, logging_output
def fill_noised_mask(self, masked_inputs: List[str], topk=1): masked_token = '<mask>' noises, topk_opt = [], [] text_spans = [sent.split(masked_token) for src, sent in masked_inputs] noised_tokens = [] targets_bpe = [] for (src, _), segs in zip(masked_inputs, text_spans): bpe_src = self.bpe.encode(src.strip()) bpe_tgt = ' {0} '.format(masked_token).join( [self.bpe.encode(seg.rstrip()) for seg in segs]) bpe_idx = self.task.source_dictionary.encode_line( '<s> ' + bpe_src + ' </s> </s> ' + bpe_tgt + ' </s>', append_eos=False, add_if_not_exist=False, ) tgt_bpe_idx = self.task.source_dictionary.encode_line( '<s> ' + bpe_tgt + ' </s>', append_eos=False, add_if_not_exist=False, ) noised_tokens.append(bpe_idx) targets_bpe.append(tgt_bpe_idx) sample = self._build_sample(noised_tokens).long() masked_index = (sample == self.task.mask_idx) with utils.eval(self.model): # features: B x T x |V| features, extra = self.model(sample, features_only=False, return_all_hiddens=False, masked_tokens=masked_index) prob = features.softmax(dim=-1) # values, index = prob.topk(k=topk, dim=-1) values, index = prob.max(dim=-1) index = index.squeeze(-1) # K extra_symbols_to_ignore = set([]) extra_symbols_to_ignore.add( self.task.source_dictionary[self.task.source_dictionary.eos()]) extra_symbols_to_ignore.add( self.task.source_dictionary[self.task.source_dictionary.bos()]) tot_masks = 0 for ii, sent in enumerate(targets_bpe): decode_noise_tokens = self.decode(sent) decode_noise_tokens = decode_noise_tokens.replace( "<mask>", " <mask>").strip() K = masked_index[ii, :].sum().item() topk_predictions = index[tot_masks:tot_masks + K] tot_masks += K assert len(topk_predictions) == decode_noise_tokens.split( " ").count('<mask>') output = [] mask_count = 0 topk_predicted_token_bpe = self.task.source_dictionary.string( topk_predictions, skip_ignore=True).split() for token in decode_noise_tokens.split(" "): if token == "<mask>": predict_bpe = topk_predicted_token_bpe[mask_count] if predict_bpe in extra_symbols_to_ignore: continue predicted_token = self.bpe.decode(predict_bpe) # output.append("[" + predicted_token.strip() + "]") output.append(predicted_token.strip()) mask_count += 1 else: output.append(token.strip()) topk_opt.append(" ".join(output)) noises.append(decode_noise_tokens) return topk_opt, noises
def fill_mask(self, masked_inputs, topk=3, return_filled_sentence=False): if isinstance(masked_inputs, str): masked_inputs = [masked_inputs] masked_token = '[MASK]' assert all(masked_token in masked_input for masked_input in masked_inputs), \ "Please add one {0} token for the input, eg: 'He is a {0} guy'".format(masked_token) def encode_masked_input(masked_input): text_spans = masked_input.split(masked_token) text_spans_bpe = (' {0} '.format(masked_token)).join([ self.bpe.encode(text_span.rstrip()) for text_span in text_spans ]).strip() tokens = self.task.source_dictionary.encode_line( '[CLS] ' + text_spans_bpe + ' [SEP]', append_eos=False, add_if_not_exist=False, ) return tokens tokens = [ encode_masked_input(masked_input) for masked_input in masked_inputs ] pad_to_length = max(len(token) for token in tokens) tokens = data_utils.collate_tokens( tokens, self.task.source_dictionary.pad(), self.task.source_dictionary.eos(), False, False, pad_to_length=pad_to_length, ) if tokens.dim() == 1: tokens = tokens.unsqueeze(0) src_lengths = tokens.ne(self.task.source_dictionary.pad()).sum(dim=-1) masked_tokens = tokens.eq(self.task.source_dictionary.mask_index) # with utils.model_eval(self.model): # new version with utils.eval(self.model): logits = self.model.forward_encoder( tokens.long().to(device=self.device), src_lengths=src_lengths.to(device=self.device), masked_tokens=masked_tokens) prob = logits.softmax(dim=-1) all_values, all_index = prob.topk(k=topk, dim=-1) topk_predicted_token_bpe = self.task.source_dictionary.string( all_index) topk_predicted_token_bpe = [ tokens.split(' ') for tokens in topk_predicted_token_bpe.split('\n') ] if not return_filled_sentence: return topk_predicted_token_bpe # all_outputs = [] # topk_predicted_token_bpe = iter(topk_predicted_token_bpe) # topk_filled_outputs = [] # for masked_input in masked_inputs: # predicted_token = self.bpe.decode(predicted_token_bpe) # if predicted_token_bpe.startswith('\u2581'): # predicted_token = ' ' + predicted_token # if " {0}".format(masked_token) in masked_input: # topk_filled_outputs.append(( # masked_input.replace( # ' {0}'.format(masked_token), predicted_token # ), # values[index].item(), # predicted_token, # )) # else: # topk_filled_outputs.append(( # masked_input.replace(masked_token, predicted_token), # values[index].item(), # predicted_token, # )) # all_outputs.append(topk_filled_outputs) return None