Beispiel #1
0
    def forward_att(self, eouts, elens, ys):
        """Compute XE loss for the sequence-to-sequence model.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (list): A list of length `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float):
            ppl (float):

        """
        bs = eouts.size(0)

        # Append <sos> and <eos>
        eos = eouts.new_zeros((1,)).fill_(self.eos).long()
        ylens = [len(y) for y in ys]
        ys = [np2tensor(np.fromiter(y[::-1] if self.backward else y, dtype=np.int64), self.device_id).long()
              for y in ys]
        ys_in = [torch.cat([eos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]
        ys_in_pad = pad_list(ys_in, self.pad)
        ys_out_pad = pad_list(ys_out, self.pad)

        # Add positional embedding
        ys_emb = self.embed(ys_in_pad) * (self.d_model ** 0.5)
        if self.pe_type:
            ys_emb = self.pos_emb_out(ys_emb)

        for l in range(self.n_layers):
            ys_emb, yy_aw, xy_aw = self.layers[l](eouts, elens, ys_emb, ylens)

        logits = self.norm_top(ys_emb)
        if self.adaptive_softmax is None:
            logits = self.output(logits)

        # Compute XE sequence loss
        if self.adaptive_softmax is None:
            if self.lsm_prob > 0:
                # Label smoothing
                loss = cross_entropy_lsm(logits, ys_out_pad,
                                         ylens=[y.size(0) for y in ys_out],
                                         lsm_prob=self.lsm_prob, size_average=False) / bs
            else:
                loss = F.cross_entropy(logits.view((-1, logits.size(2))), ys_out_pad.view(-1),
                                       ignore_index=self.pad, size_average=False) / bs
        else:
            loss = self.adaptive_softmax(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1)).loss

        # Compute token-level accuracy in teacher-forcing
        if self.adaptive_softmax is None:
            acc = compute_accuracy(logits, ys_out_pad, pad=self.pad)
        else:
            acc = compute_accuracy(self.adaptive_softmax.log_prob(
                logits.view((-1, logits.size(2)))), ys_out_pad, pad=self.pad)
        ppl = min(np.exp(loss.item()), np.inf)

        return loss, acc, ppl
Beispiel #2
0
    def forward_lmobj(self, ys):
        """Compute XE loss for LM objective.

        Args:
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy
            ppl (float): perplexity

        """
        w = next(self.parameters())

        # Append <sos> and <eos>
        eos = w.new_zeros(1).fill_(self.eos)
        ys = [
            np2tensor(np.fromiter(y, dtype=np.int64), self.device_id)
            for y in ys
        ]
        ylens = np2tensor(
            np.fromiter([y.size(0) + 1 for y in ys],
                        dtype=np.int32))  # +1 for <eos>
        ys_in_pad = pad_list([torch.cat([eos, y], dim=0) for y in ys],
                             self.pad)
        ys_out_pad = pad_list([torch.cat([y, eos], dim=0) for y in ys],
                              self.pad)

        # Update prediction network
        dout, _ = self.recurrency(self.embed(ys_in_pad), None)
        logits = self.output_lmobj(dout)

        # Compute XE loss for LM objective
        loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                               ys_out_pad.view(-1),
                               ignore_index=self.pad,
                               size_average=True)

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out_pad, self.pad)
        ppl = min(np.exp(loss.item()), np.inf)

        # scale loss for CTC
        loss *= ylens.float().mean()

        return loss, acc, ppl
    def forward_att(self, eouts, elens, ys, trigger_points=None):
        """Compute XE loss for the Transformer decoder.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (List): length `[B]`, each of which contains a list of size `[L]`
            trigger_points (IntTensor): `[B, L]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity
            losses_auxiliary (dict):

        """
        losses_auxiliary = {}

        # Append <sos> and <eos>
        ys_in, ys_out, ylens = append_sos_eos(ys, self.eos, self.eos, self.pad, self.device, self.bwd)
        if not self.training:
            self.data_dict['elens'] = tensor2np(elens)
            self.data_dict['ylens'] = tensor2np(ylens)
            self.data_dict['ys'] = tensor2np(ys_out)

        # Create target self-attention mask
        bs, ymax = ys_in.size()[:2]
        tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1])
        causal_mask = tgt_mask.new_ones(ymax, ymax, dtype=tgt_mask.dtype)
        causal_mask = torch.tril(causal_mask).unsqueeze(0)
        tgt_mask = tgt_mask & causal_mask  # `[B, L (query), L (key)]`

        # Create source-target mask
        src_mask = make_pad_mask(elens.to(self.device)).unsqueeze(1).repeat([1, ymax, 1])  # `[B, L, T]`

        # Create attention padding mask for quantity loss
        if self.attn_type == 'mocha':
            attn_mask = (ys_out != self.pad).unsqueeze(1).unsqueeze(3)  # `[B, 1, L, 1]`
        else:
            attn_mask = None

        # external LM integration
        lmout = None
        if self.lm is not None:
            self.lm.eval()
            with torch.no_grad():
                lmout, lmstate, _ = self.lm.predict(ys_in, None)
            lmout = self.lm_output_proj(lmout)

        out = self.pos_enc(self.embed_token_id(ys_in), scale=True)  # scaled + dropout

        xy_aws_layers = []
        xy_aws = None
        for lth, layer in enumerate(self.layers):
            out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout)
            # Attention padding
            xy_aws = layer.xy_aws
            if xy_aws is not None and self.attn_type == 'mocha':
                xy_aws_masked = xy_aws.masked_fill_(attn_mask.expand_as(xy_aws) == 0, 0)
                # NOTE: attention padding is quite effective for quantity loss
                xy_aws_layers.append(xy_aws_masked.clone())
            if not self.training:
                self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws)
                self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws)
                self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta)
                self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose)
                self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm)
        logits = self.output(self.norm_out(out))

        # Compute XE loss (+ label smoothing)
        loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training)

        # Quantity loss
        losses_auxiliary['loss_quantity'] = 0.
        if self.attn_type == 'mocha':
            # Average over all heads across all layers
            n_tokens_ref = tgt_mask[:, -1, :].sum(1).float()  # `[B]`
            # NOTE: count <eos> tokens
            n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1))
                                 for aws in xy_aws_layers])  # `[B]`
            n_tokens_pred /= len(xy_aws_layers)
            losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref))

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out, self.pad)

        return loss, acc, ppl, losses_auxiliary
