コード例 #1
0
    def forward(self, features, masked_tokens=None, tags=None, **kwargs):
        """
        Args:
            features:       (seq_length, batch_size, hidden_dim)
            masked_tokens:  (seq_length, batch_size)
            tags:           (seq_length, batch_size)

        Return:
             nll_loss:
        """
        x = self.dropout(features)
        x = self.dense(x)
        x = self.activation_fn(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        ncorrect = None
        if self.use_crf:
            nll_loss = -self.crf_proj(
                emissions=x, tags=tags, mask=masked_tokens)
        else:
            x = x[masked_tokens]
            tags = tags[masked_tokens]
            nll_loss = cross_entropy(
                x.view(-1, x.size(-1)),
                tags.view(-1),
                reduction='sum',
            )
            preds = x.argmax(dim=1)
            ncorrect = (preds == tags).sum()
        return nll_loss, ncorrect
コード例 #2
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        # compute MLM loss
        masked_tokens = sample['target'].ne(self.padding_idx)
        sample_size = masked_tokens.int().sum().item()

        # (Rare case) When all tokens are masked, the model results in empty
        # tensor and gives CUDA error.

        # if sample_size == 0:
        #     print(sample['target'], sample['net_input']['src_tokens'])
        #     exit()
        #     masked_tokens = None

        masked_tokens = torch.where(
            masked_tokens.any(),
            masked_tokens,
            masked_tokens.new([True]),
        )

        logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
        targets = model.get_targets(sample, [logits])
        if masked_tokens is not None:
            targets = targets[masked_tokens]

        # if sample_size != 0:
        #     targets = targets[masked_tokens]

        # loss = F.nll_loss(
        #     F.log_softmax(
        #         logits.view(-1, logits.size(-1)),
        #         dim=-1,
        #         dtype=torch.float32,
        #     ),
        #     targets.view(-1),
        #     reduction='sum',
        #     ignore_index=self.padding_idx,
        # )

        loss = modules.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            reduction='sum',
            ignore_index=self.padding_idx,
        )

        logging_output = {
            'loss': loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #3
0
 def _compute_clm_loss(self, model, net_output, sample, masked_tokens):
     logits = net_output[1]['clm_out']
     targets = model.get_clm_targets(sample, net_output)[masked_tokens]
     loss = modules.cross_entropy(
         logits.view(-1, logits.size(-1)),
         targets.view(-1),
         reduction='sum',
         ignore_index=self.padding_idx,
     )
     return loss
コード例 #4
0
    def compute_masked_loss(self, targets, net_output):

        encoder_logits = net_output[1]['masked_encoder_out'][0]
        assert encoder_logits.size(0) == targets.size(0), (
            encoder_logits.size(), targets.size())
        loss = modules.cross_entropy(
            encoder_logits.view(-1, encoder_logits.size(-1)),
            targets.view(-1),
            reduction="sum",
            ignore_index=self.padding_idx,
        )

        return loss
コード例 #5
0
ファイル: masked_lm.py プロジェクト: DylanZSZ/fairseq
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        masked_tokens = sample["target"].ne(self.padding_idx)
        sample_size = masked_tokens.int().sum()

        # Rare: when all tokens are masked, project all tokens.
        # We use torch.where to avoid device-to-host transfers,
        # except on CPU where torch.where is not well supported
        # (see github.com/pytorch/pytorch/issues/26247).
        if self.tpu:
            masked_tokens = None  # always project all tokens on TPU
        elif masked_tokens.device == torch.device("cpu"):
            if not masked_tokens.any():
                masked_tokens = None
        else:
            masked_tokens = torch.where(
                masked_tokens.any(),
                masked_tokens,
                masked_tokens.new([True]),
            )

        logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0]
        targets = model.get_targets(sample, [logits])
        if masked_tokens is not None:
            targets = targets[masked_tokens]

        loss = modules.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            #reduction="sum",
            reduction="max",
            ignore_index=self.padding_idx,
        )

        logging_output = {
            "loss": loss if self.tpu else loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["nsentences"],
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output
コード例 #6
0
ファイル: masked_lm.py プロジェクト: NJUNLP/TMM-for-MAMS
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        # compute MLM loss
        masked_tokens = sample['target'].ne(self.padding_idx)

        # Rare: when all tokens are masked, project all tokens.
        # We use torch.where to avoid device-to-host transfers,
        # except on CPU where torch.where is not well supported
        # (see github.com/pytorch/pytorch/issues/26247).
        if masked_tokens.device == torch.device('cpu'):
            if not masked_tokens.any():
                masked_tokens.fill_(True)
        else:
            masked_tokens = torch.where(
                masked_tokens.any(),
                masked_tokens,
                masked_tokens.new([True]),
            )

        logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
        targets = model.get_targets(sample, [logits])
        targets = targets[masked_tokens]

        loss = modules.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            reduction='sum',
            ignore_index=self.padding_idx,
        )

        sample_size = masked_tokens.int().sum()
        logging_output = {
            'loss': loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #7
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        non_pad = sample['target'].ne(
            self.padding_idx
        )  # select labels that corespond to start of word bpe

        if hasattr(model,
                   'tagging_heads') and 'tagging_head' in model.tagging_heads:
            logits, _ = model(**sample['net_input'],
                              features_only=True,
                              tagging_head_name='tagging_head',
                              non_pad=non_pad)
        else:
            logits = model(**sample['net_input'], non_pad=non_pad)[0]

        targets = model.get_targets(sample, [logits])[non_pad]
        loss = modules.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            reduction='sum',
            ignore_index=self.padding_idx,
        )

        sample_size = targets.ne(self.padding_idx).int().sum()
        logging_output = {
            'loss': loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #8
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        masked_code = sample["target"]['tgt_tokens'][configs.static_field].ne(
            self.padding_idx_dict[configs.static_field])
        masked_value = sample["target"]['tgt_values'][
            configs.byte_fields[0]].ne(
                self.padding_idx_dict[configs.byte_fields[0]])

        # Rare: when all tokens are not masked, project all tokens.
        # We use torch.where to avoid device-to-host transfers,
        # except on CPU where torch.where is not well supported
        # (see github.com/pytorch/pytorch/issues/26247).
        if self.tpu:
            masked_code = None  # always project all tokens on TPU
            masked_value = None  # always project all tokens on TPU
        elif masked_code.device == torch.device("cpu"):
            if not masked_code.any():
                masked_code = None
            if not masked_value.any():
                masked_value = None
        else:
            masked_code = torch.where(
                masked_code.any(),
                masked_code,
                masked_code.new([True]),
            )
            masked_value = torch.where(
                masked_value.any(),
                masked_value,
                masked_value.new([True]),
            )

        output = model(**sample["net_input"],
                       masked_code=masked_code,
                       masked_value=masked_value)[0]

        pred_logits_code, pred_value = output['code'], output['value']
        targets_code, targets_value = sample["target"]["tgt_tokens"], sample[
            "target"]["tgt_values"]

        if masked_code is not None:
            targets_code = targets_code[configs.static_field][masked_code]

        if masked_value is not None:
            targets_value_stacked = torch.stack([
                targets_value[field][masked_value]
                for field in configs.byte_fields
            ],
                                                dim=1)

        sample_size_code = masked_code.int().sum()
        sample_size_value = masked_value.int().sum() * configs.byte_len
        sample_size = sample_size_code + sample_size_value

        code_loss = modules.cross_entropy(
            pred_logits_code.view(-1, pred_logits_code.size(-1)),
            targets_code.view(-1),
            reduction="sum",
            ignore_index=self.padding_idx_dict[configs.static_field],
        )

        value_loss = F.mse_loss(pred_value.float(),
                                targets_value_stacked.float(),
                                reduction='sum')

        loss = code_loss + value_loss

        if random.random(
        ) < 0.001:  # only randomly log some prediction in case screen flushing
            for i, field in enumerate(configs.byte_fields):
                print(
                    f'{field} tgt value:',
                    targets_value[field][masked_value].view(-1)[5:10].tolist())
                print(f'{field} pred value:', pred_value[5:10,
                                                         i].view(-1).tolist())

            targets_code_idx = targets_code.view(-1)[5:10]
            pred_code_idx = torch.argmax(pred_logits_code.view(
                -1, pred_logits_code.size(-1))[5:10],
                                         dim=-1)
            print(
                f'tgt code:', self.task.source_dictionary[
                    configs.static_field].string(targets_code_idx))
            print(
                f'pred code:', self.task.source_dictionary[
                    configs.static_field].string(pred_code_idx))

        logging_output = {
            "loss": loss.data,
            'code_loss': code_loss.data,
            'value_loss': value_loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["nsentences"],
            "sample_size": sample_size,
            "sample_size_code": sample_size_code,
            "sample_size_value": sample_size_value,
        }
        return loss, sample_size, logging_output
コード例 #9
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        assert (
                hasattr(model, 'classification_heads')
                and self.classification_head_name in model.classification_heads
        ), 'model must provide sentence classification head for --criterion=sentence_prediction'

        if 'parallel_data_mask' in sample:
            parallel_data_mask = sample['parallel_data_mask'].ne(self.padding_idx)
        else:
            parallel_data_mask = None
        logits, extra = model(
            sample['net_input']['src_tokens'],
            features_only=True,
            classification_head_name=self.classification_head_name,
            target_mask=sample['target_mask'],
            parallel_data_mask=parallel_data_mask,
            parallel_data=sample['net_input']['parallel_src_tokens'] if parallel_data_mask is not None else None,
        )

        targets = model.get_targets(sample, [logits]).view(-1)  # K (K=\sum_i B_i)
        sample_size = targets.numel()

        target_lengths = sample['target_lengths']
        assert sum(target_lengths) == sample_size

        lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
        if self.upweight_minority_labels:
            loss = F.nll_loss(lprobs, targets, reduction='sum', weight=torch.FloatTensor([1., 2.]).cuda())
        else:
            loss = F.nll_loss(lprobs, targets, reduction='sum')

        if parallel_data_mask is not None:
            # compute masked LM loss on the target side
            masked_logits = extra
            parallel_target = sample['parallel_target']
            target_mask = parallel_target.ne(self.padding_idx)
            total_tokens = target_mask.int().sum()
            parallel_target = parallel_target[target_mask]
            masked_prediction_loss = modules.cross_entropy(
                masked_logits.view(-1, masked_logits.size(-1)),
                parallel_target.view(-1),
                reduction='sum',
                ignore_index=self.padding_idx,
            )
            masked_lm_loss = masked_prediction_loss / total_tokens
            hallucination_pred_loss = loss / sample_size
            loss = hallucination_pred_loss + self.masked_lm_loss_weight * masked_lm_loss

        logging_output = {
            'loss': loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample_size,
            'sample_size': sample_size if self.masked_lm_loss_weight <= 0 else 1,
        }
        preds = logits.argmax(dim=1)

        nt_correct = sum([1 for p, t in zip(preds, targets) if p.item() == 1 and t.item() == 1])
        nf_correct = sum([1 for p, t in zip(preds, targets) if p.item() == 0 and t.item() == 0])
        nt_precision_denom = sum(preds == 1)
        nt_recall_denom = sum(targets == 1)
        nf_precision_denom = sum(preds == 0)
        nf_recall_denom = sum(targets == 0)

        logging_output['ncorrect'] = (preds == targets).sum()
        logging_output['nt_correct'] = nt_correct
        logging_output['nf_correct'] = nf_correct
        logging_output['nt_precision_denom'] = nt_precision_denom
        logging_output['nt_recall_denom'] = nt_recall_denom
        logging_output['nf_precision_denom'] = nf_precision_denom
        logging_output['nf_recall_denom'] = nf_recall_denom

        if parallel_data_mask is not None:
            logging_output['hallucination_pred_loss'] = hallucination_pred_loss.data
            logging_output['masked_lm_loss'] = masked_lm_loss.data

        return loss, sample_size, logging_output
コード例 #10
0
ファイル: masked_lm_distill.py プロジェクト: yrchen92/CoDIR
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        # compute MLM loss
        masked_tokens = sample['target'].ne(self.padding_idx)

        # Rare: when all tokens are masked, project all tokens.
        # We use torch.where to avoid device-to-host transfers,
        # except on CPU where torch.where is not well supported
        # (see github.com/pytorch/pytorch/issues/26247).
        if masked_tokens.device == torch.device('cpu'):
            if not masked_tokens.any():
                masked_tokens.fill_(True)
        else:
            masked_tokens = torch.where(
                masked_tokens.any(),
                masked_tokens,
                masked_tokens.new([True]),
            )

        logits_student = model(**sample['net_input'],
                               masked_tokens=masked_tokens)[0]
        with torch.no_grad():
            logits_teacher = self.teacher_model(**sample['net_input'],
                                                masked_tokens=masked_tokens)[0]
        targets = model.get_targets(sample, [logits_student])
        targets = targets[masked_tokens]

        loss_ce = modules.cross_entropy(
            logits_student.view(-1, logits_student.size(-1)),
            targets.view(-1),
            reduction='sum',
            ignore_index=self.padding_idx,
        )
        if self.print_teacher_loss:
            loss_ce_teacher = modules.cross_entropy(
                logits_teacher.view(-1, logits_teacher.size(-1)),
                targets.view(-1),
                reduction='sum',
                ignore_index=self.padding_idx,
            )
        # KD loss below
        loss_kd = self.kd_loss_func(
            F.log_softmax(logits_student / self.T, dim=-1),
            F.softmax(logits_teacher / self.T, dim=-1)) * self.T**2

        loss = (1 - self.beta) * loss_ce + self.beta * loss_kd

        sample_size = masked_tokens.int().sum()
        logging_output = {
            'loss': loss.data,
            'ce_loss': loss_ce.data,
            'kd_loss': loss_kd.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            'sample_size': sample_size,
        }
        if self.print_teacher_loss:
            logging_output['ce_loss_teacher'] = loss_ce_teacher
        return loss, sample_size, logging_output
コード例 #11
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        masked_tokens = sample['target'].ne(self.padding_idx)
        sample_size_mask = masked_tokens.int().sum()

        decode_tokens=sample['decode_target'].ne(self.padding_idx)
        sample_size_decode=decode_tokens.int().sum()

        # Rare: when all tokens are masked, project all tokens.
        # We use torch.where to avoid device-to-host transfers,
        # except on CPU where torch.where is not well supported
        # (see github.com/pytorch/pytorch/issues/26247).
        if self.tpu:
            masked_tokens = None  # always project all tokens on TPU
        elif masked_tokens.device == torch.device('cpu'):
            if not masked_tokens.any():
                masked_tokens = None
            if not decode_tokens.any():
                decode_tokens=None
        else:
            masked_tokens = torch.where(
                masked_tokens.any(),
                masked_tokens,
                masked_tokens.new([True]),
            )
            decode_tokens = torch.where(
                decode_tokens.any(),
                decode_tokens,
                decode_tokens.new([True]),
            )

        logits, logits_decode, _ = model(**sample['net_input'], masked_tokens=masked_tokens, )
        targets = model.get_targets(sample, [logits])
        if masked_tokens is not None:
            targets = targets[masked_tokens]

        #print('???',logits_decode.shape)
        decode_target=sample["decode_target"]
        if decode_tokens is not None:
            if logits_decode.shape[1]!=decode_target.shape[1]:
                print(decode_target)
                print(sample['net_input']['src_tokens'])
            decode_target=decode_target[decode_tokens]
            logits_decode=logits_decode[decode_tokens]


        mask_loss = modules.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            reduction='sum',
            ignore_index=self.padding_idx,
        )

        decode_loss = modules.cross_entropy(
            logits_decode.view(-1, logits_decode.size(-1)),
            decode_target.view(-1),
            reduction='sum',
            ignore_index=self.padding_idx,
        )
        
        accumulate_step = sample['accumulate_step']

        logging_output = {
            #'loss': loss if self.tpu else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            #'sample_size': sample_size,
            'loss_decode' : decode_loss if self.tpu else decode_loss.data,
            'loss_mask' : mask_loss if self.tpu else mask_loss.data,
            'sample_size_decode':sample_size_decode ,
            'sample_size_mask': sample_size_mask,
            'sample_size': sample_size_mask,
            'sample_size_t': 1.0/accumulate_step,
            'loss' : mask_loss if self.tpu else mask_loss.data,
        }

        sample_size_mask = sample['sample_size_mask']
        sample_size_decode = sample['sample_size_decode']
        

        decode_loss=decode_loss/sample_size_decode
        mask_loss=mask_loss/sample_size_mask
        loss=0.5*mask_loss+0.5*decode_loss

        #print('???',decode_loss,mask_loss)

        return loss, 1.0/accumulate_step, logging_output
