def forward(self, features, masked_tokens=None, tags=None, **kwargs): """ Args: features: (seq_length, batch_size, hidden_dim) masked_tokens: (seq_length, batch_size) tags: (seq_length, batch_size) Return: nll_loss: """ x = self.dropout(features) x = self.dense(x) x = self.activation_fn(x) x = self.dropout(x) x = self.out_proj(x) ncorrect = None if self.use_crf: nll_loss = -self.crf_proj( emissions=x, tags=tags, mask=masked_tokens) else: x = x[masked_tokens] tags = tags[masked_tokens] nll_loss = cross_entropy( x.view(-1, x.size(-1)), tags.view(-1), reduction='sum', ) preds = x.argmax(dim=1) ncorrect = (preds == tags).sum() return nll_loss, ncorrect
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ # compute MLM loss masked_tokens = sample['target'].ne(self.padding_idx) sample_size = masked_tokens.int().sum().item() # (Rare case) When all tokens are masked, the model results in empty # tensor and gives CUDA error. # if sample_size == 0: # print(sample['target'], sample['net_input']['src_tokens']) # exit() # masked_tokens = None masked_tokens = torch.where( masked_tokens.any(), masked_tokens, masked_tokens.new([True]), ) logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0] targets = model.get_targets(sample, [logits]) if masked_tokens is not None: targets = targets[masked_tokens] # if sample_size != 0: # targets = targets[masked_tokens] # loss = F.nll_loss( # F.log_softmax( # logits.view(-1, logits.size(-1)), # dim=-1, # dtype=torch.float32, # ), # targets.view(-1), # reduction='sum', # ignore_index=self.padding_idx, # ) loss = modules.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), reduction='sum', ignore_index=self.padding_idx, ) logging_output = { 'loss': loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['nsentences'], 'sample_size': sample_size, } return loss, sample_size, logging_output
def _compute_clm_loss(self, model, net_output, sample, masked_tokens): logits = net_output[1]['clm_out'] targets = model.get_clm_targets(sample, net_output)[masked_tokens] loss = modules.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), reduction='sum', ignore_index=self.padding_idx, ) return loss
def compute_masked_loss(self, targets, net_output): encoder_logits = net_output[1]['masked_encoder_out'][0] assert encoder_logits.size(0) == targets.size(0), ( encoder_logits.size(), targets.size()) loss = modules.cross_entropy( encoder_logits.view(-1, encoder_logits.size(-1)), targets.view(-1), reduction="sum", ignore_index=self.padding_idx, ) return loss
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ masked_tokens = sample["target"].ne(self.padding_idx) sample_size = masked_tokens.int().sum() # Rare: when all tokens are masked, project all tokens. # We use torch.where to avoid device-to-host transfers, # except on CPU where torch.where is not well supported # (see github.com/pytorch/pytorch/issues/26247). if self.tpu: masked_tokens = None # always project all tokens on TPU elif masked_tokens.device == torch.device("cpu"): if not masked_tokens.any(): masked_tokens = None else: masked_tokens = torch.where( masked_tokens.any(), masked_tokens, masked_tokens.new([True]), ) logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0] targets = model.get_targets(sample, [logits]) if masked_tokens is not None: targets = targets[masked_tokens] loss = modules.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), #reduction="sum", reduction="max", ignore_index=self.padding_idx, ) logging_output = { "loss": loss if self.tpu else loss.data, "ntokens": sample["ntokens"], "nsentences": sample["nsentences"], "sample_size": sample_size, } return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ # compute MLM loss masked_tokens = sample['target'].ne(self.padding_idx) # Rare: when all tokens are masked, project all tokens. # We use torch.where to avoid device-to-host transfers, # except on CPU where torch.where is not well supported # (see github.com/pytorch/pytorch/issues/26247). if masked_tokens.device == torch.device('cpu'): if not masked_tokens.any(): masked_tokens.fill_(True) else: masked_tokens = torch.where( masked_tokens.any(), masked_tokens, masked_tokens.new([True]), ) logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0] targets = model.get_targets(sample, [logits]) targets = targets[masked_tokens] loss = modules.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), reduction='sum', ignore_index=self.padding_idx, ) sample_size = masked_tokens.int().sum() logging_output = { 'loss': loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['nsentences'], 'sample_size': sample_size, } return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ non_pad = sample['target'].ne( self.padding_idx ) # select labels that corespond to start of word bpe if hasattr(model, 'tagging_heads') and 'tagging_head' in model.tagging_heads: logits, _ = model(**sample['net_input'], features_only=True, tagging_head_name='tagging_head', non_pad=non_pad) else: logits = model(**sample['net_input'], non_pad=non_pad)[0] targets = model.get_targets(sample, [logits])[non_pad] loss = modules.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), reduction='sum', ignore_index=self.padding_idx, ) sample_size = targets.ne(self.padding_idx).int().sum() logging_output = { 'loss': loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['nsentences'], 'sample_size': sample_size, } return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ masked_code = sample["target"]['tgt_tokens'][configs.static_field].ne( self.padding_idx_dict[configs.static_field]) masked_value = sample["target"]['tgt_values'][ configs.byte_fields[0]].ne( self.padding_idx_dict[configs.byte_fields[0]]) # Rare: when all tokens are not masked, project all tokens. # We use torch.where to avoid device-to-host transfers, # except on CPU where torch.where is not well supported # (see github.com/pytorch/pytorch/issues/26247). if self.tpu: masked_code = None # always project all tokens on TPU masked_value = None # always project all tokens on TPU elif masked_code.device == torch.device("cpu"): if not masked_code.any(): masked_code = None if not masked_value.any(): masked_value = None else: masked_code = torch.where( masked_code.any(), masked_code, masked_code.new([True]), ) masked_value = torch.where( masked_value.any(), masked_value, masked_value.new([True]), ) output = model(**sample["net_input"], masked_code=masked_code, masked_value=masked_value)[0] pred_logits_code, pred_value = output['code'], output['value'] targets_code, targets_value = sample["target"]["tgt_tokens"], sample[ "target"]["tgt_values"] if masked_code is not None: targets_code = targets_code[configs.static_field][masked_code] if masked_value is not None: targets_value_stacked = torch.stack([ targets_value[field][masked_value] for field in configs.byte_fields ], dim=1) sample_size_code = masked_code.int().sum() sample_size_value = masked_value.int().sum() * configs.byte_len sample_size = sample_size_code + sample_size_value code_loss = modules.cross_entropy( pred_logits_code.view(-1, pred_logits_code.size(-1)), targets_code.view(-1), reduction="sum", ignore_index=self.padding_idx_dict[configs.static_field], ) value_loss = F.mse_loss(pred_value.float(), targets_value_stacked.float(), reduction='sum') loss = code_loss + value_loss if random.random( ) < 0.001: # only randomly log some prediction in case screen flushing for i, field in enumerate(configs.byte_fields): print( f'{field} tgt value:', targets_value[field][masked_value].view(-1)[5:10].tolist()) print(f'{field} pred value:', pred_value[5:10, i].view(-1).tolist()) targets_code_idx = targets_code.view(-1)[5:10] pred_code_idx = torch.argmax(pred_logits_code.view( -1, pred_logits_code.size(-1))[5:10], dim=-1) print( f'tgt code:', self.task.source_dictionary[ configs.static_field].string(targets_code_idx)) print( f'pred code:', self.task.source_dictionary[ configs.static_field].string(pred_code_idx)) logging_output = { "loss": loss.data, 'code_loss': code_loss.data, 'value_loss': value_loss.data, "ntokens": sample["ntokens"], "nsentences": sample["nsentences"], "sample_size": sample_size, "sample_size_code": sample_size_code, "sample_size_value": sample_size_value, } return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ assert ( hasattr(model, 'classification_heads') and self.classification_head_name in model.classification_heads ), 'model must provide sentence classification head for --criterion=sentence_prediction' if 'parallel_data_mask' in sample: parallel_data_mask = sample['parallel_data_mask'].ne(self.padding_idx) else: parallel_data_mask = None logits, extra = model( sample['net_input']['src_tokens'], features_only=True, classification_head_name=self.classification_head_name, target_mask=sample['target_mask'], parallel_data_mask=parallel_data_mask, parallel_data=sample['net_input']['parallel_src_tokens'] if parallel_data_mask is not None else None, ) targets = model.get_targets(sample, [logits]).view(-1) # K (K=\sum_i B_i) sample_size = targets.numel() target_lengths = sample['target_lengths'] assert sum(target_lengths) == sample_size lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) if self.upweight_minority_labels: loss = F.nll_loss(lprobs, targets, reduction='sum', weight=torch.FloatTensor([1., 2.]).cuda()) else: loss = F.nll_loss(lprobs, targets, reduction='sum') if parallel_data_mask is not None: # compute masked LM loss on the target side masked_logits = extra parallel_target = sample['parallel_target'] target_mask = parallel_target.ne(self.padding_idx) total_tokens = target_mask.int().sum() parallel_target = parallel_target[target_mask] masked_prediction_loss = modules.cross_entropy( masked_logits.view(-1, masked_logits.size(-1)), parallel_target.view(-1), reduction='sum', ignore_index=self.padding_idx, ) masked_lm_loss = masked_prediction_loss / total_tokens hallucination_pred_loss = loss / sample_size loss = hallucination_pred_loss + self.masked_lm_loss_weight * masked_lm_loss logging_output = { 'loss': loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample_size, 'sample_size': sample_size if self.masked_lm_loss_weight <= 0 else 1, } preds = logits.argmax(dim=1) nt_correct = sum([1 for p, t in zip(preds, targets) if p.item() == 1 and t.item() == 1]) nf_correct = sum([1 for p, t in zip(preds, targets) if p.item() == 0 and t.item() == 0]) nt_precision_denom = sum(preds == 1) nt_recall_denom = sum(targets == 1) nf_precision_denom = sum(preds == 0) nf_recall_denom = sum(targets == 0) logging_output['ncorrect'] = (preds == targets).sum() logging_output['nt_correct'] = nt_correct logging_output['nf_correct'] = nf_correct logging_output['nt_precision_denom'] = nt_precision_denom logging_output['nt_recall_denom'] = nt_recall_denom logging_output['nf_precision_denom'] = nf_precision_denom logging_output['nf_recall_denom'] = nf_recall_denom if parallel_data_mask is not None: logging_output['hallucination_pred_loss'] = hallucination_pred_loss.data logging_output['masked_lm_loss'] = masked_lm_loss.data return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ # compute MLM loss masked_tokens = sample['target'].ne(self.padding_idx) # Rare: when all tokens are masked, project all tokens. # We use torch.where to avoid device-to-host transfers, # except on CPU where torch.where is not well supported # (see github.com/pytorch/pytorch/issues/26247). if masked_tokens.device == torch.device('cpu'): if not masked_tokens.any(): masked_tokens.fill_(True) else: masked_tokens = torch.where( masked_tokens.any(), masked_tokens, masked_tokens.new([True]), ) logits_student = model(**sample['net_input'], masked_tokens=masked_tokens)[0] with torch.no_grad(): logits_teacher = self.teacher_model(**sample['net_input'], masked_tokens=masked_tokens)[0] targets = model.get_targets(sample, [logits_student]) targets = targets[masked_tokens] loss_ce = modules.cross_entropy( logits_student.view(-1, logits_student.size(-1)), targets.view(-1), reduction='sum', ignore_index=self.padding_idx, ) if self.print_teacher_loss: loss_ce_teacher = modules.cross_entropy( logits_teacher.view(-1, logits_teacher.size(-1)), targets.view(-1), reduction='sum', ignore_index=self.padding_idx, ) # KD loss below loss_kd = self.kd_loss_func( F.log_softmax(logits_student / self.T, dim=-1), F.softmax(logits_teacher / self.T, dim=-1)) * self.T**2 loss = (1 - self.beta) * loss_ce + self.beta * loss_kd sample_size = masked_tokens.int().sum() logging_output = { 'loss': loss.data, 'ce_loss': loss_ce.data, 'kd_loss': loss_kd.data, 'ntokens': sample['ntokens'], 'nsentences': sample['nsentences'], 'sample_size': sample_size, } if self.print_teacher_loss: logging_output['ce_loss_teacher'] = loss_ce_teacher return loss, sample_size, logging_output
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ masked_tokens = sample['target'].ne(self.padding_idx) sample_size_mask = masked_tokens.int().sum() decode_tokens=sample['decode_target'].ne(self.padding_idx) sample_size_decode=decode_tokens.int().sum() # Rare: when all tokens are masked, project all tokens. # We use torch.where to avoid device-to-host transfers, # except on CPU where torch.where is not well supported # (see github.com/pytorch/pytorch/issues/26247). if self.tpu: masked_tokens = None # always project all tokens on TPU elif masked_tokens.device == torch.device('cpu'): if not masked_tokens.any(): masked_tokens = None if not decode_tokens.any(): decode_tokens=None else: masked_tokens = torch.where( masked_tokens.any(), masked_tokens, masked_tokens.new([True]), ) decode_tokens = torch.where( decode_tokens.any(), decode_tokens, decode_tokens.new([True]), ) logits, logits_decode, _ = model(**sample['net_input'], masked_tokens=masked_tokens, ) targets = model.get_targets(sample, [logits]) if masked_tokens is not None: targets = targets[masked_tokens] #print('???',logits_decode.shape) decode_target=sample["decode_target"] if decode_tokens is not None: if logits_decode.shape[1]!=decode_target.shape[1]: print(decode_target) print(sample['net_input']['src_tokens']) decode_target=decode_target[decode_tokens] logits_decode=logits_decode[decode_tokens] mask_loss = modules.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), reduction='sum', ignore_index=self.padding_idx, ) decode_loss = modules.cross_entropy( logits_decode.view(-1, logits_decode.size(-1)), decode_target.view(-1), reduction='sum', ignore_index=self.padding_idx, ) accumulate_step = sample['accumulate_step'] logging_output = { #'loss': loss if self.tpu else loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['nsentences'], #'sample_size': sample_size, 'loss_decode' : decode_loss if self.tpu else decode_loss.data, 'loss_mask' : mask_loss if self.tpu else mask_loss.data, 'sample_size_decode':sample_size_decode , 'sample_size_mask': sample_size_mask, 'sample_size': sample_size_mask, 'sample_size_t': 1.0/accumulate_step, 'loss' : mask_loss if self.tpu else mask_loss.data, } sample_size_mask = sample['sample_size_mask'] sample_size_decode = sample['sample_size_decode'] decode_loss=decode_loss/sample_size_decode mask_loss=mask_loss/sample_size_mask loss=0.5*mask_loss+0.5*decode_loss #print('???',decode_loss,mask_loss) return loss, 1.0/accumulate_step, logging_output
def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ fields = configs.fields # compute MLM loss masked_tokens = sample['target'][fields[0]].ne( self.padding_idx_dict[fields[0]]) sample_size = masked_tokens.int().sum().item() # (Rare case) When all tokens are masked, the model results in empty # tensor and gives CUDA error. # if sample_size == 0: # masked_tokens = None masked_tokens = torch.where( masked_tokens.any(), masked_tokens, masked_tokens.new([True]), ) logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0] targets = model.get_targets(sample, [logits]) # Which field to predict output_langs = self.args.output_lang.split(',') trace_weight = float(self.args.trace_weight) for output_lang in output_langs: assert output_lang in logits.keys() loss = 0 for field in output_langs: if masked_tokens is not None: targets[field] = targets[field][masked_tokens] if field == configs.fields[0]: # static code loss loss += modules.cross_entropy( logits[field].view(-1, logits[field].size(-1)), targets[field].view(-1), reduction='sum', ignore_index=self.padding_idx_dict[field], ) else: loss += trace_weight * modules.cross_entropy( logits[field].view(-1, logits[field].size(-1)), targets[field].view(-1), reduction='sum', ignore_index=self.padding_idx_dict[field], ) logging_output = { 'loss': loss.data / len(output_langs), 'ntokens': sample['ntokens'], 'nsentences': sample['nsentences'], 'sample_size': sample_size, } return loss, sample_size, logging_output