Beispiel #4
0
    def forward_att(self, eouts, elens, ys,
                    return_logits=False, teacher_logits=None, trigger_points=None):
        """Compute XE loss for the Transformer decoder.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): length `B`, each of which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
            teacher_logits (FloatTensor): `[B, L, vocab]`
            trigger_points (IntTensor): `[B, T]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity
            loss_quantity (FloatTensor): `[1]`
            loss_headdiv (FloatTensor): `[1]`
            loss_latency (FloatTensor): `[1]`

        """
        # Append <sos> and <eos>
        ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.eos, self.pad, self.bwd)
        if not self.training:
            self.data_dict['elens'] = tensor2np(elens)
            self.data_dict['ylens'] = tensor2np(ylens)
            self.data_dict['ys'] = tensor2np(ys_out)

        # Create target self-attention mask
        xmax = eouts.size(1)
        bs, ymax = ys_in.size()[:2]
        mlen = 0
        tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1])
        causal_mask = tgt_mask.new_ones(ymax, ymax).byte()
        causal_mask = torch.tril(causal_mask, diagonal=0 + mlen, out=causal_mask).unsqueeze(0)
        tgt_mask = tgt_mask & causal_mask  # `[B, L (query), L (key)]`

        # Create source-target mask
        src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat([1, ymax, 1])  # `[B, L, T]`

        # external LM integration
        lmout = None
        if self.lm is not None:
            self.lm.eval()
            with torch.no_grad():
                lmout, lmstate, _ = self.lm.predict(ys_in, None)
            lmout = self.lm_output_proj(lmout)

        out = self.pos_enc(self.embed(ys_in))  # scaled

        mems = self.init_memory()
        pos_embs = None
        if self.memory_transformer:
            out = self.dropout_emb(out)
            # NOTE: TransformerXL does not use positional encoding in the token embedding
            # adopt zero-centered offset
            pos_idxs = torch.arange(mlen - 1, -ymax - 1, -1.0, dtype=torch.float)
            pos_embs = self.pos_emb(pos_idxs, self.device_id)

        hidden_states = [out]
        xy_aws_layers = []
        for lth, (mem, layer) in enumerate(zip(mems, self.layers)):
            out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout,
                        pos_embs=pos_embs, memory=mem, u=self.u, v=self.v)
            if lth < self.n_layers - 1:
                hidden_states.append(out)
                # NOTE: outputs from the last layer is not used for momory
            # Attention padding
            xy_aws = layer.xy_aws
            if xy_aws is not None and 'mocha' in self.attn_type:
                tgt_mask_v2 = (ys_out != self.pad).unsqueeze(1).unsqueeze(3)  # `[B, 1, L, 1]`
                xy_aws = xy_aws.masked_fill_(tgt_mask_v2.repeat([1, xy_aws.size(1), 1, xmax]) == 0, 0)
                # NOTE: attention padding is quite effective for quantity loss
                xy_aws_layers.append(xy_aws.clone())
            if not self.training:
                if layer.yy_aws is not None:
                    self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws)
                if layer.xy_aws is not None:
                    self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws)
                if layer.xy_aws_beta is not None:
                    self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta)
                if layer.xy_aws_p_choose is not None:
                    self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose)
                if layer.yy_aws_lm is not None:
                    self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm)
        logits = self.output(self.norm_out(out))

        # for knowledge distillation
        if return_logits:
            return logits

        # Compute XE loss (+ label smoothing)
        loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training)
        losses_auxiliary = {}

        # Quantity loss
        losses_auxiliary['loss_quantity'] = 0.
        if 'mocha' in self.attn_type:
            # Average over all heads across all layers
            n_tokens_ref = tgt_mask[:, -1, :].sum(1).float()  # `[B]`
            # NOTE: count <eos> tokens
            n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1))
                                 for aws in xy_aws_layers])  # `[B]`
            n_tokens_pred /= len(xy_aws_layers)
            losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref))

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out, self.pad)

        return loss, acc, ppl, losses_auxiliary
