Ejemplo n.º 1
0
    def forward_att(self, eouts, elens, ys, return_logits=False):
        """Compute XE loss for the sequence-to-sequence model.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
        Returns:
            loss (FloatTensor): `[1]`
            acc (float):
            ppl (float):

        """
        bs = eouts.size(0)

        # Append <sos> and <eos>
        eos = eouts.new_zeros(1).fill_(self.eos).long()
        ys = [
            np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int64),
                      self.device_id) for y in ys
        ]
        ylens = np2tensor(
            np.fromiter([y.size(0) + 1 for y in ys],
                        dtype=np.int32))  # +1 for <eos>
        ys_in_pad = pad_list([torch.cat([eos, y], dim=0) for y in ys],
                             self.pad)
        ys_out_pad = pad_list([torch.cat([y, eos], dim=0) for y in ys],
                              self.pad)

        # Create the self-attention mask
        bs, ymax = ys_in_pad.size()[:2]
        yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand(
            bs, ymax, ymax)
        yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax,
                                              ymax)
        subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(),
                                     diagonal=0)
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, ymax)
        yy_mask = yy_mask & subsequent_mask

        # Create the source-target mask
        xmax = eouts.size(1)
        x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand(
            bs, ymax, xmax)
        y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).expand(
            bs, ymax, xmax)
        xy_mask = (x_mask * y_mask).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, xmax)

        ys_emb = self.pos_enc(self.embed(ys_in_pad))
        for l in range(self.n_layers):
            ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts,
                                                    xy_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
                setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws))
        logits = self.norm_out(ys_emb)
        if self.adaptive_softmax is None:
            logits = self.output(logits)
        if return_logits:
            return logits

        # Compute XE sequence loss
        if self.adaptive_softmax is None:
            if self.lsm_prob > 0 and self.training:
                # Label smoothing
                loss = cross_entropy_lsm(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1), self.lsm_prob,
                                         self.pad)
            else:
                loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                                       ys_out_pad.view(-1),
                                       ignore_index=self.pad,
                                       size_average=True)

            # Focal loss
            if self.focal_loss_weight > 0:
                fl = focal_loss(logits,
                                ys_out_pad,
                                ylens,
                                alpha=self.focal_loss_weight,
                                gamma=self.focal_loss_gamma)
                loss = loss * (
                    1 - self.focal_loss_weight) + fl * self.focal_loss_weight
        else:
            loss = self.adaptive_softmax(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1)).loss

        # Compute token-level accuracy in teacher-forcing
        if self.adaptive_softmax is None:
            acc = compute_accuracy(logits, ys_out_pad, self.pad)
        else:
            acc = compute_accuracy(
                self.adaptive_softmax.log_prob(
                    logits.view((-1, logits.size(2)))), ys_out_pad, self.pad)
        ppl = min(np.exp(loss.item()), np.inf)

        # scale loss for CTC
        loss *= ylens.float().mean()

        return loss, acc, ppl
