示例#1
0
    def forward(self, input_ids_src, input_ids_tgt, output_ids, *args,
                **kwargs):

        ret = self.model(input_ids_src,
                         attention_mask=input_ids_src !=
                         self.model.encoderconfig.pad_token_id,
                         decoder_input_ids=output_ids)
        probs_src, states_src = ret[0], ret[-1]
        _, predactions = probs_src.max(-1)
        outputs = [
            metric(probs_src, predactions, output_ids[:, 1:])
            for metric in self.metrics
        ]
        outputs_src = merge_metric_dicts(*outputs)

        ret = self.model2(input_ids_tgt,
                          attention_mask=input_ids_tgt !=
                          self.model2.encoderconfig.pad_token_id,
                          decoder_input_ids=output_ids)
        probs_tgt, states_tgt = ret[0], ret[-1]
        _, predactions = probs_tgt.max(-1)
        outputs = [
            metric(probs_tgt, predactions, output_ids[:, 1:])
            for metric in self.metrics
        ]
        outputs_tgt = merge_metric_dicts(*outputs)

        outputs = {
            k: (outputs_src[k] + outputs_tgt[k]) / 2
            for k in outputs_src
        }

        # penalties:
        outmask = (output_ids != self.accs.padid).float()
        if self.statesimweight > 0:
            statediff = (states_src - states_tgt)
            statediff = torch.norm(statediff, 2, -1)
            statediff = (statediff * outmask).sum(-1) / outmask.sum(
                -1)  # mean/sum over seqlen with decoder mask
            statediff = statediff.mean()
            outputs["statediff"] = statediff
            outputs["loss"] = outputs["loss"] + statediff * self.statesimweight

        if self.probsimweight > 0:
            # probdiff = (probs_src - probs_tgt)
            # probdiff = torch.norm(probdiff, 2, -1)
            m = (torch.softmax(probs_src, -1) +
                 torch.softmax(probs_tgt, -1)) / 2.
            a = self.kldiv(torch.log_softmax(probs_src, -1), m).sum(-1)
            b = self.kldiv(torch.log_softmax(probs_tgt, -1), m).sum(-1)
            probdiff = (a + b) / 2.

            probdiff = (probdiff * outmask).sum(-1) / outmask.sum(
                -1)  # mean/sum over seqlen with decoder mask
            probdiff = probdiff.mean()
            outputs["probdiff"] = probdiff
            outputs["loss"] = outputs["loss"] + probdiff * self.probsimweight
        return outputs, ret
示例#2
0
 def forward(self, input_ids, output_ids, abs_output_ids, *args, **kwargs):
     ret = self.model(input_ids, attention_mask=input_ids!=self.model.config.pad_token_id, decoder_input_ids=output_ids)
     probs, absprobs = ret[0], ret[1]
     _, predactions = probs.max(-1)
     _, abspredactions = absprobs.max(-1)
     outputs = [metric(probs, predactions, output_ids[:, 1:]) for metric in self.metrics]
     absoutputs = [metric(absprobs, abspredactions, abs_output_ids[:, 1:]) for metric in self.absmetrics]
     outputs = merge_metric_dicts(*outputs)
     absoutputs = merge_metric_dicts(*absoutputs)
     absoutputs = {f"abs_{k}": v for k, v in absoutputs.items()}
     outputs.update(absoutputs)
     outputs["loss"] = outputs["loss"] + outputs["abs_loss"]
     return outputs, ret
示例#3
0
    def forward(self, x: RankState):
        inp_tensors = x.inp_tensor
        cand_tensors = x.candtensors
        alignments = x.alignments
        align_entropies = x.alignment_entropies

        candscores = self.model(inp_tensors, cand_tensors, alignments,
                                align_entropies)
        _, candpred = candscores.max(-1)
        # _, candpred = x.candgold.max(-1)
        candgold = x.candgold.to(torch.float)

        candtensors = x.candtensors
        pred_tensor = candtensors.gather(
            1, candpred[:, None, None].repeat(1, 1, candtensors.size(2)))[:, 0]
        gold_tensor = x.gold_tensor

        rank_metrics = [
            metric(candscores, candpred, candgold, x)
            for metric in self._metrics
        ]
        seq_metrics = [
            metric(None, pred_tensor, gold_tensor, x)
            for metric in self._seq_metrics
        ]
        metrics = merge_metric_dicts(*(rank_metrics + seq_metrics))
        return metrics, x
示例#4
0
 def forward(self, input_ids, output_ids, *args, **kwargs):
     ret = self.model(input_ids, attention_mask=input_ids!=self.model.config.pad_token_id, decoder_input_ids=output_ids)
     probs = ret[0]
     _, predactions = probs.max(-1)
     outputs = [metric(probs, predactions, output_ids[:, 1:]) for metric in self.metrics]
     outputs = merge_metric_dicts(*outputs)
     return outputs, ret
    def forward(self, x: State):
        inpseq = x.inp_tensor
        position_ids = torch.arange(inpseq.size(1),
                                    dtype=torch.long,
                                    device=inpseq.device)[None, :].repeat(
                                        inpseq.size(0), 1)
        inpseq = torch.cat([
            inpseq,
            torch.arange(
                self.maxoutlen, dtype=inpseq.dtype,
                device=inpseq.device)[None, :].repeat(inpseq.size(0), 1) +
            self._numinpids
        ], 1)
        position_ids_out = torch.arange(self.maxoutlen,
                                        dtype=torch.long,
                                        device=inpseq.device)[None, :].repeat(
                                            inpseq.size(0), 1) + self.maxinplen
        position_ids = torch.cat([position_ids, position_ids_out], 1)
        attention_mask = (inpseq != 0)
        y = self.tm(inpseq,
                    attention_mask=attention_mask,
                    position_ids=position_ids)
        outprobs = self.out(y[0])
        outprobs = outprobs[:, self.maxinplen:]
        _, predactions = outprobs.max(-1)

        metrics = [
            metric(outprobs, predactions, x) for metric in self._metrics
        ]
        metrics = merge_metric_dicts(*metrics)
        return metrics, x
示例#6
0
 def forward(self, input_ids, output_ids, *args, **kwargs):
     ret = self.model.generate(input_ids,
                               decoder_input_ids=output_ids[:, 0:1],
                               attention_mask=input_ids!=self.model.config.pad_token_id,
                               max_length=self.maxlen,
                               num_beams=self.numbeam)
     outputs = [metric(None, ret[:, 1:], output_ids[:, 1:]) for metric in self.metrics]
     outputs = merge_metric_dicts(*outputs)
     return outputs, ret
    def forward(self, input_ids, output_ids, adv_output_ids, *args, **kwargs):
        ret = self.model(
            input_ids,
            attention_mask=input_ids != self.model.config.pad_token_id,
            decoder_input_ids=output_ids)
        probs = ret[0]
        advret = self.advmodel(
            input_ids,
            attention_mask=input_ids != self.model.config.pad_token_id,
            decoder_input_ids=adv_output_ids)
        advprobs = advret[0]
        lmret = self.lm_model(input_ids, decoder_input_ids=adv_output_ids)
        lmprobs = lmret[0]
        _, predactions = probs.max(-1)

        outputs = [
            metric(probs, predactions, output_ids[:, 1:])
            for metric in self.metrics
        ]
        mask = (output_ids == self.absid)[:, 1:]
        entropy = self.kl(advprobs, _, lmprobs, mask)
        outputs.append(entropy)
        outputs = merge_metric_dicts(*outputs)
        return outputs, ret + advret