Beispiel #5
0
    def _forward(self, ys, state, reporter, n_caches=0, predict_last=False):
        ys = [np2tensor(y, self.device_id) for y in ys]  # <eos> is included
        ylens = np2tensor(
            np.fromiter([y.size(0) - 1 for y in ys],
                        dtype=np.int32))  # -1 for <eos>
        ys = pad_list(ys, self.pad)
        ys_in = ys[:, :-1]
        ys_out = ys[:, 1:]

        out, state = self.decode(ys_in, state)
        if self.adaptive_softmax is None:
            logits = self.output(out)
        else:
            logits = out

        if predict_last:
            ys_out = ys_out[:, -1].unsqueeze(1)
            logits = logits[:, -1].unsqueeze(1)

        # Compute XE sequence loss
        if n_caches > 0 and len(self.cache_ids) > 0:
            assert ys_out.size(1) == 1
            assert ys_out.size(0) == 1
            if self.adaptive_softmax is None:
                probs = F.softmax(logits, dim=-1)
            else:
                probs = self.adaptive_softmax.log_prob(logits).exp()
            cache_probs = probs.new_zeros(probs.size())

            # Truncate cache
            self.cache_ids = self.cache_ids[-n_caches:]  # list of `[B, 1]`
            self.cache_keys = self.cache_keys[
                -n_caches:]  # list of `[B, 1, n_units]`

            # Compute inner-product over caches
            cache_attn = F.softmax(
                self.cache_theta *
                torch.matmul(torch.cat(self.cache_keys, dim=1),
                             out.transpose(2, 1)).squeeze(2),
                dim=1)

            # For visualization
            if len(self.cache_ids) == n_caches:
                self.cache_attn += [cache_attn.cpu().numpy()]
                self.cache_attn = self.cache_attn[-n_caches:]

            # Sum all probabilities
            for offset, idx in enumerate(self.cache_ids):
                cache_probs[:, :, idx] += cache_attn[:, offset]
            probs = (1 - self.cache_lambda
                     ) * probs + self.cache_lambda * cache_probs
            loss = -torch.log(probs[:, :, ys_out[:, -1]])
        else:
            if self.adaptive_softmax is None:
                if self.lsm_prob > 0 and self.training:
                    # Label smoothing
                    loss = cross_entropy_lsm(logits.view((-1, logits.size(2))),
                                             ys_out.contiguous().view(-1),
                                             self.lsm_prob, self.pad)
                else:
                    loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                                           ys_out.contiguous().view(-1),
                                           ignore_index=self.pad,
                                           size_average=True)
            else:
                loss = self.adaptive_softmax(logits.view((-1, logits.size(2))),
                                             ys_out.contiguous().view(-1)).loss

        if n_caches > 0:
            # Register to cache
            self.cache_ids += [ys_out[0, -1].item()]
            self.cache_keys += [out]

        # Compute token-level accuracy in teacher-forcing
        if self.adaptive_softmax is None:
            acc = compute_accuracy(logits, ys_out, pad=self.pad)
        else:
            acc = compute_accuracy(self.adaptive_softmax.log_prob(
                logits.view((-1, logits.size(2)))),
                                   ys_out,
                                   pad=self.pad)

        observation = {
            'loss.lm': loss.item(),
            'acc.lm': acc,
            'ppl.lm': min(np.exp(loss.item()), np.inf)
        }

        # Report here
        if reporter is not None:
            is_eval = not self.training
            reporter.add(observation, is_eval)

        return loss, state, reporter
