Ejemplo n.º 1
0
    def forward(self, ys, ylens=None, labels=None, ps=None, plens=None):
        if ylens is None:
            attention_mask = None
        else:
            attention_mask = make_nopad_mask(ylens).float().to(ys.device)
            ys = ys[:, :max(ylens)]  # DataParallel

        gloss, glogits = self.gmodel(ys,
                                     attention_mask=attention_mask,
                                     labels=labels)

        generated_ids = ys.clone()
        masked_indices = labels.long() != -100
        original_ids = ys.clone()
        original_ids[masked_indices] = labels[masked_indices]
        sample_ids = sample_temp(glogits)  # sampling
        generated_ids[masked_indices] = sample_ids[masked_indices]
        labels_replaced = (generated_ids.long() != original_ids.long()).long()

        dloss, dlogits = self.dmodel(generated_ids,
                                     attention_mask=attention_mask,
                                     labels=labels_replaced)

        loss = gloss + self.electra_disc_weight * dloss
        loss_dict = {}

        loss_dict["loss_gen"] = gloss
        loss_dict["loss_disc"] = dloss
        loss_dict["num_replaced"] = labels_replaced.sum().long() / ys.size(0)
        loss_dict["num_masked"] = masked_indices.sum().long() / ys.size(0)

        return loss, loss_dict
Ejemplo n.º 2
0
    def forward(
        self,
        eouts,
        elens,
        eouts_inter=None,
        ys=None,
        ylens=None,
        ys_in=None,
        ys_out=None,  # labels
        soft_labels=None,
        ps=None,
        plens=None,
    ):
        loss = 0
        loss_dict = {}

        bs = eouts.size(0)

        ys_emb = self.dropout_emb(self.embed(ys_in))
        dstate = None
        # context vector
        ctx = eouts.new_zeros(bs, 1, self.enc_hidden_size)
        attn_weight = None
        attn_mask = make_nopad_mask(elens).unsqueeze(2)
        logits = []

        for i in range(ys_in.size(1)):
            y_emb = ys_emb[:, i:i + 1]  # (bs, 1, embedding_size)
            logit, ctx, dstate, attn_weight = self.forward_one_step(
                y_emb, ctx, eouts, dstate, attn_weight, attn_mask)
            logits.append(logit)  # (bs, 1, dec_intermediate_size)

        logits = self.output(torch.cat(logits, dim=1))  # (bs, ylen, vocab)

        if self.kd_weight > 0 and soft_labels is not None:
            # NOTE: ys_out (label) have length ylens+1
            loss_att_kd, loss_kd, loss_att = self.loss_fn(
                logits, ys_out, soft_labels, ylens + 1)
            loss += loss_att_kd
            loss_dict["loss_kd"] = loss_kd
            loss_dict["loss_att"] = loss_att
        else:
            loss_att = self.loss_fn(logits, ys_out, ylens + 1)
            loss += loss_att
            loss_dict["loss_att"] = loss_att

        if self.mtl_ctc_weight > 0:
            # NOTE: KD is not applied to auxiliary CTC
            loss_ctc, _, _ = self.ctc(eouts=eouts,
                                      elens=elens,
                                      ys=ys,
                                      ylens=ylens,
                                      soft_labels=None)
            loss += self.mtl_ctc_weight * loss_ctc  # auxiliary loss
            loss_dict["loss_ctc"] = loss_ctc

        loss_dict["loss_total"] = loss

        return loss, loss_dict, logits
Ejemplo n.º 3
0
    def forward_disc(self, ys, ylens=None, error_labels=None):
        if ylens is None:
            attention_mask = None
        else:
            attention_mask = make_nopad_mask(ylens).float().to(ys.device)
            ys = ys[:, :max(ylens)]  # DataParallel

        loss, _ = self.dmodel(ys,
                              attention_mask=attention_mask,
                              labels=error_labels)
        loss_dict = {"loss_total": loss}

        return loss, loss_dict
Ejemplo n.º 4
0
    def predict(self, ys, ylens, states=None):
        """ predict next token for Shallow Fusion
        """
        attention_mask = make_nopad_mask(ylens).float().to(ys.device)

        with torch.no_grad():
            (logits,) = self.transformer(ys, attention_mask, causal=True)

        log_probs = torch.log_softmax(logits, dim=-1)

        log_probs_next = []
        bs = len(ys)
        for b in range(bs):
            log_probs_next.append(tensor2np(log_probs[b, ylens[b] - 1]))

        return torch.tensor(log_probs_next).to(ys.device), states
Ejemplo n.º 5
0
    def score(self, ys, ylens, batch_size=None):
        """ score token sequence for Rescoring
        """
        attention_mask = make_nopad_mask(ylens).float().to(ys.device)
        logits, = self.dmodel(ys, attention_mask=attention_mask)
        probs = torch.sigmoid(logits)

        if ys.size(0) == 1:
            return [torch.sum(probs, dim=-1).item()]

        score_lms = []
        bs = len(ys)
        for b in range(bs):
            score_lm = (-1) * torch.sum(probs[b, :ylens[b]], dim=-1).item()
            score_lms.append(score_lm)

        return score_lms
Ejemplo n.º 6
0
    def forward(self, ys, ylens=None, labels=None, ps=None, plens=None):
        if ylens is None:
            attention_mask = None
        else:
            attention_mask = make_nopad_mask(ylens).float().to(ys.device)
            # DataParallel
            ys = ys[:, :max(ylens)]

        if labels is None:
            (logits, ) = self.bert(ys, attention_mask=attention_mask)
            return logits

        if ylens is not None:
            labels = labels[:, :max(ylens)]
        loss, logits = self.bert(ys,
                                 attention_mask=attention_mask,
                                 labels=labels)
        loss_dict = {"loss_total": loss}

        return loss, loss_dict
Ejemplo n.º 7
0
    def score(self, ys, ylens, batch_size=None):
        """ score token sequence for Rescoring
        """
        attention_mask = make_nopad_mask(ylens).float().to(ys.device)

        with torch.no_grad():
            (logits,) = self.transformer(ys, attention_mask, causal=True)

        log_probs = torch.log_softmax(logits, dim=-1)

        score_lms = []
        bs = len(ys)
        for b in range(bs):
            score_lm = 0

            for i in range(0, ylens[b] - 1):
                v = ys[b, i + 1].item()  # predict next
                score_lm += log_probs[b, i, v].item()
            score_lms.append(score_lm)

        return score_lms
Ejemplo n.º 8
0
    def forward(self, ys, ylens=None, labels=None, ps=None, plens=None):
        if ylens is None:
            attention_mask = None
        else:
            attention_mask = make_nopad_mask(ylens).float().to(ys.device)
            # DataParallel
            ys = ys[:, : max(ylens)]
        
        if labels is None:
            # NOTE: causal attention mask
            (logits,) = self.transformer(ys, attention_mask=attention_mask, causal=True)
            return logits
        
        if ylens is not None:
            labels = labels[:, : max(ylens)]
        # NOTE: causal attention mask
        loss, logits = self.transformer(
            ys, attention_mask=attention_mask, causal=True, labels=labels
        )
        loss_dict = {"loss_total": loss}

        return loss, loss_dict