コード例 #1
0
    def forward_asr(self, hs_pad, hs_mask, ys_pad):
        """Forward pass in the auxiliary ASR task.

        :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor hs_mask: batch of input token mask (B, Lmax)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ASR attention loss value
        :rtype: torch.Tensor
        :return: accuracy in ASR attention decoder
        :rtype: float
        :return: ASR CTC loss value
        :rtype: torch.Tensor
        :return: character error rate from CTC prediction
        :rtype: float
        :return: character error rate from attetion decoder prediction
        :rtype: float
        :return: word error rate from attetion decoder prediction
        :rtype: float
        """
        loss_att, loss_ctc = 0.0, 0.0
        acc = None
        cer, wer = None, None
        cer_ctc = None
        if self.asr_weight == 0:
            return loss_att, acc, loss_ctc, cer_ctc, cer, wer

        # attention
        if self.mtlalpha < 1:
            ys_in_pad_asr, ys_out_pad_asr = add_sos_eos(
                ys_pad, self.sos, self.eos, self.ignore_id)
            ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id)
            pred_pad, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad,
                                           hs_mask)
            loss_att = self.criterion(pred_pad, ys_out_pad_asr)

            acc = th_accuracy(
                pred_pad.view(-1, self.odim),
                ys_out_pad_asr,
                ignore_label=self.ignore_id,
            )
            if not self.training:
                ys_hat_asr = pred_pad.argmax(dim=-1)
                cer, wer = self.error_calculator_asr(ys_hat_asr.cpu(),
                                                     ys_pad.cpu())

        # CTC
        if self.mtlalpha > 0:
            batch_size = hs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if not self.training:
                ys_hat_ctc = self.ctc.argmax(
                    hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(),
                                                    ys_pad.cpu(),
                                                    is_ctc=True)
                # for visualization
                self.ctc.softmax(hs_pad)
        return loss_att, acc, loss_ctc, cer_ctc, cer, wer
コード例 #2
0
    def store_penultimate_state(self, xs_pad, ilens, ys_pad, moe_coes,
                                moe_coe_lens):
        moe_coes = moe_coes[:, :max(moe_coe_lens)]  # for data parallel
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # multi-encoder forward
        cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask)
        moe_coes = moe_coes.unsqueeze(-1)
        hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0]
        self.hs_pad = hs_pad

        # forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask, penultimate_state = self.decoder(
            ys_in_pad,
            ys_mask,
            hs_pad,
            hs_mask,
            moe_coes,
            return_penultimate_state=True)

        # plot penultimate_state, (B,T,att_dim)
        return penultimate_state.squeeze(0).detach().cpu().numpy()
コード例 #3
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats

        # 5. compute bleu
        if self.training or self.error_calculator is None:
            bleu = 0.0
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_mt
        self.loss = loss

        loss_data = float(self.loss)
        if self.normalize_length:
            self.ppl = np.exp(loss_data)
        else:
            ys_out_pad = ys_out_pad.view(-1)
            ignore = ys_out_pad == self.ignore_id  # (B,)
            total = len(ys_out_pad) - ignore.sum().item()
            self.ppl = np.exp(loss_data * ys_out_pad.size(0) / total)
        if not math.isnan(loss_data):
            self.reporter.report(loss_data, self.acc, self.ppl, bleu)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
コード例 #4
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        if self.etype == 'transformer':
            xs_pad = xs_pad[:, :max(ilens)]
            src_mask = (~make_pad_mask(ilens.tolist())).to(
                xs_pad.device).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        else:
            hs_pad, hlens = xs_pad, ilens
            hs_pad, hlens, _ = self.encoder(hs_pad, hlens)
            hs_mask = hlens
        self.hs_pad = hs_pad

        # 1.5. transducer preparation related
        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if self.dtype == 'transformer':
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            if self.rnnt_mode == 'rnnt':
                pred_pad = self.decoder(hs_pad, ys_in_pad)
            else:
                pred_pad = self.decoder(hs_pad, ys_in_pad, pred_len)
        self.pred_pad = pred_pad

        # 3. loss computation
        loss = self.criterion(pred_pad, target, pred_len, target_len)

        self.loss = loss
        loss_data = float(self.loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)

        return self.loss
コード例 #5
0
def test_transformer_mask():
    args = make_arg()
    model, x, ilens, y, data, uttid_list = prepare("pytorch", args)
    yi, yo = add_sos_eos(y, model.sos, model.eos, model.ignore_id)
    y_mask = target_mask(yi, model.ignore_id)
    y = model.decoder.embed(yi)
    y[0, 3:] = float("nan")
    a = model.decoder.decoders[0].self_attn
    a(y, y, y, y_mask)
    assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
コード例 #6
0
def test_transformer_mask(module):
    model, x, ilens, y, data = prepare(module)
    from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
    from espnet.nets.pytorch_backend.transformer.mask import target_mask
    yi, yo = add_sos_eos(y, model.sos, model.eos, model.ignore_id)
    y_mask = target_mask(yi, model.ignore_id)
    y = model.decoder.embed(yi)
    y[0, 3:] = float("nan")
    a = model.decoder.decoders[0].self_attn
    a(y, y, y, y_mask)
    assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