Beispiel #6
0
    def forward_att(self, eouts, elens, ys, return_logits=False):
        """Compute XE loss for the sequence-to-sequence model.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
        Returns:
            loss (FloatTensor): `[1]`
            acc (float):
            ppl (float):

        """
        bs = eouts.size(0)

        # Append <sos> and <eos>
        eos = eouts.new_zeros(1).fill_(self.eos).long()
        ys = [
            np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int64),
                      self.device_id) for y in ys
        ]
        ylens = np2tensor(
            np.fromiter([y.size(0) + 1 for y in ys],
                        dtype=np.int32))  # +1 for <eos>
        ys_in_pad = pad_list([torch.cat([eos, y], dim=0) for y in ys],
                             self.pad)
        ys_out_pad = pad_list([torch.cat([y, eos], dim=0) for y in ys],
                              self.pad)

        # Create the self-attention mask
        bs, ymax = ys_in_pad.size()[:2]
        yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand(
            bs, ymax, ymax)
        yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax,
                                              ymax)
        subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(),
                                     diagonal=0)
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, ymax)
        yy_mask = yy_mask & subsequent_mask

        # Create the source-target mask
        xmax = eouts.size(1)
        x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand(
            bs, ymax, xmax)
        y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).expand(
            bs, ymax, xmax)
        xy_mask = (x_mask * y_mask).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, xmax)

        ys_emb = self.pos_enc(self.embed(ys_in_pad))
        for l in range(self.n_layers):
            ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts,
                                                    xy_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
                setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws))
        logits = self.norm_out(ys_emb)
        if self.adaptive_softmax is None:
            logits = self.output(logits)
        if return_logits:
            return logits

        # Compute XE sequence loss
        if self.adaptive_softmax is None:
            if self.lsm_prob > 0 and self.training:
                # Label smoothing
                loss = cross_entropy_lsm(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1), self.lsm_prob,
                                         self.pad)
            else:
                loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                                       ys_out_pad.view(-1),
                                       ignore_index=self.pad,
                                       size_average=True)

            # Focal loss
            if self.focal_loss_weight > 0:
                fl = focal_loss(logits,
                                ys_out_pad,
                                ylens,
                                alpha=self.focal_loss_weight,
                                gamma=self.focal_loss_gamma)
                loss = loss * (
                    1 - self.focal_loss_weight) + fl * self.focal_loss_weight
        else:
            loss = self.adaptive_softmax(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1)).loss

        # Compute token-level accuracy in teacher-forcing
        if self.adaptive_softmax is None:
            acc = compute_accuracy(logits, ys_out_pad, self.pad)
        else:
            acc = compute_accuracy(
                self.adaptive_softmax.log_prob(
                    logits.view((-1, logits.size(2)))), ys_out_pad, self.pad)
        ppl = min(np.exp(loss.item()), np.inf)

        # scale loss for CTC
        loss *= ylens.float().mean()

        return loss, acc, ppl
