Exemplo n.º 1
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        if self.etype == 'transformer':
            xs_pad = xs_pad[:, :max(ilens)]
            src_mask = (~make_pad_mask(ilens.tolist())).to(
                xs_pad.device).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        else:
            hs_pad, hlens = xs_pad, ilens
            hs_pad, hlens, _ = self.encoder(hs_pad, hlens)
            hs_mask = hlens
        self.hs_pad = hs_pad

        # 1.5. transducer preparation related
        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if self.dtype == 'transformer':
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            if self.rnnt_mode == 'rnnt':
                pred_pad = self.decoder(hs_pad, ys_in_pad)
            else:
                pred_pad = self.decoder(hs_pad, ys_in_pad, pred_len)
        self.pred_pad = pred_pad

        # 3. loss computation
        loss = self.criterion(pred_pad, target, pred_len, target_len)

        self.loss = loss
        loss_data = float(self.loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)

        return self.loss
Exemplo n.º 2
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        xs_pad = xs_pad[:, :max(ilens)]

        if "custom" in self.etype:
            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        else:
            hs_pad, hs_mask, _ = self.enc(xs_pad, ilens)

        # 1.5. transducer preparation related
        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if "custom" in self.dtype:
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            pred_pad = self.dec(hs_pad, ys_in_pad)

        z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1))

        # 3. loss computation
        loss = self.criterion(z, target, pred_len, target_len)

        self.loss = loss
        loss_data = float(loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
def test_sa_transducer_mask(module):
    from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
    from espnet.nets.pytorch_backend.transducer.utils import prepare_loss_inputs
    from espnet.nets.pytorch_backend.transformer.mask import target_mask

    train_args = make_train_args()
    model, x, ilens, y, data = prepare(module, train_args)

    # dummy mask
    x_mask = (~make_pad_mask(ilens.tolist())).to(x.device).unsqueeze(-2)

    _, target, _, _ = prepare_loss_inputs(y, x_mask)
    y_mask = target_mask(target, model.blank_id)

    y = model.decoder.embed(target.type(torch.long))
    y[0, 3:] = float("nan")

    a = model.decoder.decoders[0].self_attn
    a(y, y, y, y_mask)
    assert not numpy.isnan(a.attn[0, :, :3, :3].detach().numpy()).any()
Exemplo n.º 4
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """

        # 1. encoder
        # if etpye is transformer, deal the padding
        # xs_pad:[8, 393, 83]
        # ilens:[393 * 8]
        # src_mask:[8, 1, 393]
        # hs_mask:[8, 1, 65]
        if self.etype == "transformer":
            xs_pad = xs_pad[:, :max(ilens)]
            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs_pad, src_mask)

        else:
            logging.info("enc!!!")
            hs_pad, hs_mask, _ = self.encoder(xs_pad, ilens)
        self.hs_pad = hs_pad

        # 1.5. transducer preparation related
        # ys_in_pad: sos,1,2,...,0 [8, 14]
        # target: 1,2,... [8, 13]
        # pred_len: [8]
        # target_len: [8]
        # ys_out_pad:1,2,...,eos,-1

        ys = [y[y != self.ignore_id] for y in ys_pad]

        eos = ys[0].new([self.eos])
        sos = ys[0].new([self.sos])

        ys_in = [torch.cat([sos, y], dim=0) for y in ys]
        ys_out = [torch.cat([y, eos], dim=0) for y in ys]

        ys_out_pad = pad_list(ys_out, self.ignore_id)

        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)
        # 2. decoder
        # ys_mask:[8, 16, 16]

        if self.dtype == "transformer":
            ys_mask = target_mask(ys_in_pad, self.blank_id)

            pred_pad, pred_att, _ = self.decoder(ys_in_pad, ys_mask, hs_pad,
                                                 hs_mask)
        else:
            if self.rnnt_mode == "rnnt":
                pred_pad = self.dec(hs_pad, ys_in_pad)
            else:
                pred_pad = self.dec(hs_pad, ys_in_pad, pred_len)
        self.pred_pad = pred_pad

        # 3. loss computation
        loss_att = F.cross_entropy(
            pred_att,
            ys_out_pad.view(-1),  # batch x olength
            ignore_index=self.ignore_id,
        )
        # compute perplexity
        # ppl = math.exp(loss_att.item())
        # -1: eos, which is removed in the loss computation
        loss_att *= np.mean([len(x) for x in ys_in]) - 1

        loss_rnnt = self.criterion(pred_pad, target, pred_len, target_len)

        # loss_ctc = self.ctc(hs_pad, pred_len, ys_pad)

        alpha = self.mtlalpha
        beta = self.mtlbeta
        gamma = self.mtlgamma

        self.loss_rnnt = loss_rnnt
        self.loss_att = loss_att
        # self.loss_ctc = loss_ctc

        # self.loss = alpha * self.loss_ctc + beta * self.loss_rnnt + gamma * self.loss_att
        self.loss = beta * self.loss_rnnt + gamma * self.loss_att
        # self.loss = alpha * self.loss_ctc
        loss_data = float(self.loss)
        # loss_ctc_data = float(self.loss_ctc)
        loss_att_data = float(self.loss_att)
        loss_rnnt_data = float(self.loss_rnnt)

        # loss_att_data = None
        # loss_rnnt_data = None
        # 4. compute cer/wer

        if self.training or self.error_calculator is None:
            logging.info("ALL none!!!!!")
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        # with open('/home/oshindo/espnet/egs/aishell/asr1/exp/train_sp_pytorch_e2e_asr_transducer/blstmp_ctc.txt', "a+") as fid:
        #     fid.write("loss:" + str(loss_ctc_data) + '\n')

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, loss_rnnt_data, loss_att_data, cer,
                                 wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
Exemplo n.º 5
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        xs_pad = xs_pad[:, :max(ilens)]

        if "custom" in self.etype:
            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            _hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
        else:
            _hs_pad, hs_mask, _ = self.enc(xs_pad, ilens)

        if self.use_aux_task:
            hs_pad, aux_hs_pad = _hs_pad[0], _hs_pad[1]
        else:
            hs_pad, aux_hs_pad = _hs_pad, None

        # 1.5. transducer preparation related
        ys_in_pad, ys_out_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if "custom" in self.dtype:
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            pred_pad = self.dec(hs_pad, ys_in_pad)

        z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1))

        # 3. loss computation
        loss_trans = self.criterion(z, target, pred_len, target_len)

        if self.use_aux_task and aux_hs_pad is not None:
            loss_trans += self.auxiliary_task(aux_hs_pad, pred_pad, z, target,
                                              pred_len, target_len)

        if self.use_aux_ctc:
            if "custom" in self.etype:
                hs_mask = torch.IntTensor([h.size(1) for h in hs_mask], ).to(
                    hs_mask.device)

            loss_ctc = self.aux_ctc(hs_pad, hs_mask, ys_pad)
        else:
            loss_ctc = 0

        if self.use_aux_cross_entropy:
            loss_ce = self.aux_cross_entropy(self.aux_decoder_output(pred_pad),
                                             ys_out_pad)
        else:
            loss_ce = 0

        loss = (self.transducer_weight * loss_trans +
                self.aux_ctc_weight * loss_ctc +
                self.aux_cross_entropy_weight * loss_ce)

        self.loss = loss
        loss_data = float(loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss
Exemplo n.º 6
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)

        Returns:
            loss (torch.Tensor): transducer loss value

        """
        # 1. encoder
        xs_pad = xs_pad[:, :max(ilens)]

        if "transformer" in self.etype:

            src_mask = make_non_pad_mask(ilens.tolist()).to(
                xs_pad.device).unsqueeze(-2)

            batchsize = xs_pad.size(0)
            inputs = xs_pad.unsqueeze(1)
            logging.info("inputs:{}".format(inputs.shape))
            logging.info("src_mask:{}".format(src_mask.shape))

            inputs_length = []
            if src_mask is not None:
                for mask in src_mask.tolist():
                    inputs_length.append(mask[0].count(True))

                for i in range(batchsize):
                    inputs_s = inputs[i].unsqueeze(0)[:, :,
                                                      0:inputs_length[i], :]
                    core_out = self.conv(inputs_s)
                    inputs_length[i] = core_out.size(2)

                inputs_length = torch.as_tensor(inputs_length)

            else:
                core_out = self.conv(inputs)
                inputs_length = core_out.size(2)
                inputs_length = torch.as_tensor(inputs_length)
                logging.info("inputs_length:{}".format(inputs_length))

            # block 1
            # the inputs shape of Conv2d is 4-dim of (bsz * c * l * w)
            # the inputs shape of Conv1d is 3-dim of (bsz * c * l)
            # the inputs shape of transformer is 3-dim of (l * bsz * c)
            # conv output format: (bsz * c * t * d)
            inputs = self.conv(inputs)

            # we can get a batch of 16 channels feature maps in all time steps
            # merge 16 channels of one timestep to create one self-attention input (batch, 16, dim)
            inputs = inputs.permute(2, 0, 1, 3)
            logging.info("inputs:{}".format(inputs.shape))
            merge = torch.zeros(inputs.size(0), batchsize, 512)

            for t in range(inputs.size(0)):  # max_length
                merge[t] = self.clayers(inputs[t],
                                        None)[0].reshape(batchsize, 512)

            xs = merge.permute(1, 0, 2)

            if inputs_length.dim() == 0:
                masks = make_non_pad_mask([inputs_length]).unsqueeze(-2)
            else:
                masks = make_non_pad_mask(inputs_length.tolist()).unsqueeze(-2)

            hs_pad, hs_mask = self.encoder(xs, masks)
        else:
            hs_pad, hs_mask, _ = self.enc(xs_pad, ilens)
        self.hs_pad = hs_pad

        # 1.5. transducer preparation related
        ys_in_pad, target, pred_len, target_len = prepare_loss_inputs(
            ys_pad, hs_mask)

        # 2. decoder
        if "transformer" in self.dtype:
            ys_mask = target_mask(ys_in_pad, self.blank_id)
            pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad)
        else:
            if self.rnnt_mode == "rnnt":
                pred_pad = self.dec(hs_pad, ys_in_pad)
            else:
                pred_pad = self.dec(hs_pad, ys_in_pad, pred_len)
        self.pred_pad = pred_pad

        # 3. loss computation
        loss = self.criterion(pred_pad, target, pred_len, target_len)

        self.loss = loss
        loss_data = float(self.loss)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer = None, None
        else:
            cer, wer = self.error_calculator(hs_pad, ys_pad)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)

        return self.loss