コード例 #7
0
    def forward(self, feats, feats_lengths):

        src_mask = make_non_pad_mask(feats_lengths.tolist()).to(feats.device).unsqueeze(-2)
        src_mask = src_mask.to(feats.device)
        h, hs_mask = self.encoder.encoder(feats, src_mask)
        ys_in_pad = self.encoder.init_decoder(h, .9)
        ys_mask = target_mask(ys_in_pad, ignore_id=self.encoder.eos).to(ys_in_pad.device)
        _, mask, encodings = self.encoder.decoder(ys_in_pad, ys_mask, h, hs_mask, return_hidden=True)
        encoding_lengths = mask.sum(-1)[:, -1]
        logits = self.classifier(encodings, encoding_lengths)
        return logits
コード例 #8
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        xs_pad = xs_pad[:, :max(ilens)]

        if "custom" in self.etype:
            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        else:
            hs_pad, hs_mask, _ = self.enc(xs_pad, ilens)

        # 1.5. transducer preparation related
        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if "custom" in self.dtype:
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            pred_pad = self.dec(hs_pad, ys_in_pad)

        z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1))

        # 3. loss computation
        loss = self.criterion(z, target, pred_len, target_len)

        self.loss = loss
        loss_data = float(loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
コード例 #9
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute attention loss
        self.loss = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        loss_data = float(self.loss)
        if self.normalize_length:
            self.ppl = np.exp(loss_data)
        else:
            batch_size = ys_out_pad.size(0)
            ys_out_pad = ys_out_pad.view(-1)
            ignore = ys_out_pad == self.ignore_id  # (B*T,)
            total_n_tokens = len(ys_out_pad) - ignore.sum().item()
            self.ppl = np.exp(loss_data * batch_size / total_n_tokens)
        if not math.isnan(loss_data):
            self.reporter.report(loss_data, self.acc, self.ppl, self.bleu)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
コード例 #10
0
    def store_penultimate_state(self, xs_pad, ilens, ys_pad, bnf_feats, bnf_feats_lens):
        bnf_feats = bnf_feats[:, :max(bnf_feats_lens)] # for data parallel
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask, bnf_feats)
        self.hs_pad = hs_pad

        # forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask, penultimate_state = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True)

        # plot penultimate_state, (B,T,att_dim)
        return penultimate_state.squeeze(0).detach().cpu().numpy()
コード例 #11
0
    def decoder_and_attention(self, hs_pad, hs_mask, ys_pad, batch_size):
        """Forward decoder and attention loss."""
        # forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        acc = th_accuracy(pred_pad.view(-1, self.odim),
                          ys_out_pad,
                          ignore_label=self.ignore_id)
        return pred_pad, pred_mask, loss_att, acc
コード例 #12
0
def test_sa_transducer_mask(module):
    from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
    from espnet.nets.pytorch_backend.transducer.utils import prepare_loss_inputs
    from espnet.nets.pytorch_backend.transformer.mask import target_mask

    train_args = make_train_args()
    model, x, ilens, y, data = prepare(module, train_args)

    # dummy mask
    x_mask = (~make_pad_mask(ilens.tolist())).to(x.device).unsqueeze(-2)

    _, target, _, _ = prepare_loss_inputs(y, x_mask)
    y_mask = target_mask(target, model.blank_id)

    y = model.decoder.embed(target.type(torch.long))
    y[0, 3:] = float("nan")

    a = model.decoder.decoders[0].self_attn
    a(y, y, y, y_mask)
    assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