Beispiel #7
0
    def _forward(self, ys, hidden, reporter, n_caches=0):
        ys = [np2tensor(y, self.device_id).long() for y in ys]
        ys = pad_list(ys, self.pad)
        ys_in = ys[:, :-1]
        ys_out = ys[:, 1:]

        lmout, hidden = self.decode(self.encode(ys_in), hidden)
        if self.adaptive_softmax is None:
            logits = self.generate(lmout)
        else:
            logits = lmout

        # Compute XE sequence loss
        if n_caches > 0 and len(self.cache_ids) > 0:
            assert ys_out.size(1) == 1
            assert ys_out.size(0) == 1
            if self.adaptive_softmax is None:
                probs = F.softmax(logits, dim=-1)
            else:
                probs = self.adaptive_softmax.log_prob(logits).exp()
            cache_probs = probs.new_zeros(probs.size())

            # Truncate cache
            self.cache_ids = self.cache_ids[-n_caches:]  # list of `[B, 1]`
            self.cache_keys = self.cache_keys[-n_caches:]  # list of `[B, 1, n_units]`

            # Compute inner-product over caches
            cache_attn = F.softmax(self.cache_theta * torch.matmul(
                torch.cat(self.cache_keys, dim=1), lmout.transpose(2, 1)).squeeze(2), dim=1)

            # For visualization
            if len(self.cache_ids) == n_caches:
                self.cache_attn += [cache_attn.cpu().numpy()]
                self.cache_attn = self.cache_attn[-n_caches:]

            # Sum all probabilities
            for offset, idx in enumerate(self.cache_ids):
                cache_probs[:, :, idx] += cache_attn[:, offset]
            probs = (1 - self.cache_lambda) * probs + self.cache_lambda * cache_probs
            loss = -torch.log(probs[:, :, ys_out[:, -1]])
        else:
            if self.adaptive_softmax is None:
                loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                                       ys_out.contiguous().view(-1),
                                       ignore_index=self.pad, size_average=True)
            else:
                loss = self.adaptive_softmax(logits.view((-1, logits.size(2))),
                                             ys_out.contiguous().view(-1)).loss

        if n_caches > 0:
            # Register to cache
            self.cache_ids += [ys_out[0, -1].item()]
            self.cache_keys += [lmout]

        # Compute token-level accuracy in teacher-forcing
        if self.adaptive_softmax is None:
            acc = compute_accuracy(logits, ys_out, pad=self.pad)
        else:
            acc = compute_accuracy(self.adaptive_softmax.log_prob(
                logits.view((-1, logits.size(2)))), ys_out, pad=self.pad)

        observation = {'loss.lm': loss.item(),
                       'acc.lm': acc,
                       'ppl.lm': np.exp(loss.item())}

        # Report here
        if reporter is not None:
            is_eval = not self.training
            reporter.add(observation, is_eval)

        return loss, hidden, reporter
Beispiel #8
0
    def forward_att(self, eouts, elens, ys):
        """Compute XE loss for the CIF model.

        Args:
            eouts (FloatTensor): `[B, T, dec_n_units]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity

        """
        bs, xmax = eouts.size()[:2]

        # Append <sos> and <eos>
        ys_in_pad, ys_out_pad, ylens = self.append_sos_eos(ys, self.bwd)

        # Initialization
        dstate = self.zero_state(bs)
        lmouts, lmstate = None, None

        # CIF
        cvs, alpha, aws = self.score(eouts, elens, ylens)

        # Update LM states for LM fusion
        if self.lm is not None:
            lmouts, _ = self.lm.decode(ys_in_pad, lmstate)

        # Recurrency -> Score -> Generate
        ys_emb = self.embed(ys_in_pad)
        dstate, logits = self.decode_step(cvs, dstate, ys_emb, lmouts)
        logits = self.output(logits)

        # for attention plot
        if not self.training:
            self.aws = tensor2np(aws)  # `[B, n_heads, L, T]`

        # Compute XE sequence loss
        if self.lsm_prob > 0 and self.training:
            # Label smoothing
            loss = cross_entropy_lsm(logits.view((-1, logits.size(2))),
                                     ys_out_pad.view(-1), self.lsm_prob,
                                     self.pad)
        else:
            loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                                   ys_out_pad.view(-1),
                                   ignore_index=self.pad,
                                   size_average=True)

        # Quantity loss for CIF
        loss += torch.mean(
            torch.abs(alpha.sum(1) - ylens.float().cuda(self.device_id))
        ) * self.quantity_loss_weight

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out_pad, self.pad)
        ppl = np.exp(loss.item())

        # scale loss for CTC
        loss *= ylens.float().mean()

        return loss, acc, ppl