コード例 #12
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        fields = configs.fields
        # compute MLM loss
        masked_tokens = sample['target'][fields[0]].ne(
            self.padding_idx_dict[fields[0]])
        sample_size = masked_tokens.int().sum().item()

        # (Rare case) When all tokens are masked, the model results in empty
        # tensor and gives CUDA error.
        # if sample_size == 0:
        #     masked_tokens = None

        masked_tokens = torch.where(
            masked_tokens.any(),
            masked_tokens,
            masked_tokens.new([True]),
        )

        logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
        targets = model.get_targets(sample, [logits])

        # Which field to predict
        output_langs = self.args.output_lang.split(',')
        trace_weight = float(self.args.trace_weight)
        for output_lang in output_langs:
            assert output_lang in logits.keys()

        loss = 0
        for field in output_langs:

            if masked_tokens is not None:
                targets[field] = targets[field][masked_tokens]

            if field == configs.fields[0]:  # static code loss
                loss += modules.cross_entropy(
                    logits[field].view(-1, logits[field].size(-1)),
                    targets[field].view(-1),
                    reduction='sum',
                    ignore_index=self.padding_idx_dict[field],
                )
            else:
                loss += trace_weight * modules.cross_entropy(
                    logits[field].view(-1, logits[field].size(-1)),
                    targets[field].view(-1),
                    reduction='sum',
                    ignore_index=self.padding_idx_dict[field],
                )

        logging_output = {
            'loss': loss.data / len(output_langs),
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output