コード例 #13
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        xs_pad = xs_pad[:, :max(ilens)]

        if "transformer" in self.etype:

            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            batchsize = xs_pad.size(0)
            inputs = xs_pad.unsqueeze(1)
            logging.info("inputs:{}".format(inputs.shape))
            logging.info("src_mask:{}".format(src_mask.shape))

            inputs_length = []
            if src_mask is not None:
                for mask in src_mask.tolist():
                    inputs_length.append(mask[0].count(True))

                for i in range(batchsize):
                    inputs_s = inputs[i].unsqueeze(0)[:, :,
                                                      0:inputs_length[i], :]
                    core_out = self.conv(inputs_s)
                    inputs_length[i] = core_out.size(2)

                inputs_length = torch.as_tensor(inputs_length)

            else:
                core_out = self.conv(inputs)
                inputs_length = core_out.size(2)
                inputs_length = torch.as_tensor(inputs_length)
                logging.info("inputs_length:{}".format(inputs_length))

            # block 1
            # the inputs shape of Conv2d is 4-dim of (bsz * c * l * w)
            # the inputs shape of Conv1d is 3-dim of (bsz * c * l)
            # the inputs shape of transformer is 3-dim of (l * bsz * c)
            # conv output format: (bsz * c * t * d)
            inputs = self.conv(inputs)

            # we can get a batch of 16 channels feature maps in all time steps
            # merge 16 channels of one timestep to create one self-attention input (batch, 16, dim)
            inputs = inputs.permute(2, 0, 1, 3)
            logging.info("inputs:{}".format(inputs.shape))
            merge = torch.zeros(inputs.size(0), batchsize, 512)

            for t in range(inputs.size(0)):  # max_length
                merge[t] = self.clayers(inputs[t],
                                        None)[0].reshape(batchsize, 512)

            xs = merge.permute(1, 0, 2)

            if inputs_length.dim() == 0:
                masks = make_non_pad_mask([inputs_length]).unsqueeze(-2)
            else:
                masks = make_non_pad_mask(inputs_length.tolist()).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs, masks)
        else:
            hs_pad, hs_mask, _ = self.enc(xs_pad, ilens)
        self.hs_pad = hs_pad

        # 1.5. transducer preparation related
        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if "transformer" in self.dtype:
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            if self.rnnt_mode == "rnnt":
                pred_pad = self.dec(hs_pad, ys_in_pad)
            else:
                pred_pad = self.dec(hs_pad, ys_in_pad, pred_len)
        self.pred_pad = pred_pad

        # 3. loss computation
        loss = self.criterion(pred_pad, target, pred_len, target_len)

        self.loss = loss
        loss_data = float(self.loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
コード例 #14
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # mlp moe forward
        cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask)
        hs_mask = cn_hs_mask  # cn_hs_mask & en_hs_mask are identical
        # gated add module
        """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b)  #(B, T, 1)
            xs = lambda * cn_xs + (1-lambda) * en_xs 
        """
        hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1)
        lambda_ = self.aggregation_module(
            hs_pad)  # (B,T,1)/(B,T,D), range from (0, 1)
        hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad
        self.hs_pad = hs_pad

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)

            # divide ys_pad into cn_ys & en_ys;
            # note that this target can directly pass to ctc module
            cn_ys, en_ys = partial_target(ys_pad, self.language_divider)
            cn_loss_ctc = self.cn_ctc(
                cn_hs_pad.view(batch_size, -1, self.adim), hs_len, cn_ys)
            en_loss_ctc = self.en_ctc(
                en_hs_pad.view(batch_size, -1, self.adim), hs_len, en_ys)
            loss_ctc = 0.5 * (cn_loss_ctc + en_loss_ctc)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                               hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, None, None, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