Ejemplo n.º 2
0
    def forward_att(self, eouts, elens, ys, device_id):
        """Compute XE loss for the sequence-to-sequence model.

        Args:
            eouts (FloatTensor): `[B, T, dec_units]`
            elens (list): A list of length `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
            device_id (int):
        Returns:
            loss (FloatTensor): `[B, L, vocab]`
            acc (float):
            ppl (float):

        """
        bs, _, enc_nunits = eouts.size()

        # Append <sos> and <eos>
        sos = eouts.new_zeros(1).fill_(self.sos).long()
        eos = eouts.new_zeros(1).fill_(self.eos).long()
        if self.backward:
            ys = [
                np2tensor(np.fromiter(y[::-1], dtype=np.int64),
                          device_id).long() for y in ys
            ]
            ys_in = [torch.cat([eos, y], dim=0) for y in ys]
            ys_out = [torch.cat([y, sos], dim=0) for y in ys]
        else:
            ys = [
                np2tensor(np.fromiter(y, dtype=np.int64), device_id).long()
                for y in ys
            ]
            ys_in = [torch.cat([sos, y], dim=0) for y in ys]
            ys_out = [torch.cat([y, eos], dim=0) for y in ys]
        ys_in_pad = pad_list(ys_in, self.pad)
        ys_out_pad = pad_list(ys_out, -1)

        # Initialization
        dout, dstate = self.init_dec_state(bs, self.nlayers, device_id, eouts,
                                           elens)
        _dout, _dstate = self.init_dec_state(bs, 1, device_id, eouts,
                                             elens)  # for internal LM
        context = eouts.new_zeros(bs, 1, enc_nunits)
        self.score.reset()
        aw = None
        rnnlm_state = None

        # Pre-computation of embedding
        ys_emb = self.embed(ys_in_pad)
        if self.rnnlm_cf:
            ys_lm_emb = self.rnnlm_cf.embed(ys_in_pad)
            # ys_lm_emb = [self.rnnlm_cf.embed(ys_in_pad[:, t:t + 1])
            #              for t in range(ys_in_pad.size(1))]
            # ys_lm_emb = torch.cat(ys_lm_emb, dim=1)

        logits = []
        for t in range(ys_in_pad.size(1)):
            # Sample for scheduled sampling
            is_sample = t > 0 and self.ss_prob > 0 and random.random(
            ) < self.ss_prob
            if is_sample:
                y_emb = self.embed(torch.argmax(logits[-1].detach(), dim=-1))
            else:
                y_emb = ys_emb[:, t:t + 1]

            # Recurrency
            dout, dstate, _dout, _dstate = self.recurrency(
                y_emb, context, dstate, _dstate)

            # Update RNNLM states for cold fusion
            if self.rnnlm_cf:
                if is_sample:
                    y_lm_emb = self.rnnlm_cf.embed(
                        np.argmax(logits[-1].detach(), axis=2).cuda(device_id))
                else:
                    y_lm_emb = ys_lm_emb[:, t:t + 1]
                logits_lm_t, lm_out, rnnlm_state = self.rnnlm_cf.predict(
                    y_lm_emb, rnnlm_state)
            else:
                logits_lm_t, lm_out = None, None

            # Score
            context, aw = self.score(eouts, elens, dout, aw)

            # Generate
            attentional_t = self.generate(context, dout, logits_lm_t, lm_out)
            if self.rnnlm_init and self.internal_lm:
                # Residual connection
                attentional_t += _dout
            logits.append(self.output(attentional_t))

        # Compute XE sequence loss
        logits = torch.cat(logits, dim=1) / self.logits_temp
        if self.lsm_prob > 0:
            # Label smoothing
            y_lens = [y.size(0) for y in ys_out]
            loss = cross_entropy_lsm(logits,
                                     ys=ys_out_pad,
                                     y_lens=y_lens,
                                     lsm_prob=self.lsm_prob,
                                     size_average=True)
        else:
            loss = F.cross_entropy(
                logits.view((-1, logits.size(2))),
                ys_out_pad.view(-1),  # long
                ignore_index=-1,
                size_average=False) / bs
        ppl = math.exp(loss.item())

        # Focal loss
        if self.fl_weight > 0:
            y_lens = [y.size(0) for y in ys_out]
            fl = focal_loss(logits,
                            ys=ys_out_pad,
                            y_lens=y_lens,
                            gamma=self.fl_gamma,
                            size_average=True)
            loss = loss * (1 - self.fl_weight) + fl * self.fl_weight

        # Compute token-level accuracy in teacher-forcing
        pad_pred = logits.view(ys_out_pad.size(0), ys_out_pad.size(1),
                               logits.size(-1)).argmax(2)
        mask = ys_out_pad != -1
        numerator = torch.sum(
            pad_pred.masked_select(mask) == ys_out_pad.masked_select(mask))
        denominator = torch.sum(mask)
        acc = float(numerator) * 100 / float(denominator)

        return loss, acc, ppl