Beispiel #9
0
    def _forward(self, ys, state, n_caches=0, predict_last=False):
        ys = [np2tensor(y, self.device) for y in ys]  # <eos> is included
        ys = pad_list(ys, self.pad)
        ys_in, ys_out = ys[:, :-1], ys[:, 1:]

        logits, out, new_state = self.decode(ys_in, state=state, mems=state)
        # NOTE: state=state is used for RNNLM while mems=state is used for TransformerXL.
        # TransformerLM ignores both of them.

        if predict_last:
            ys_out = ys_out[:, -1].unsqueeze(1)
            logits = logits[:, -1].unsqueeze(1)

        # Compute XE sequence loss
        if n_caches > 0 and len(self.cache_ids) > 0:
            assert ys_out.size(1) == 1
            assert ys_out.size(0) == 1
            if self.adaptive_softmax is None:
                probs = torch.softmax(logits, dim=-1)
            else:
                probs = self.adaptive_softmax.log_prob(logits).exp()
            cache_probs = probs.new_zeros(probs.size())

            # Truncate cache
            self.cache_ids = self.cache_ids[-n_caches:]  # list of `[B, 1]`
            self.cache_keys = self.cache_keys[
                -n_caches:]  # list of `[B, 1, n_units]`

            # Compute inner-product over caches
            cache_attn = torch.softmax(
                self.cache_theta *
                torch.matmul(torch.cat(self.cache_keys, dim=1),
                             out.transpose(2, 1)).squeeze(2),
                dim=1)

            # For visualization
            if len(self.cache_ids) == n_caches:
                self.cache_attn += [cache_attn.cpu().numpy()]
                self.cache_attn = self.cache_attn[-n_caches:]

            # Sum all probabilities
            for offset, idx in enumerate(self.cache_ids):
                cache_probs[:, :, idx] += cache_attn[:, offset]
            probs = (1 - self.cache_lambda
                     ) * probs + self.cache_lambda * cache_probs
            loss = -torch.log(probs[:, :, ys_out[:, -1]])
        else:
            if self.adaptive_softmax is None:
                loss, ppl = cross_entropy_lsm(logits,
                                              ys_out.contiguous(),
                                              self.lsm_prob,
                                              self.pad,
                                              self.training,
                                              normalize_length=True)
            else:
                loss = self.adaptive_softmax(logits.view((-1, logits.size(2))),
                                             ys_out.contiguous().view(-1)).loss
                ppl = np.exp(loss.item())

        if n_caches > 0:
            # Register to cache
            self.cache_ids += [ys_out[0, -1].item()]
            self.cache_keys += [out]

        # Compute token-level accuracy in teacher-forcing
        if self.adaptive_softmax is None:
            acc = compute_accuracy(logits, ys_out, pad=self.pad)
        else:
            acc = compute_accuracy(self.adaptive_softmax.log_prob(
                logits.view((-1, logits.size(2)))),
                                   ys_out,
                                   pad=self.pad)

        observation = {'loss.lm': loss.item(), 'acc.lm': acc, 'ppl.lm': ppl}
        return loss, new_state, observation
Beispiel #10
0
    def forward_att(self,
                    eouts,
                    elens,
                    ys,
                    ys_hist=[],
                    return_logits=False,
                    teacher_logits=None):
        """Compute XE loss for the Transformer model.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
            ys_hist (list):
            return_logits (bool): return logits for knowledge distillation
            teacher_logits (FloatTensor): `[B, L, vocab]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity

        """
        bs = eouts.size(0)

        # Append <sos> and <eos>
        ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.pad,
                                              self.bwd)

        # Create the self-attention mask
        bs, ytime = ys_in.size()[:2]
        tgt_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat(
            [1, ytime, 1])
        subsequent_mask = tgt_mask.new_ones(ytime, ytime).byte()
        subsequent_mask = torch.tril(subsequent_mask,
                                     out=subsequent_mask).unsqueeze(0)
        tgt_mask = tgt_mask & subsequent_mask

        # Create the source-target mask
        src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat(
            [1, ytime, 1])

        out = self.pos_enc(self.embed(ys_in))
        for l in range(self.n_layers):
            out, yy_aws, xy_aws = self.layers[l](out, tgt_mask, eouts,
                                                 src_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
                setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws))
        logits = self.output(self.norm_out(out))

        # for knowledge distillation
        if return_logits:
            return logits

        # Compute XE sequence loss (+ label smoothing)
        loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad,
                                      self.training)

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out, self.pad)

        return loss, acc, ppl
Beispiel #11
0
    def forward_att(self,
                    eouts,
                    elens,
                    ys,
                    return_logits=False,
                    teacher_logits=None,
                    trigger_points=None):
        """Compute XE loss for the Transformer decoder.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): length `B`, each of which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
            teacher_logits (FloatTensor): `[B, L, vocab]`
            trigger_points (IntTensor): `[B, T]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity
            loss_quantity (FloatTensor): `[1]`
            loss_headdiv (FloatTensor): `[1]`
            loss_latency (FloatTensor): `[1]`

        """
        # Append <sos> and <eos>
        ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.eos,
                                              self.pad, self.bwd)
        if not self.training:
            self.data_dict['elens'] = tensor2np(elens)
            self.data_dict['ylens'] = tensor2np(ylens)
            self.data_dict['ys'] = tensor2np(ys_out)

        # Create target self-attention mask
        xtime = eouts.size(1)
        bs, ymax = ys_in.size()[:2]
        tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1])
        causal_mask = tgt_mask.new_ones(ymax, ymax).byte()
        causal_mask = torch.tril(causal_mask, out=causal_mask).unsqueeze(0)
        tgt_mask = tgt_mask & causal_mask  # `[B, L, L]`

        # Create source-target mask
        src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat(
            [1, ymax, 1])  # `[B, L, T]`

        # external LM integration
        lmout = None
        if self.lm is not None:
            self.lm.eval()
            with torch.no_grad():
                lmout, lmstate, _ = self.lm.predict(ys_in, None)
            lmout = self.lm_output_proj(lmout)

        out = self.embed(ys_in)

        mlen = 0  # TODO: fix later
        if self.memory_transformer:
            # NOTE: TransformerXL does not use positional encoding in the token embedding
            mems = self.init_memory()
            # adopt zero-centered offset
            pos_idxs = torch.arange(mlen - 1,
                                    -ymax - 1,
                                    -1.0,
                                    dtype=torch.float)
            if self.device_id >= 0:
                pos_idxs = pos_idxs.cuda(self.device_id)
            pos_embs = self.dropout_emb(self.pos_emb(pos_idxs))
            out = self.dropout_emb(out)
            hidden_states = [out]
        else:
            out = self.pos_enc(out)

        xy_aws_layers = []
        for l, layer in enumerate(self.layers):
            if self.memory_transformer:
                out, yy_aws, xy_aws, xy_aws_beta, yy_aws_lm = layer(
                    out,
                    tgt_mask,
                    eouts,
                    src_mask,
                    mode='parallel',
                    lmout=lmout,
                    pos_embs=pos_embs,
                    memory=mems[l],
                    u=self.u,
                    v=self.v)
                hidden_states.append(out)
            else:
                out, yy_aws, xy_aws, xy_aws_beta, yy_aws_lm = layer(
                    out,
                    tgt_mask,
                    eouts,
                    src_mask,
                    mode='parallel',
                    lmout=lmout)

            xy_aws_layers.append(xy_aws.clone() if xy_aws is not None else out.
                                 new_zeros(bs, yy_aws.size(1), ymax, xtime))
            if not self.training:
                if yy_aws is not None:
                    self.aws_dict['yy_aws_layer%d' % l] = tensor2np(yy_aws)
                if xy_aws is not None:
                    self.aws_dict['xy_aws_layer%d' % l] = tensor2np(xy_aws)
                if xy_aws_beta is not None:
                    self.aws_dict['xy_aws_beta_layer%d' %
                                  l] = tensor2np(xy_aws_beta)
                if yy_aws_lm is not None:
                    self.aws_dict['yy_aws_lm_layer%d' %
                                  l] = tensor2np(yy_aws_lm)
        logits = self.output(self.norm_out(out))

        # TODO: Update memory
        # if self.memory_transformer:
        #     new_mems = self.update_memory(mems, hidden_states)

        # for knowledge distillation
        if return_logits:
            return logits

        # Compute XE sequence loss (+ label smoothing)
        loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad,
                                      self.training)

        # Attention padding
        if self.quantity_loss_weight > 0 or self.headdiv_loss_weight > 0 or self.latency_loss_weight > 0:
            for l in range(self.mocha_first_layer - 1, self.n_layers):
                n_heads = xy_aws_layers[l].size(1)
                xy_aws_layers[l] = xy_aws_layers[l].masked_fill_(
                    src_mask.unsqueeze(1).repeat([1, n_heads, 1, 1]) == 0, 0)
                xy_aws_layers[l] = xy_aws_layers[l].masked_fill_(
                    tgt_mask[:, :,
                             -1:].unsqueeze(1).repeat([1, n_heads, 1,
                                                       xtime]) == 0, 0)
                # NOTE: attention padding is quite effective for quantity loss
        n_heads = xy_aws_layers[-1].size(1)  # mono
        # NOTE: debug for multihead mono + multihead chunk

        # Quantity loss
        loss_quantity = 0.
        if 'mocha' in self.attn_type:
            # Average over all heads across all layers
            n_tokens_ref = tgt_mask[:, -1, :].sum(1).float()  # `[B]`
            # NOTE: count <eos> tokens
            n_tokens_pred = sum([
                torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1))
                for aws in xy_aws_layers[self.mocha_first_layer - 1:]
            ])  # `[B]`
            n_tokens_pred /= (self.n_layers - self.mocha_first_layer + 1)
            loss_quantity = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref))

        # Head divergence loss
        loss_headdiv = 0.
        if self.headdiv_loss_weight > 0.:
            # Calculate variance over all heads across all layers
            js = torch.arange(xtime, dtype=torch.float).cuda(self.device_id)
            js = js.repeat([bs, n_heads, ymax, 1])
            avg_head_pos = sum([
                (js * aws).sum(3).sum(1) for aws in xy_aws_layers
            ]) / (n_heads * self.n_layers)  # `[B, L]`
            loss_headdiv = sum([((js * aws).sum(3).sum(1) - avg_head_pos)**2
                                for aws in xy_aws_layers]) / (
                                    n_heads * self.n_layers)  # `[B, L]`
            loss_headdiv = loss_headdiv.sum() / ylens.sum()

        # Latency loss
        loss_latency = 0.
        if self.latency_metric == 'interval':
            raise NotImplementedError
        elif trigger_points is not None:
            assert self.latency_loss_weight > 0
            # Calculate weight average latency
            js = torch.arange(xtime, dtype=torch.float).cuda(self.device_id)
            js = js.repeat([bs, n_heads, ymax, 1])
            weighted_avg_head_pos = torch.cat(
                [(js * aws).sum(3) for aws in xy_aws_layers],
                dim=1)  # `[B, H_mono * n_layers, L]`
            weighted_avg_head_pos *= torch.softmax(
                weighted_avg_head_pos.clone(), dim=1)
            trigger_points = trigger_points.float().cuda(
                self.device_id)  # `[B, L]`
            trigger_points = trigger_points.unsqueeze(1)
            if self.latency_metric == 'ctc_sync':
                loss_latency = torch.abs(
                    weighted_avg_head_pos -
                    trigger_points)  # `[B, H_mono * n_layers, L]`
            else:
                raise NotImplementedError(self.latency_metric)
            # NOTE: trigger_points are padded with 0
            loss_latency = loss_latency.sum() / ylens.sum()

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out, self.pad)

        return loss, acc, ppl, loss_quantity, loss_headdiv, loss_latency
Beispiel #12
0
    def forward_att(self, eouts, elens, ys, return_logits=False):
        """Compute XE loss for the sequence-to-sequence model.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity

        """
        bs = eouts.size(0)

        # Append <sos> and <eos>
        ys_in_pad, ys_out_pad, ylens = self.append_sos_eos(ys, self.bwd)

        # Create the self-attention mask
        bs, ymax = ys_in_pad.size()[:2]
        yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat(
            [1, ymax, 1])
        yy_mask = yy_mask.unsqueeze(1).repeat([1, self.attn_n_heads, 1, 1])
        subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(),
                                     diagonal=0)
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).repeat(
            [bs, self.attn_n_heads, 1, 1])
        yy_mask = yy_mask & subsequent_mask

        # Create the source-target mask
        xmax = eouts.size(1)
        x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat(
            [1, ymax, 1])
        y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).repeat(
            [1, 1, xmax])
        xy_mask = (x_mask * y_mask).unsqueeze(1).repeat(
            [1, self.attn_n_heads, 1, 1])

        ys_emb = self.pos_enc(self.embed(ys_in_pad))
        for l in range(self.n_layers):
            ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts,
                                                    xy_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
                setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws))
        ys_emb = self.norm_out(ys_emb)
        logits = self.output(ys_emb)

        # for knowledge distillation
        if return_logits:
            return logits

        # Compute XE sequence loss
        if self.lsm_prob > 0 and self.training:
            # Label smoothing
            loss = cross_entropy_lsm(logits.view((-1, logits.size(2))),
                                     ys_out_pad.view(-1), self.lsm_prob,
                                     self.pad)
        else:
            loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                                   ys_out_pad.view(-1),
                                   ignore_index=self.pad,
                                   size_average=True)

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out_pad, self.pad)
        ppl = min(np.exp(loss.item()), np.inf)

        # scale loss for CTC
        loss *= ylens.float().mean()

        return loss, acc, ppl