コード例 #15
0
    def forward(self, feats: torch.Tensor, feats_len: torch.Tensor,
                labels: torch.Tensor) -> torch.Tensor:
        """E2E forward.

        Args:
            feats: Feature sequences. (B, F, D_feats)
            feats_len: Feature sequences lengths. (B,)
            labels: Label ID sequences. (B, L)

        Returns:
            loss: Transducer loss value

        """
        # 1. encoder
        feats = feats[:, :max(feats_len)]

        if self.etype == "custom":
            feats_mask = (make_non_pad_mask(feats_len.tolist()).to(
                feats.device).unsqueeze(-2))

            _enc_out, _enc_out_len = self.encoder(feats, feats_mask)
        else:
            _enc_out, _enc_out_len, _ = self.enc(feats, feats_len)

        if self.use_auxiliary_enc_outputs:
            enc_out, aux_enc_out = _enc_out[0], _enc_out[1]
            enc_out_len, aux_enc_out_len = _enc_out_len[0], _enc_out_len[1]
        else:
            enc_out, aux_enc_out = _enc_out, None
            enc_out_len, aux_enc_out_len = _enc_out_len, None

        # 2. decoder
        dec_in = get_decoder_input(labels, self.blank_id, self.ignore_id)

        if self.dtype == "custom":
            self.decoder.set_device(enc_out.device)

            dec_in_mask = target_mask(dec_in, self.blank_id)
            dec_out, _ = self.decoder(dec_in, dec_in_mask)
        else:
            self.dec.set_device(enc_out.device)

            dec_out = self.dec(dec_in)

        # 3. transducer tasks computation
        losses = self.transducer_tasks(
            enc_out,
            aux_enc_out,
            dec_out,
            labels,
            enc_out_len,
            aux_enc_out_len,
        )

        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(
                enc_out, self.transducer_tasks.get_target())

        self.loss = sum(losses)
        loss_data = float(self.loss)

        if not math.isnan(loss_data):
            self.reporter.report(
                loss_data,
                *[float(loss) for loss in losses],
                cer,
                wer,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
コード例 #16
0
ファイル: e2e_asr_transformer.py プロジェクト: sas91/espnet
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        '''
        """
        if self.attention_enc_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ]:
            for layer in self.encoder.encoders:
                layer.self_attn.clamp_param()
        if self.attention_dec_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ]:
            for layer in self.decoder.decoders:
                layer.self_attn.clamp_param()

        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)
        # xkc09 Span attention loss computation
        # xkc09 Span attention size loss computation
        loss_span = 0
        if self.attention_enc_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_dynamic_span2'
        ]:
            loss_span += sum([
                layer.self_attn.get_mean_span()
                for layer in self.encoder.encoders
            ])
        if self.attention_dec_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_dynamic_span2'
        ]:
            loss_span += sum([
                layer.self_attn.get_mean_span()
                for layer in self.decoder.decoders
            ])
        # xkc09 Span attention ratio loss computation
        loss_ratio = 0
        if self.ratio_adaptive:
            # target_ratio = 0.5
            if self.attention_enc_type in [
                    'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                    'self_attn_dynamic_span2'
            ]:
                loss_ratio += sum([
                    1 - layer.self_attn.get_mean_ratio()
                    for layer in self.encoder.encoders
                ])
            if self.attention_dec_type in [
                    'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                    'self_attn_dynamic_span2'
            ]:
                loss_ratio += sum([
                    1 - layer.self_attn.get_mean_ratio()
                    for layer in self.decoder.decoders
                ])
        if (self.attention_enc_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ] or self.attention_dec_type in [
                'self_attn_dynamic_span', 'self_attn_adaptive_span',
                'self_attn_adaptive_span2', 'self_attn_fixed_span2',
                'self_attn_dynamic_span2'
        ]):
            if getattr(self, 'span_loss_coef', None):
                self.loss += (loss_span + loss_ratio) * self.span_loss_coef

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
コード例 #17
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """

        # 1. encoder
        # if etpye is transformer, deal the padding
        # xs_pad:[8, 393, 83]
        # ilens:[393 * 8]
        # src_mask:[8, 1, 393]
        # hs_mask:[8, 1, 65]
        if self.etype == "transformer":
            xs_pad = xs_pad[:, :max(ilens)]
            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        else:
            logging.info("enc!!!")
            hs_pad, hs_mask, _ = self.encoder(xs_pad, ilens)
        self.hs_pad = hs_pad

        # 1.5. transducer preparation related
        # ys_in_pad: sos,1,2,...,0 [8, 14]
        # target: 1,2,... [8, 13]
        # pred_len: [8]
        # target_len: [8]
        # ys_out_pad:1,2,...,eos,-1

        ys = [y[y != self.ignore_id] for y in ys_pad]

        eos = ys[0].new([self.eos])
        sos = ys[0].new([self.sos])

        ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]

        ys_out_pad = pad_list(ys_out, self.ignore_id)

        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)
        # 2. decoder
        # ys_mask:[8, 16, 16]

        if self.dtype == "transformer":
            ys_mask = target_mask(ys_in_pad, self.blank_id)

            pred_pad, pred_att, _ = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                                 hs_mask)
        else:
            if self.rnnt_mode == "rnnt":
                pred_pad = self.dec(hs_pad, ys_in_pad)
            else:
                pred_pad = self.dec(hs_pad, ys_in_pad, pred_len)
        self.pred_pad = pred_pad

        # 3. loss computation
        loss_att = F.cross_entropy(
            pred_att,
            ys_out_pad.view(-1),  # batch x olength
            ignore_index=self.ignore_id,
        )
        # compute perplexity
        # ppl = math.exp(loss_att.item())
        # -1: eos, which is removed in the loss computation
        loss_att *= np.mean([len(x) for x in ys_in]) - 1

        loss_rnnt = self.criterion(pred_pad, target, pred_len, target_len)

        # loss_ctc = self.ctc(hs_pad, pred_len, ys_pad)

        alpha = self.mtlalpha
        beta = self.mtlbeta
        gamma = self.mtlgamma

        self.loss_rnnt = loss_rnnt
        self.loss_att = loss_att
        # self.loss_ctc = loss_ctc

        # self.loss = alpha * self.loss_ctc + beta * self.loss_rnnt + gamma * self.loss_att
        self.loss = beta * self.loss_rnnt + gamma * self.loss_att
        # self.loss = alpha * self.loss_ctc
        loss_data = float(self.loss)
        # loss_ctc_data = float(self.loss_ctc)
        loss_att_data = float(self.loss_att)
        loss_rnnt_data = float(self.loss_rnnt)

        # loss_att_data = None
        # loss_rnnt_data = None
        # 4. compute cer/wer

        if self.training or self.error_calculator is None:
            logging.info("ALL none!!!!!")
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        # with open('/home/oshindo/espnet/egs/aishell/asr1/exp/train_sp_pytorch_e2e_asr_transducer/blstmp_ctc.txt', "a+") as fid:
        #     fid.write("loss:" + str(loss_ctc_data) + '\n')

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, loss_rnnt_data, loss_att_data, cer,
                                 wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
コード例 #18
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # mlp moe forward
        cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask)
        # gated add module
        """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b)  #(B, T, 1)
            xs = lambda * cn_xs + (1-lambda) * en_xs 
        """
        hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1)
        lambda_ = self.enc_lambda
        hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad
        self.hs_pad = hs_pad

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                               hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
コード例 #19
0
    def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        moe_coes = moe_coes[:, :max(moe_coe_lens)]  # for data parallel
        # here we use interpolation_coe to 'fix' initial moe_coes
        interp_factor = self.interp_factor  # 0.1 for example, similar to lsm
        moe_coes = (
            1 - interp_factor) * moe_coes + interp_factor / moe_coes.shape[2]

        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)

        # multi-encoder forward
        cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask)
        moe_coes = moe_coes.unsqueeze(-1)
        hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0]
        self.hs_pad = hs_pad

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                               hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
コード例 #20
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # CTC forward
        ys = [y[y != self.ignore_id] for y in ys_pad]
        y_len = max([len(y) for y in ys])
        ys_pad = ys_pad[:, :y_len]
        self.hs_pad = hs_pad
        cer_ctc = None
        batch_size = xs_pad.size(0)
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)

        # trigger mask
        start_time = time.time()
        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)

        return self.loss, loss_ctc_data, loss_att_data, self.acc
コード例 #21
0
    def forward(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        moe_coes = moe_coes[:, :max(moe_coe_lens)].long()  # for data parallel
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        # mlp moe forward
        cn_hs_pad, cn_hs_mask = self.cn_encoder(xs_pad, src_mask)
        en_hs_pad, en_hs_mask = self.en_encoder(xs_pad, src_mask)
        # gated add module
        """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b)  #(B, T, 1)
            xs = lambda * cn_xs + (1-lambda) * en_xs 
        """
        hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1)
        lambda_ = F.softmax(
            self.aggre_scaling * self.aggregation_module(hs_pad),
            -1).unsqueeze(-1)
        ctc_hs_pad = lambda_[:, :, 0] * cn_hs_pad + lambda_[:, :,
                                                            1] * en_hs_pad
        ctc_hs_mask = cn_hs_mask

        # plat attention mode, (B,T,D)*2 --> (B,2T,D)
        s2s_hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=1)
        # mask: (B,1,T) --> (B,1,2T)
        s2s_hs_mask = torch.cat((cn_hs_mask, en_hs_mask), dim=-1)
        # self.hs_pad = hs_pad

        # compute lid loss here, using lambda_
        # moe_coes (B, T, 2) ==> (B,T)
        moe_coes = moe_coes[:, :, 0]  # 0 for cn, 1 for en
        lambda_ = lambda_.squeeze(-1)
        if self.lid_mtl_alpha == 0.0:
            loss_lid = 0.0
        else:
            loss_lid = self.lid_criterion(lambda_, moe_coes)
        lid_acc = th_accuracy(
            lambda_.view(-1, 2), moe_coes,
            ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None
        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = ctc_hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(ctc_hs_pad.view(batch_size, -1, self.adim),
                                hs_len, ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(
                    ctc_hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            # 2. forward decoder
            ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                                self.ignore_id)
            ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, s2s_hs_pad,
                                               s2s_hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            acc = th_accuracy(pred_pad.view(-1, self.odim),
                              ys_out_pad,
                              ignore_label=self.ignore_id)
        self.acc = acc

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        lid_alpha = self.lid_mtl_alpha
        if alpha == 0:
            self.loss = loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc + lid_alpha * loss_lid
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (
                1 - alpha) * loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data, lid_acc)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
コード例 #22
0
    def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        batch_size = xs_pad.shape[0]
        src_mask = make_non_pad_mask(ilens.tolist()).to(
            xs_pad.device).unsqueeze(-2)
        if isinstance(self.encoder.embed, EncoderConv2d):
            xs, hs_mask = self.encoder.embed(xs_pad,
                                             torch.sum(src_mask, 2).squeeze())
            hs_mask = hs_mask.unsqueeze(1)
        else:
            xs, hs_mask = self.encoder.embed(xs_pad, src_mask)

        if enc_mask is not None:
            enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]]
        enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask
        hs_pad, _ = self.encoder.encoders(xs, enc_mask)
        if self.encoder.normalize_before:
            hs_pad = self.encoder.after_norm(hs_pad)

        # CTC forward
        ys = [y[y != self.ignore_id] for y in ys_pad]
        y_len = max([len(y) for y in ys])
        ys_pad = ys_pad[:, :y_len]
        if dec_mask is not None:
            dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]]
        self.hs_pad = hs_pad
        batch_size = xs_pad.size(0)
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)

        # trigger mask
        hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask
        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        return self.loss, loss_ctc_data, loss_att_data, self.acc
コード例 #23
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask, penultimate_state = self.decoder(
            ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True)
        self.pred_pad = pred_pad

        # 3. compute attention loss
        loss_att = self.criterion(pred_pad, ys_out_pad)
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)
        # 4. compute lid multitask loss
        src_att = self.lid_src_att(penultimate_state, hs_pad, hs_pad, hs_mask)
        pred_lid_pad = self.lid_output_layer(src_att)
        loss_lid, lid_ys_out_pad = self.lid_criterion(pred_lid_pad, ys_out_pad)
        lid_acc = th_accuracy(
            pred_lid_pad.view(-1, self.lid_odim),
            lid_ys_out_pad,
            ignore_label=self.ignore_id) if self.log_lid_mtl_acc else None

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len,
                                ys_pad)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1,
                                                     self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(),
                                                ys_pad.cpu(),
                                                is_ctc=True)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        lid_alpha = self.lid_mtl_alpha
        if alpha == 0:
            self.loss = loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            raise Exception("LID MTL not supports pure ctc mode")
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (
                1 - alpha) * loss_att + lid_alpha * loss_lid
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, self.acc,
                                 cer_ctc, cer, wer, loss_data, lid_acc)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
コード例 #24
0
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Extract target language ID
        tgt_lang_ids = None
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining

        # 1. forward encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        # replace <sos> with target language ID
        if self.replace_sos:
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute ST loss
        loss_att = self.criterion(pred_pad, ys_out_pad)

        self.acc = th_accuracy(
            pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
        )

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # 5. compute auxiliary ASR loss
        loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr(
            hs_pad, hs_mask, ys_pad_src
        )

        # 6. compute auxiliary MT loss
        loss_mt, acc_mt = 0.0, None
        if self.mt_weight > 0:
            loss_mt, acc_mt = self.forward_mt(
                ys_pad_src, ys_in_pad, ys_out_pad, ys_mask
            )

        asr_ctc_weight = self.mtlalpha
        self.loss = (
            (1 - self.asr_weight - self.mt_weight) * loss_att
            + self.asr_weight
            * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att)
            + self.mt_weight * loss_mt
        )
        loss_asr_data = float(
            asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att
        )
        loss_mt_data = None if self.mt_weight == 0 else float(loss_mt)
        loss_st_data = float(loss_att)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_asr_data,
                loss_mt_data,
                loss_st_data,
                acc_asr,
                acc_mt,
                self.acc,
                cer_ctc,
                cer,
                wer,
                self.bleu,
                loss_data,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
コード例 #25
0
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Extract target language ID
        tgt_lang_ids = None
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:,
                            1:]  # remove target language ID in the beggining

        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel
        src_mask = (~make_pad_mask(ilens.tolist())).to(
            xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
                                            self.ignore_id)
        # replace <sos> with target language ID
        if self.replace_sos:
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
        ys_mask = target_mask(ys_in_pad, self.ignore_id)
        pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)

        # 3. compute ST loss
        loss_asr_att, loss_asr_ctc, loss_mt = 0.0, 0.0, 0.0
        acc_asr, acc_mt = 0.0, 0.0
        loss_att = self.criterion(pred_pad, ys_out_pad)

        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_out_pad,
                               ignore_label=self.ignore_id)

        # 4. compute corpus-level bleu in a mini-batch
        if self.training:
            self.bleu = None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # 5. compute auxiliary ASR loss
        cer, wer = None, None
        cer_ctc = None
        if self.asr_weight > 0:
            # attention
            if self.mtlalpha < 1:
                ys_in_pad_asr, ys_out_pad_asr = add_sos_eos(
                    ys_pad_src, self.sos, self.eos, self.ignore_id)
                ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id)
                pred_pad_asr, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr,
                                                   hs_pad, hs_mask)
                loss_asr_att = self.criterion(pred_pad_asr, ys_out_pad_asr)

                acc_asr = th_accuracy(
                    pred_pad_asr.view(-1, self.odim),
                    ys_out_pad_asr,
                    ignore_label=self.ignore_id,
                )
                if not self.training:
                    ys_hat_asr = pred_pad_asr.argmax(dim=-1)
                    cer, wer = self.error_calculator_asr(
                        ys_hat_asr.cpu(), ys_pad_src.cpu())

            # CTC
            if self.mtlalpha > 0:
                batch_size = xs_pad.size(0)
                hs_len = hs_mask.view(batch_size, -1).sum(1)
                loss_asr_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim),
                                        hs_len, ys_pad_src)
                ys_hat_ctc = self.ctc.argmax(
                    hs_pad.view(batch_size, -1, self.adim)).data
                if not self.training:
                    cer_ctc = self.error_calculator_asr(ys_hat_ctc.cpu(),
                                                        ys_pad_src.cpu(),
                                                        is_ctc=True)

        # 6. compute auxiliary MT loss
        if self.mt_weight > 0:
            ilens_mt = torch.sum(ys_pad_src != self.ignore_id,
                                 dim=1).cpu().numpy()
            # NOTE: ys_pad_src is padded with -1
            ys_src = [y[y != self.ignore_id]
                      for y in ys_pad_src]  # parse padded ys_src
            ys_zero_pad_src = pad_list(ys_src, self.pad)  # re-pad with zero
            ys_zero_pad_src = ys_zero_pad_src[:, :max(
                ilens_mt)]  # for data parallel
            src_mask_mt = ((~make_pad_mask(ilens_mt.tolist())).to(
                ys_zero_pad_src.device).unsqueeze(-2))
            hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src,
                                                    src_mask_mt)
            pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt,
                                          hs_mask_mt)
            loss_mt = self.criterion(pred_pad_mt, ys_out_pad)

            acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim),
                                 ys_out_pad,
                                 ignore_label=self.ignore_id)

        alpha = self.mtlalpha
        self.loss = ((1 - self.asr_weight - self.mt_weight) * loss_att +
                     self.asr_weight * (alpha * loss_asr_ctc +
                                        (1 - alpha) * loss_asr_att) +
                     self.mt_weight * loss_mt)
        loss_asr_data = float(alpha * loss_asr_ctc +
                              (1 - alpha) * loss_asr_att)
        loss_mt_data = None if self.mt_weight == 0 else float(loss_mt)
        loss_st_data = float(loss_att)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_asr_data,
                loss_mt_data,
                loss_st_data,
                acc_asr,
                acc_mt,
                self.acc,
                cer_ctc,
                cer,
                wer,
                self.bleu,
                loss_data,
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
コード例 #26
0
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_src):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax)
        :return: ctc loass value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Extract target language ID
        # src_lang_ids = None
        tgt_lang_ids, tgt_lang_ids_src = None, None
        if self.multilingual:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining

        if self.one_to_many:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining
            
            if self.do_asr:
                tgt_lang_ids_src = ys_pad_src[:, 0:1]
                ys_pad_src = ys_pad_src[:, 1:]  # remove target language ID in the beggining

        # 1. forward encoder
        xs_pad = xs_pad[:, :max(ilens)]  # for data parallel # bs x max_ilens x idim

        if self.lang_tok == "encoder-pre-sum":
            lang_embed = self.language_embeddings(tgt_lang_ids) # bs x 1 x idim
            xs_pad = xs_pad + lang_embed

        src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) # bs x 1 x max_ilens
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # hs_pad: bs x (max_ilens/4) x adim; hs_mask: bs x 1 x (max_ilens/4)
        self.hs_pad = hs_pad

        # 2. forward decoder
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # bs x max_lens

        if self.do_asr:
            ys_in_pad_src, ys_out_pad_src = add_sos_eos(ys_pad_src, self.sos, self.eos, self.ignore_id) # bs x max_lens_src

        # replace <sos> with target language ID
        if self.replace_sos:
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)

        if self.lang_tok == "decoder-pre":
            ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1)
            if self.do_asr:
                ys_in_pad_src = torch.cat([tgt_lang_ids_src, ys_in_pad_src[:, 1:]], dim=1)

        ys_mask = target_mask(ys_in_pad, self.ignore_id) # bs x max_lens x max_lens

        if self.do_asr:
            ys_mask_src = target_mask(ys_in_pad_src, self.ignore_id) # bs x max_lens x max_lens_src

        if self.wait_k_asr > 0:
            cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=self.wait_k_asr)
            cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=-self.wait_k_asr)
        elif self.wait_k_st > 0:
            cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=-self.wait_k_st)
            cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=self.wait_k_st)
        else:
            cross_mask = create_cross_mask(ys_in_pad, ys_in_pad_src, self.ignore_id, wait_k_cross=0)
            cross_mask_asr = create_cross_mask(ys_in_pad_src, ys_in_pad, self.ignore_id, wait_k_cross=0)

        pred_pad, pred_mask, pred_pad_asr, pred_mask_asr = self.dual_decoder(ys_in_pad, ys_mask, ys_in_pad_src, ys_mask_src,
                                                                                hs_pad, hs_mask, cross_mask, cross_mask_asr,
                                                                                cross_self=self.cross_self, cross_src=self.cross_src,
                                                                                cross_self_from=self.cross_self_from,
                                                                                cross_src_from=self.cross_src_from)

        self.pred_pad = pred_pad
        self.pred_pad_asr = pred_pad_asr
        pred_pad_mt = None

        # 3. compute attention loss
        loss_asr, loss_mt = 0.0, 0.0
        loss_att = self.criterion(pred_pad, ys_out_pad)

        # compute loss
        loss_asr = self.criterion(pred_pad_asr, ys_out_pad_src)
        # Multi-task w/ MT
        if self.mt_weight > 0:
            # forward MT encoder
            ilens_mt = torch.sum(ys_pad_src != self.ignore_id, dim=1).cpu().numpy()
            # NOTE: ys_pad_src is padded with -1
            ys_src = [y[y != self.ignore_id] for y in ys_pad_src]  # parse padded ys_src
            ys_zero_pad_src = pad_list(ys_src, self.pad)  # re-pad with zero
            ys_zero_pad_src = ys_zero_pad_src[:, :max(ilens_mt)]  # for data parallel
            src_mask_mt = (~make_pad_mask(ilens_mt.tolist())).to(ys_zero_pad_src.device).unsqueeze(-2)
            # ys_zero_pad_src, ys_pad = self.target_forcing(ys_zero_pad_src, ys_pad)
            hs_pad_mt, hs_mask_mt = self.encoder_mt(ys_zero_pad_src, src_mask_mt)
            # forward MT decoder
            pred_pad_mt, _ = self.decoder(ys_in_pad, ys_mask, hs_pad_mt, hs_mask_mt)
            # compute loss
            loss_mt = self.criterion(pred_pad_mt, ys_out_pad)

        self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
                               ignore_label=self.ignore_id)
        if pred_pad_asr is not None:
            self.acc_asr = th_accuracy(pred_pad_asr.view(-1, self.odim), ys_out_pad_src,
                                       ignore_label=self.ignore_id)
        else:
            self.acc_asr = 0.0
        if pred_pad_mt is not None:
            self.acc_mt = th_accuracy(pred_pad_mt.view(-1, self.odim), ys_out_pad,
                                      ignore_label=self.ignore_id)
        else:
            self.acc_mt = 0.0

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0 or self.asr_weight == 0:
            loss_ctc = 0.0
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad_src)
            if self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad_src.cpu(), is_ctc=True)

        # 5. compute cer/wer
        cer, wer = None, None  # TODO(hirofumi0810): fix later
        # if self.training or (self.asr_weight == 0 or self.mtlalpha == 1 or not (self.report_cer or self.report_wer)):
        #     cer, wer = None, None
        # else:
        #     ys_hat = pred_pad.argmax(dim=-1)
        #     cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copyied from e2e_asr
        alpha = self.mtlalpha
        self.loss = (1 - self.asr_weight - self.mt_weight) * loss_att + self.asr_weight * \
            (alpha * loss_ctc + (1 - alpha) * loss_asr) + self.mt_weight * loss_mt
        loss_asr_data = float(alpha * loss_ctc + (1 - alpha) * loss_asr)
        loss_mt_data = None if self.mt_weight == 0 else float(loss_mt)
        loss_st_data = float(loss_att)

        # logging.info(f'loss_st_data={loss_st_data}')

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_asr_data, loss_mt_data, loss_st_data,
                                 self.acc_asr, self.acc_mt, self.acc,
                                 cer_ctc, cer, wer, 0.0,  # TODO(hirofumi0810): bleu
                                 loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
コード例 #27
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of source sequences (B)
        :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 1. forward encoder
        xs_pad = xs_pad[:, : max(ilens)]  # for data parallel
        src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
        hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        self.hs_pad = hs_pad

        # 2. forward decoder
        if self.decoder is not None:
            if self.decoder_mode == "maskctc":
                ys_in_pad, ys_out_pad = mask_uniform(
                    ys_pad, self.mask_token, self.eos, self.ignore_id
                )
                ys_mask = (ys_in_pad != self.ignore_id).unsqueeze(-2)
            else:
                ys_in_pad, ys_out_pad = add_sos_eos(
                    ys_pad, self.sos, self.eos, self.ignore_id
                )
                ys_mask = target_mask(ys_in_pad, self.ignore_id)
            pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
            self.pred_pad = pred_pad

            # 3. compute attention loss
            loss_att = self.criterion(pred_pad, ys_out_pad)
            self.acc = th_accuracy(
                pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
            )
        else:
            loss_att = None
            self.acc = None

        # TODO(karita) show predicted text
        # TODO(karita) calculate these stats
        cer_ctc = None
        if self.mtlalpha == 0.0:
            loss_ctc = None
        else:
            batch_size = xs_pad.size(0)
            hs_len = hs_mask.view(batch_size, -1).sum(1)
            loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
            if not self.training and self.error_calculator is not None:
                ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
                cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
            # for visualization
            if not self.training:
                self.ctc.softmax(hs_pad)

        # 5. compute cer/wer
        if self.training or self.error_calculator is None or self.decoder is None:
            cer, wer = None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())

        # copied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
コード例 #28
0
ファイル: e2e_asr_transducer.py プロジェクト: zxpan/espnet
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        xs_pad = xs_pad[:, :max(ilens)]

        if "custom" in self.etype:
            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            _hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        else:
            _hs_pad, hs_mask, _ = self.enc(xs_pad, ilens)

        if self.use_aux_task:
            hs_pad, aux_hs_pad = _hs_pad[0], _hs_pad[1]
        else:
            hs_pad, aux_hs_pad = _hs_pad, None

        # 1.5. transducer preparation related
        ys_in_pad, ys_out_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if "custom" in self.dtype:
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            pred_pad = self.dec(hs_pad, ys_in_pad)

        z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1))

        # 3. loss computation
        loss_trans = self.criterion(z, target, pred_len, target_len)

        if self.use_aux_task and aux_hs_pad is not None:
            loss_trans += self.auxiliary_task(aux_hs_pad, pred_pad, z, target,
                                              pred_len, target_len)

        if self.use_aux_ctc:
            if "custom" in self.etype:
                hs_mask = torch.IntTensor([h.size(1) for h in hs_mask], ).to(
                    hs_mask.device)

            loss_ctc = self.aux_ctc(hs_pad, hs_mask, ys_pad)
        else:
            loss_ctc = 0

        if self.use_aux_cross_entropy:
            loss_ce = self.aux_cross_entropy(self.aux_decoder_output(pred_pad),
                                             ys_out_pad)
        else:
            loss_ce = 0

        loss = (self.transducer_weight * loss_trans +
                self.aux_ctc_weight * loss_ctc +
                self.aux_cross_entropy_weight * loss_ce)

        self.loss = loss
        loss_data = float(loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss