Example #1
0
    def translate_batch(self, xs, trans_args, char_list, rnnlm=None):
        """E2E batch beam search.

        :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...]
        :param Namespace trans_args: argument Namespace containing options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        prev = self.training
        self.eval()
        ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)

        # subsample frame
        xs = [xx[::self.subsample[0], :] for xx in xs]
        xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
        xs_pad = pad_list(xs, 0.0)

        # 1. Encoder
        hs_pad, hlens, _ = self.enc(xs_pad, ilens)

        # 2. Decoder
        hlens = torch.tensor(list(map(int,
                                      hlens)))  # make sure hlens is tensor
        y = self.dec.recognize_beam_batch(hs_pad, hlens, None, trans_args,
                                          char_list, rnnlm)

        if prev:
            self.train()
        return y
 def forward(self, xs, labels=None):
     ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
     xs = [to_device(self.slu, to_torch_tensor(xx).float()) for xx in xs]
     xs_pad = pad_list(xs, 0.0)
     embeddings = self.slu(xs_pad, ilens, None)
     outputs = self.classifier(embeddings, labels)
     return outputs
Example #3
0
    def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
        """E2E beam search.

        :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...]
        :param Namespace recog_args: argument Namespace containing options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        prev = self.training
        self.eval()
        ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)

        # subsample frame
        xs = [xx[:: self.subsample[0], :] for xx in xs]
        xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
        xs_pad = pad_list(xs, 0.0)

        # 0. Frontend
        if self.frontend is not None:
            enhanced, hlens, mask = self.frontend(xs_pad, ilens)
            hs_pad, hlens = self.feature_transform(enhanced, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        batchsize = hs_pad.size(0)

        # 1. Encoder
        hyps, hlens, _ = self.enc(hs_pad, hlens)
        hyps = hyps.view(batchsize, -1, self.odim)

        return hyps
Example #4
0
    def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad):
        """E2E CTC probability calculation.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: CTC probability (B, Tmax, vocab)
        :rtype: float ndarray
        """
        probs = None
        if self.mtlalpha == 0:
            return probs

        self.eval()
        with torch.no_grad():
            # 0. Frontend
            if self.frontend is not None:
                hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad),
                                                    ilens)
                hs_pad, hlens = self.feature_transform(hs_pad, hlens)
            else:
                hs_pad, hlens = xs_pad, ilens

            # 1. Encoder
            hpad, hlens, _ = self.enc(hs_pad, hlens)

            # 2. CTC probs
            probs = self.ctc.softmax(hpad).cpu().numpy()
        self.train()
        return probs
Example #5
0
    def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
        """E2E attention calculation

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
        :return: attention weights with the following shape,
            1) multi-head case => attention weights (B, H, Lmax, Tmax),
            2) other case => attention weights (B, Lmax, Tmax).
        :rtype: float ndarray
        """
        with torch.no_grad():
            # 0. Frontend
            if self.frontend is not None:
                hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad),
                                                    ilens)
                hs_pad, hlens = self.feature_transform(hs_pad, hlens)
            else:
                hs_pad, hlens = xs_pad, ilens

            # 1. Encoder
            if self.replace_sos:
                tgt_lang_ids = ys_pad[:, 0:1]
                ys_pad = ys_pad[:,
                                1:]  # remove target language ID in the beggining
            else:
                tgt_lang_ids = None
            hpad, hlens, _ = self.enc(hs_pad, hlens)

            # 2. Decoder
            att_ws = self.dec.calculate_all_attentions(
                hpad, hlens, ys_pad, tgt_lang_ids=tgt_lang_ids)

        return att_ws
Example #6
0
    def enhance(self, xs):
        """Forward only the frontend stage.

        :param ndarray xs: input acoustic feature (T, C, F)
        """
        if self.frontend is None:
            raise RuntimeError('Frontend doesn\'t exist')
        prev = self.training
        self.eval()
        ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)

        # subsample frame
        xs = [xx[::self.subsample[0], :] for xx in xs]
        xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
        xs_pad = pad_list(xs, 0.0)
        enhanced, hlensm, mask = self.frontend(xs_pad, ilens)
        if prev:
            self.train()

        if isinstance(enhanced, (tuple, list)):
            enhanced = list(enhanced)
            mask = list(mask)
            for idx in range(len(enhanced)):  # number of speakers
                enhanced[idx] = enhanced[idx].cpu().numpy()
                mask[idx] = mask[idx].cpu().numpy()
            return enhanced, mask, ilens
        return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
Example #7
0
    def enhance(self, xs):
        """Forward only the frontend stage.

        Args:
            xs (ndarray): input acoustic feature (T, C, F)

        Returns:
            enhanced (ndarray):
            mask (torch.Tensor):
            ilens (torch.Tensor): batch of lengths of input sequences (B)

        """
        if self.frontend is None:
            raise RuntimeError('Frontend does\'t exist')
        prev = self.training
        self.eval()
        ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)

        # subsample frame
        xs = [xx[::self.subsample[0], :] for xx in xs]
        xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
        xs_pad = pad_list(xs, 0.0)
        enhanced, hlensm, mask = self.frontend(xs_pad, ilens)

        if prev:
            self.train()

        return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
Example #8
0
    def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
        """E2E attention calculation.

        Args:
            xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax)

        Returns:
            att_ws (ndarray): attention weights with the following shape,
                1) multi-head case => attention weights (B, H, Lmax, Tmax),
                2) other case => attention weights (B, Lmax, Tmax).

        """
        if self.rnnt_mode == 'rnnt':
            return []

        with torch.no_grad():
            # 0. Frontend
            if self.frontend is not None:
                hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
                hs_pad, hlens = self.feature_transform(hs_pad, hlens)
            else:
                hs_pad, hlens = xs_pad, ilens

            # encoder
            hpad, hlens, _ = self.enc(hs_pad, hlens)

            # decoder
            att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad)

        return att_ws
Example #9
0
    def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
        """E2E attention calculation.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad:
            batch of padded character id sequence tensor (B, num_spkrs, Lmax)
        :return: attention weights with the following shape,
            1) multi-head case => attention weights (B, H, Lmax, Tmax),
            2) other case => attention weights (B, Lmax, Tmax).
        :rtype: float ndarray
        """
        with torch.no_grad():
            # 0. Frontend
            if self.frontend is not None:
                hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad),
                                                    ilens)
                hlens_n = [None] * self.num_spkrs
                for i in range(self.num_spkrs):
                    hs_pad[i], hlens_n[i] = self.feature_transform(
                        hs_pad[i], hlens)
                hlens = hlens_n
            else:
                hs_pad, hlens = xs_pad, ilens

            # 1. Encoder
            if not isinstance(hs_pad,
                              list):  # single-channel multi-speaker input x
                hs_pad, hlens, _ = self.enc(hs_pad, hlens)
            else:  # multi-channel multi-speaker input x
                for i in range(self.num_spkrs):
                    hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])

            # Permutation
            ys_pad = ys_pad.transpose(0, 1)  # (num_spkrs, B, Lmax)
            if self.num_spkrs <= 3:
                loss_ctc = torch.stack(
                    [
                        self.ctc(
                            hs_pad[i // self.num_spkrs],
                            hlens[i // self.num_spkrs],
                            ys_pad[i % self.num_spkrs],
                        ) for i in range(self.num_spkrs**2)
                    ],
                    1,
                )  # (B, num_spkrs^2)
                loss_ctc, min_perm = self.pit.pit_process(loss_ctc)
            for i in range(ys_pad.size(1)):  # B
                ys_pad[:, i] = ys_pad[min_perm[i], i]

            # 2. Decoder
            att_ws = [
                self.dec.calculate_all_attentions(hs_pad[i],
                                                  hlens[i],
                                                  ys_pad[i],
                                                  strm_idx=i)
                for i in range(self.num_spkrs)
            ]

        return att_ws
Example #10
0
    def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
        """E2E attention calculation.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: attention weights with the following shape,
            1) multi-head case => attention weights (B, H, Lmax, Tmax),
            2) other case => attention weights (B, Lmax, Tmax).
        :rtype: float ndarray
        """
        self.eval()
        with torch.no_grad():
            # 0. Frontend
            if self.frontend is not None:
                hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad),
                                                    ilens)
                hs_pad, hlens = self.feature_transform(hs_pad, hlens)
            else:
                hs_pad, hlens = xs_pad, ilens

            # 1. Encoder
            hpad, hlens, _ = self.enc(hs_pad, hlens)

            # 2. Decoder
            att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad)
        self.train()
        return att_ws
Example #11
0
    def recognize(self, x, recog_args, char_list, rnnlm=None):
        """E2E beam search.

        :param ndarray x: input acoustic feature (T, D)
        :param Namespace recog_args: argument Namespace containing options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        prev = self.training
        self.eval()
        ilens = [x.shape[0]]

        # subsample frame
        x = x[::self.subsample[0], :]
        h = to_device(self, to_torch_tensor(x).float())
        # make a utt list (1) to use the same interface for encoder
        hs = h.contiguous().unsqueeze(0)

        # 0. Frontend
        if self.frontend is not None:
            hs, hlens, mask = self.frontend(hs, ilens)
            hlens_n = [None] * self.num_spkrs
            for i in range(self.num_spkrs):
                hs[i], hlens_n[i] = self.feature_transform(hs[i], hlens)
            hlens = hlens_n
        else:
            hs, hlens = hs, ilens

        # 1. Encoder
        if not isinstance(hs, list):  # single-channel multi-speaker input x
            hs, hlens, _ = self.enc(hs, hlens)
        else:  # multi-channel multi-speaker input x
            for i in range(self.num_spkrs):
                hs[i], hlens[i], _ = self.enc(hs[i], hlens[i])

        # calculate log P(z_t|X) for CTC scores
        if recog_args.ctc_weight > 0.0:
            lpz = [self.ctc.log_softmax(i)[0] for i in hs]
        else:
            lpz = None

        # 2. decoder
        # decode the first utterance
        y = [
            self.dec.recognize_beam(hs[i][0],
                                    lpz[i],
                                    recog_args,
                                    char_list,
                                    rnnlm,
                                    strm_idx=i) for i in range(self.num_spkrs)
        ]

        if prev:
            self.train()
        return y
Example #12
0
    def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
        """E2E batch beam search.

        :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...]
        :param Namespace recog_args: argument Namespace containing options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        prev = self.training
        self.eval()
        ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)

        # subsample frame
        xs = [xx[::self.subsample[0], :] for xx in xs]
        xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
        xs_pad = pad_list(xs, 0.0)

        # 0. Frontend
        if self.frontend is not None:
            enhanced, hlens, mask = self.frontend(xs_pad, ilens)
            hs_pad, hlens = self.feature_transform(enhanced, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. Encoder
        hs_pad, hlens, _ = self.enc(hs_pad, hlens)

        # calculate log P(z_t|X) for CTC scores
        if recog_args.ctc_weight > 0.0:
            lpz = self.ctc.log_softmax(hs_pad)
            normalize_score = False
        else:
            lpz = None
            normalize_score = True

        # 2. Decoder
        hlens = torch.tensor(list(map(int,
                                      hlens)))  # make sure hlens is tensor
        y = self.dec.recognize_beam_batch(
            hs_pad,
            hlens,
            lpz,
            recog_args,
            char_list,
            rnnlm,
            normalize_score=normalize_score,
        )

        if prev:
            self.train()
        return y
Example #13
0
    def forward_frontend_and_encoder(self, xs_pad, ilens):
        """Forward front-end and encoder."""
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, _ = self.frontend(to_torch_tensor(xs_pad), ilens)
            hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. Encoder
        hs_pad, hlens, _ = self.enc(hs_pad, hlens)
        return hs_pad, hlens
Example #14
0
    def recognize_batch(self, xs_list, recog_args, char_list, rnnlm=None):
        """E2E beam search.

        :param list xs_list: list of list of input acoustic feature arrays
                [[(T1_1, D), (T1_2, D), ...],[(T2_1, D), (T2_2, D), ...], ...]
        :param Namespace recog_args: argument Namespace containing options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        prev = self.training
        self.eval()
        ilens_list = [np.fromiter((xx.shape[0] for xx in xs_list[idx]), dtype=np.int64) for idx in range(self.num_encs)]

        # subsample frame
        xs_list = [[xx[::self.subsample_list[idx][0], :] for xx in xs_list[idx]] for idx in range(self.num_encs)]

        xs_list = [[to_device(self, to_torch_tensor(xx).float()) for xx in xs_list[idx]] for idx in
                   range(self.num_encs)]
        xs_pad_list = [pad_list(xs_list[idx], 0.0) for idx in range(self.num_encs)]

        # 1. Encoder
        hs_pad_list, hlens_list = [], []
        for idx in range(self.num_encs):
            hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx])
            hs_pad_list.append(hs_pad)
            hlens_list.append(hlens)

        # calculate log P(z_t|X) for CTC scores
        if recog_args.ctc_weight > 0.0:
            if self.share_ctc:
                lpz_list = [self.ctc[0].log_softmax(hs_pad_list[idx]) for idx in range(self.num_encs)]
            else:
                lpz_list = [self.ctc[idx].log_softmax(hs_pad_list[idx]) for idx in range(self.num_encs)]
            normalize_score = False
        else:
            lpz_list = None
            normalize_score = True

        # 2. Decoder
        hlens_list = [torch.tensor(list(map(int, hlens_list[idx]))) for idx in
                      range(self.num_encs)]  # make sure hlens is tensor
        y = self.dec.recognize_beam_batch(hs_pad_list, hlens_list, lpz_list, recog_args, char_list,
                                          rnnlm, normalize_score=normalize_score)

        if prev:
            self.train()
        return y
Example #15
0
    def calculate_alignments(self, xs_pad, ilens, ys_pad):
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. encoder
        hs_pad, hlens, _ = self.enc(hs_pad, hlens)

        # 2. decoder
        _, gammas = self.dec.rnnt_alignment(hs_pad, hlens, ys_pad)

        return gammas
Example #16
0
    def recognize(self, x, recog_args, char_list, rnnlm=None):
        """E2E beam search

        :param ndarray x: input acoustic feature (T, D)
        :param Namespace recog_args: argument Namespace containing options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        prev = self.training
        self.eval()
        ilens = [x.shape[0]]

        # subsample frame
        x = x[::self.subsample[0], :]
        h = to_device(self, to_torch_tensor(x).float())
        # make a utt list (1) to use the same interface for encoder
        hs = h.contiguous().unsqueeze(0)

        # 0. Frontend
        if self.frontend is not None:
            enhanced, hlens, mask = self.frontend(hs, ilens)
            hs, hlens = self.feature_transform(enhanced, hlens)
        else:
            hs, hlens = hs, ilens

        # 1. encoder
        hs, _, __ = self.enc(hs, hlens)

        print(hs.shape, _.shape, __.shape)
        exit(1)

        # calculate log P(z_t|X) for CTC scores
        if recog_args.ctc_weight > 0.0:
            lpz = self.ctc.log_softmax(hs)[0]
        else:
            lpz = None

        # 2. Decoder
        # decode the first utterance
        y = self.dec.recognize_beam(hs[0], lpz, recog_args, char_list, rnnlm)

        if prev:
            self.train()

        return y
Example #17
0
    def recognize(self, x, recog_args, char_list, rnnlm=None):
        """E2E recognize.

        Args:
            x (ndarray): input acoustic feature (T, D)
            recog_args (namespace): argument Namespace containing options
            char_list (list): list of characters
            rnnlm (torch.nn.Module): language model module

        Returns:
           y (list): n-best decoding results

        """
        prev = self.training
        self.eval()
        ilens = [x.shape[0]]

        # subsample frame
        x = x[::self.subsample[0], :]
        h = to_device(self, to_torch_tensor(x).float())
        # make a utt list (1) to use the same interface for encoder
        hs = h.contiguous().unsqueeze(0)

        # 0. Frontend
        if self.frontend is not None:
            enhanced, hlens, mask = self.frontend(hs, ilens)
            hs, hlens = self.feature_transform(enhanced, hlens)
        else:
            hs, hlens = hs, ilens

        # 1. Encoder
        h, _, _ = self.enc(hs, hlens)

        # 2. Decoder
        if recog_args.beam_size == 1:
            y = self.dec.recognize(h[0], recog_args)
        else:
            y = self.dec.recognize_beam(h[0], recog_args, rnnlm)

        if prev:
            self.train()

        return y
Example #18
0
    def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False):
        """Recognize input speech.

        :param ndnarray x: input acoustic feature (B, T, D) or (T, D)
        :param Namespace recog_args: argment Namespace contraining options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """

        self.eval()
        ilens = numpy.fromiter((xx.shape[0] for xx in x), dtype=numpy.int64)

        # subsample frame
        x = [xx[:: self.subsample[0], :] for xx in x]
        x = [to_device(self, to_torch_tensor(xx).float()) for xx in x]
        x = pad_list(x, 0.0)

        enc_output, _ = self.encoder(x, None)
        batchsize = x.size(0)

        if self.outer:
            post_pad = self.poster(enc_output)
            post_pad = post_pad.view(post_pad.size(0), -1, self.odim)
            if post_pad.size(1) != x.size(1):
                if post_pad.size(1) < x.size(1):
                    x = x[:, :post_pad.size(1)]
                else:
                    raise ValueError(
                        "target size {} and pred size {} is mismatch".format(x.size(1), post_pad.size(1)))
            if self.residual:
                post_pad = post_pad + self.matcher_res(x)
            else:
                post_pad = torch.cat([post_pad, x], dim=-1)
            hyps = self.matcher(post_pad)
        else:
            pred_pad = self.poster(enc_output)
            hyps = pred_pad.view(pred_pad.size(0), -1, self.odim)
        hyps = hyps.view(batchsize, -1, self.odim)

        return hyps
Example #19
0
    def encode_rnn(self, x):
        """Encode acoustic features.

        Args:
            x (ndarray): input acoustic feature (T, D)

        Returns:
            x (torch.Tensor): encoded features (T, attention_dim)

        """
        self.eval()

        ilens = [x.shape[0]]

        x = x[::self.subsample[0], :]
        h = to_device(self, to_torch_tensor(x).float())
        hs = h.contiguous().unsqueeze(0)

        h, _, _ = self.encoder(hs, ilens)

        return h[0]
Example #20
0
    def enhance(self, xs):
        """Forward only in the frontend stage.

        :param ndarray xs: input acoustic feature (T, C, F)
        :return: enhaned feature
        :rtype: torch.Tensor
        """
        if self.frontend is None:
            raise RuntimeError("Frontend does't exist")
        prev = self.training
        self.eval()
        ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)

        # subsample frame
        xs = [xx[::self.subsample[0], :] for xx in xs]
        xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
        xs_pad = pad_list(xs, 0.0)
        enhanced, hlensm, mask = self.frontend(xs_pad, ilens)
        if prev:
            self.train()
        return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
Example #21
0
    def forward(self, xs_pad, ilens, ys_pad, ys_pad_mono=None):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: loss value
        :rtype: torch.Tensor
        """
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. RNN Encoder
        hs_pad, hlens, _ = self.enc(hs_pad, hlens)

        # 2. post-processing layer for target dimension
        pred_pad = self.poster(hs_pad)
        pred_pad = pred_pad.view(pred_pad.size(0), -1, self.odim)
        self.pred_pad = pred_pad
        if pred_pad.size(1) != ys_pad.size(1):
            if pred_pad.size(1) < ys_pad.size(1):
                ys_pad = ys_pad[:, :pred_pad.size(1)].contiguous()
            else:
                raise ValueError(
                    "target size {} and pred size {} is mismatch".format(
                        ys_pad.size(1), pred_pad.size(1)))

        if ys_pad_mono is not None:
            pred_pad_mono = self.poster_mono(hs_pad)
            pred_pad_mono = pred_pad_mono.view(pred_pad_mono.size(0), -1,
                                               self.mono_odim)
            self.pred_pad_mono = pred_pad_mono
            if pred_pad_mono.size(1) != ys_pad_mono.size(1):
                if pred_pad_mono.size(1) < ys_pad_mono.size(1):
                    ys_pad_mono = ys_pad_mono[:, :pred_pad_mono.
                                              size(1)].contiguous()
                else:
                    raise ValueError(
                        "target size {} and pred size {} is mismatch".format(
                            ys_pad_mono.size(1), pred_pad_mono.size(1)))

        # 3. CTC loss
        if self.mtlalpha == 0:
            self.loss_ctc = None
        else:
            self.loss_ctc = self.ctc(pred_pad, hlens, ys_pad)

        # 3. CE loss
        if LooseVersion(torch.__version__) < LooseVersion("1.0"):
            reduction_str = "elementwise_mean"
        else:
            reduction_str = "mean"
        self.loss_ce_tri = F.cross_entropy(
            pred_pad.view(-1, self.odim),
            ys_pad.view(-1),
            ignore_index=self.ignore_id,
            reduction=reduction_str,
        )
        if ys_pad_mono is not None:
            self.loss_ce_mono = F.cross_entropy(
                pred_pad_mono.view(-1, self.odim),
                ys_pad_mono.view(-1),
                ignore_index=self.ignore_id,
                reduction=reduction_str,
            )
        else:
            self.loss_ce_mono = 0
        self.loss_ce = 0.6 * self.loss_ce_tri + 0.4 * self.loss_ce_mono
        self.acc = th_accuracy(pred_pad.view(-1, self.odim),
                               ys_pad,
                               ignore_label=self.ignore_id)

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer, cer_ctc = None, None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
            cer_ctc = self.error_calculator(ys_hat.cpu(),
                                            ys_pad.cpu(),
                                            is_ctc=True)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = self.loss_ce
            loss_ce_data = float(self.loss_ce)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = self.loss_ctc
            loss_ce_data = None
            loss_ctc_data = float(self.loss_ctc)
        else:
            self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_ce
            loss_ce_data = float(self.loss_ce)
            loss_ctc_data = float(self.loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_ce_data, self.acc,
                                 cer_ctc, cer, wer, loss_data)
        else:
            pass
        return self.loss
Example #22
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad:
            batch of padded character id sequence tensor (B, num_spkrs, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            if isinstance(hs_pad, list):
                hlens_n = [None] * self.num_spkrs
                for i in range(self.num_spkrs):
                    hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens)
                hlens = hlens_n
            else:
                hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. Encoder
        if not isinstance(
            hs_pad, list
        ):  # single-channel input xs_pad (single- or multi-speaker)
            hs_pad, hlens, _ = self.enc(hs_pad, hlens)
        else:  # multi-channel multi-speaker input xs_pad
            for i in range(self.num_spkrs):
                hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])

        # 2. CTC loss
        if self.mtlalpha == 0:
            loss_ctc, min_perm = None, None
        else:
            if not isinstance(hs_pad, list):  # single-speaker input xs_pad
                loss_ctc = torch.mean(self.ctc(hs_pad, hlens, ys_pad))
            else:  # multi-speaker input xs_pad
                ys_pad = ys_pad.transpose(0, 1)  # (num_spkrs, B, Lmax)
                loss_ctc_perm = torch.stack(
                    [
                        self.ctc(
                            hs_pad[i // self.num_spkrs],
                            hlens[i // self.num_spkrs],
                            ys_pad[i % self.num_spkrs],
                        )
                        for i in range(self.num_spkrs ** 2)
                    ],
                    dim=1,
                )  # (B, num_spkrs^2)
                loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm)
                logging.info("ctc loss:" + str(float(loss_ctc)))

        # 3. attention loss
        if self.mtlalpha == 1:
            loss_att = None
            acc = None
        else:
            if not isinstance(hs_pad, list):  # single-speaker input xs_pad
                loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad)
            else:
                for i in range(ys_pad.size(1)):  # B
                    ys_pad[:, i] = ys_pad[min_perm[i], i]
                rslt = [
                    self.dec(hs_pad[i], hlens[i], ys_pad[i], strm_idx=i)
                    for i in range(self.num_spkrs)
                ]
                loss_att = sum([r[0] for r in rslt]) / float(len(rslt))
                acc = sum([r[1] for r in rslt]) / float(len(rslt))
        self.acc = acc

        # 4. compute cer without beam search
        if self.mtlalpha == 0 or self.char_list is None:
            cer_ctc = None
        else:
            cers = []
            for ns in range(self.num_spkrs):
                y_hats = self.ctc.argmax(hs_pad[ns]).data
                for i, y in enumerate(y_hats):
                    y_hat = [x[0] for x in groupby(y)]
                    y_true = ys_pad[ns][i]

                    seq_hat = [
                        self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
                    ]
                    seq_true = [
                        self.char_list[int(idx)] for idx in y_true if int(idx) != -1
                    ]
                    seq_hat_text = "".join(seq_hat).replace(self.space, " ")
                    seq_hat_text = seq_hat_text.replace(self.blank, "")
                    seq_true_text = "".join(seq_true).replace(self.space, " ")

                    hyp_chars = seq_hat_text.replace(" ", "")
                    ref_chars = seq_true_text.replace(" ", "")
                    if len(ref_chars) > 0:
                        cers.append(
                            editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)
                        )

            cer_ctc = sum(cers) / len(cers) if cers else None

        # 5. compute cer/wer
        if (
            self.training
            or not (self.report_cer or self.report_wer)
            or not isinstance(hs_pad, list)
        ):
            cer, wer = 0.0, 0.0
        else:
            if self.recog_args.ctc_weight > 0.0:
                lpz = [
                    self.ctc.log_softmax(hs_pad[i]).data for i in range(self.num_spkrs)
                ]
            else:
                lpz = None

            word_eds, char_eds, word_ref_lens, char_ref_lens = [], [], [], []
            nbest_hyps = [
                self.dec.recognize_beam_batch(
                    hs_pad[i],
                    torch.tensor(hlens[i]),
                    lpz[i],
                    self.recog_args,
                    self.char_list,
                    self.rnnlm,
                    strm_idx=i,
                )
                for i in range(self.num_spkrs)
            ]
            # remove <sos> and <eos>
            y_hats = [
                [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps[i]]
                for i in range(self.num_spkrs)
            ]
            for i in range(len(y_hats[0])):
                hyp_words = []
                hyp_chars = []
                ref_words = []
                ref_chars = []
                for ns in range(self.num_spkrs):
                    y_hat = y_hats[ns][i]
                    y_true = ys_pad[ns][i]

                    seq_hat = [
                        self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
                    ]
                    seq_true = [
                        self.char_list[int(idx)] for idx in y_true if int(idx) != -1
                    ]
                    seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ")
                    seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "")
                    seq_true_text = "".join(seq_true).replace(
                        self.recog_args.space, " "
                    )

                    hyp_words.append(seq_hat_text.split())
                    ref_words.append(seq_true_text.split())
                    hyp_chars.append(seq_hat_text.replace(" ", ""))
                    ref_chars.append(seq_true_text.replace(" ", ""))

                tmp_word_ed = [
                    editdistance.eval(
                        hyp_words[ns // self.num_spkrs], ref_words[ns % self.num_spkrs]
                    )
                    for ns in range(self.num_spkrs ** 2)
                ]  # h1r1,h1r2,h2r1,h2r2
                tmp_char_ed = [
                    editdistance.eval(
                        hyp_chars[ns // self.num_spkrs], ref_chars[ns % self.num_spkrs]
                    )
                    for ns in range(self.num_spkrs ** 2)
                ]  # h1r1,h1r2,h2r1,h2r2

                word_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_word_ed))[0])
                word_ref_lens.append(len(sum(ref_words, [])))
                char_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_char_ed))[0])
                char_ref_lens.append(len("".join(ref_chars)))

            wer = (
                0.0
                if not self.report_wer
                else float(sum(word_eds)) / sum(word_ref_lens)
            )
            cer = (
                0.0
                if not self.report_cer
                else float(sum(char_eds)) / sum(char_ref_lens)
            )

        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = loss_ctc
            loss_att_data = None
            loss_ctc_data = float(loss_ctc)
        else:
            self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
            loss_att_data = float(loss_att)
            loss_ctc_data = float(loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
            )
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss
Example #23
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        Args:
            xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim)
            ilens (torch.Tensor): batch of lengths of input sequences (B)
            ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax)

        Returns:
               loss (torch.Tensor): transducer loss value

        """
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            if isinstance(hs_pad, list):
                hlens_n = [None] * self.num_spkrs
                for i in range(self.num_spkrs):
                    hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens)
                hlens = hlens_n
            else:
                hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. Encoder
        if not isinstance(hs_pad, list):  # single-channel input xs_pad (single- or multi-speaker)
            hs_pad, hlens, _ = self.enc(hs_pad, hlens)
        else:  # multi-channel multi-speaker input xs_pad
            for i in range(self.num_spkrs):
                hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])

        # 2. decoder
        loss = self.dec(hs_pad, hlens, ys_pad)

        # 3. compute cer/wer
        # note: not recommended outside debugging right now,
        # the training time is hugely impacted.
        if self.training or not (self.report_cer or self.report_wer):
            cer, wer = 0.0, 0.0
        else:
            word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], []

            batchsize = int(hs_pad.size(0))
            batch_nbest = []

            for b in six.moves.range(batchsize):
                nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args)
                batch_nbest.append(nbest_hyps)

            y_hats = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest]

            for i, y_hat in enumerate(y_hats):
                y_true = ys_pad[i]

                seq_hat = [self.char_list[int(idx)] for idx in y_hat]
                seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
                seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, ' ')
                seq_true_text = "".join(seq_true).replace(self.recog_args.space, ' ')

                hyp_words = seq_hat_text.split()
                ref_words = seq_true_text.split()
                word_eds.append(editdistance.eval(hyp_words, ref_words))
                word_ref_lens.append(len(ref_words))

                hyp_chars = seq_hat_text.replace(' ', '')
                ref_chars = seq_true_text.replace(' ', '')
                char_eds.append(editdistance.eval(hyp_chars, ref_chars))
                char_ref_lens.append(len(ref_chars))
            wer = 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens)
            cer = 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens)

        self.loss = loss
        loss_data = float(self.loss)

        if not math.isnan(loss_data):
            self.reporter.report(loss_data, cer, wer)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)

        return self.loss
Example #24
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: loss value
        :rtype: torch.Tensor
        """
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # TO DO aug
        #rand_idx=torch.randperm(hs_pad.size(0))
        #rand_ratio = 0.2 * torch.rand(1).to(hs_pad.device)
        #hs_pad = (1-rand_ratio) * hs_pad + rand_ratio * torch.flip(hs_pad, [1])[rand_idx].to(hs_pad.device)

        # 1. CNN Encoder 
        hs_pad, hlens, _ = self.enc(hs_pad, hlens) 

        # TODO: Not sure about oversampling & outer
        # 2. post-processing layer for target dimension
        if self.outer:
            post_pad = self.poster(hs_pad)
            post_pad = post_pad.view(post_pad.size(0), -1, self.odim)
            if post_pad.size(1) != xs_pad.size(1):
                if post_pad.size(1) < xs_pad.size(1):
                    xs_pad = xs_pad[:, :post_pad.size(1)].contiguous()
                else:
                    raise ValueError("target size {} and pred size {} is mismatch".format(xs_pad.size(1), post_pad.size(1)))
            if self.residual:
                post_pad = post_pad + self.matcher_res(xs_pad)
            else:
                post_pad = torch.cat([post_pad, xs_pad], dim=-1)
            pred_pad = self.matcher(post_pad)
        else:
            pred_pad = self.poster(hs_pad)
            pred_pad = pred_pad.view(pred_pad.size(0), -1, self.odim)
        self.pred_pad = pred_pad
        if pred_pad.size(1) != ys_pad.size(1):
            if pred_pad.size(1) < ys_pad.size(1):
                ys_pad = ys_pad[:, :pred_pad.size(1)].contiguous()
            else:
                raise ValueError("target size {} and pred size {} is mismatch".format(ys_pad.size(1), pred_pad.size(1)))

        # 3. CTC loss
        if self.mtlalpha == 0:
            self.loss_ctc = None
        else:
            self.loss_ctc = self.ctc(pred_pad, hlens, ys_pad)

        # 3. CE loss
        # print('pred_pad before loss computation', pred_pad.size()) # 64, 61, 3480
        # print('ys_pad before loss computation', ys_pad.size()) # 64, 61
        if LooseVersion(torch.__version__) < LooseVersion("1.0"):
            reduction_str = "elementwise_mean"
        else:
            reduction_str = "mean"
        self.loss_ce = F.cross_entropy(
            pred_pad.view(-1, self.odim),
            ys_pad.view(-1),
            ignore_index=self.ignore_id,
            reduction=reduction_str,
        )
        self.acc = th_accuracy(
            pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id
        )

        # 4. compute cer/wer
        if self.training or self.error_calculator is None:
            cer, wer, cer_ctc = None, None, None
        else:
            ys_hat = pred_pad.argmax(dim=-1)
            cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)

        # copyied from e2e_asr
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = self.loss_ce
            loss_ce_data = float(self.loss_ce)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = self.loss_ctc
            loss_ce_data = None
            loss_ctc_data = float(self.loss_ctc)
        else:
            self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_ce
            loss_ce_data = float(self.loss_ce)
            loss_ctc_data = float(self.loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_ctc_data, loss_ce_data, self.acc, cer_ctc, cer, wer, loss_data
            )
        else:
            pass
        return self.loss
Example #25
0
    def forward(self, xs_pad, ilens, ys_pad, asrtts=False, ttsasr=False):
        """E2E forward

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
        :return: loass value
        :rtype: torch.Tensor
        """
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. Encoder
        if self.replace_sos:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:,
                            1:]  # remove target language ID in the beggining
        else:
            tgt_lang_ids = None

        hs_pad, hlens, _ = self.enc(hs_pad, hlens)
        # 2. CTC loss
        if self.mtlalpha == 0:
            self.loss_ctc = None
        else:
            self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad)
        if asrtts:
            acc = None
            self.acc = acc
            # 4. compute cer without beam search
            cer_ctc = None
            if self.recog_args.ctc_weight > 0.0:
                lpz = self.ctc.log_softmax(hs_pad).data
            else:
                lpz = None
            #if self.recog_args.sampling == 'multinomial':
            #self.loss_att, best_hyps = self.dec.generate(hs_pad, torch.tensor(hlens), ys_pad, self.recog_args)
            self.loss_att, best_hyps = self.dec.generate_forward(
                hs_pad, torch.tensor(hlens), ys_pad,
                self.recog_args)  #, oracle_length=self.oracle_length)
            #else:
            #    best_hyps, pred_scores = self.dec.generate_beam_batch(
            #        hs_pad, torch.tensor(hlens), lpz,
            #        self.recog_args, self.char_list,
            #        self.rnnlm, tgt_lang_ids=None, sampling=self.recog_args.sampling)
            #    self.loss_att = pred_scores.mean(2)[:, -self.recog_args.nbest:].view(-1)
            #self.loss_att *= (np.mean([len(x) for x in best_hyps]) - 1)
            logging.info(self.loss_att.mean())
        elif ttsasr:
            # 3. attention loss
            if self.mtlalpha == 1:
                self.loss_att, acc = None, None
            else:
                hs_zero_pad = torch.zeros(hs_pad.size()).cuda()
                self.loss_att, _, _ = self.dec(hs_zero_pad,
                                               hlens,
                                               ys_pad,
                                               tgt_lang_ids=tgt_lang_ids)
                # masking acc
                # logging.info("TTS2ASR, ACC: " + str(acc))
                acc = None
            self.acc = acc
            # 4. compute cer without beam search
            if self.mtlalpha == 0:
                cer_ctc = None
            else:
                cers = []

                y_hats = self.ctc.argmax(hs_pad).data
                for i, y in enumerate(y_hats):
                    y_hat = [x[0] for x in groupby(y)]
                    y_true = ys_pad[i]

                    seq_hat = [
                        self.char_list[int(idx)] for idx in y_hat
                        if int(idx) != -1
                    ]
                    seq_true = [
                        self.char_list[int(idx)] for idx in y_true
                        if int(idx) != -1
                    ]
                    seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
                    seq_hat_text = seq_hat_text.replace(self.blank, '')
                    seq_true_text = "".join(seq_true).replace(self.space, ' ')

                    hyp_chars = seq_hat_text.replace(' ', '')
                    ref_chars = seq_true_text.replace(' ', '')
                    if len(ref_chars) > 0:
                        cers.append(
                            editdistance.eval(hyp_chars, ref_chars) /
                            len(ref_chars))

                cer_ctc = sum(cers) / len(cers) if cers else None
        else:
            # 3. attention loss
            if self.mtlalpha == 1:
                self.loss_att, acc = None, None
            else:
                self.loss_att, acc, _ = self.dec(hs_pad,
                                                 hlens,
                                                 ys_pad,
                                                 tgt_lang_ids=tgt_lang_ids)
            self.acc = acc
            # 4. compute cer without beam search
            if self.mtlalpha == 0:
                cer_ctc = None
            else:
                cers = []

                y_hats = self.ctc.argmax(hs_pad).data
                for i, y in enumerate(y_hats):
                    y_hat = [x[0] for x in groupby(y)]
                    y_true = ys_pad[i]

                    seq_hat = [
                        self.char_list[int(idx)] for idx in y_hat
                        if int(idx) != -1
                    ]
                    seq_true = [
                        self.char_list[int(idx)] for idx in y_true
                        if int(idx) != -1
                    ]
                    seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
                    seq_hat_text = seq_hat_text.replace(self.blank, '')
                    seq_true_text = "".join(seq_true).replace(self.space, ' ')

                    hyp_chars = seq_hat_text.replace(' ', '')
                    ref_chars = seq_true_text.replace(' ', '')
                    if len(ref_chars) > 0:
                        cers.append(
                            editdistance.eval(hyp_chars, ref_chars) /
                            len(ref_chars))

                cer_ctc = sum(cers) / len(cers) if cers else None

        # 5. compute cer/wer
        if self.training or not (self.report_cer or self.report_wer):
            cer, wer = 0.0, 0.0
            # oracle_cer, oracle_wer = 0.0, 0.0
        else:
            if self.recog_args.ctc_weight > 0.0:
                lpz = self.ctc.log_softmax(hs_pad).data
            else:
                lpz = None
            word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], []
            nbest_hyps = self.dec.recognize_beam_batch(
                hs_pad,
                torch.tensor(hlens),
                lpz,
                self.recog_args,
                self.char_list,
                self.rnnlm,
                tgt_lang_ids=tgt_lang_ids.squeeze(1).tolist()
                if self.replace_sos else None)
            # remove <sos> and <eos>
            y_hats = [nbest_hyp[0]['yseq'][1:-1] for nbest_hyp in nbest_hyps]
            for i, y_hat in enumerate(y_hats):
                y_true = ys_pad[i]

                seq_hat = [
                    self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
                ]
                seq_true = [
                    self.char_list[int(idx)] for idx in y_true
                    if int(idx) != -1
                ]
                seq_hat_text = "".join(seq_hat).replace(
                    self.recog_args.space, ' ')
                seq_hat_text = seq_hat_text.replace(self.recog_args.blank, '')
                seq_true_text = "".join(seq_true).replace(
                    self.recog_args.space, ' ')

                hyp_words = seq_hat_text.split()
                ref_words = seq_true_text.split()
                word_eds.append(editdistance.eval(hyp_words, ref_words))
                word_ref_lens.append(len(ref_words))
                hyp_chars = seq_hat_text.replace(' ', '')
                ref_chars = seq_true_text.replace(' ', '')
                char_eds.append(editdistance.eval(hyp_chars, ref_chars))
                char_ref_lens.append(len(ref_chars))

            wer = 0.0 if not self.report_wer else float(
                sum(word_eds)) / sum(word_ref_lens)
            cer = 0.0 if not self.report_cer else float(
                sum(char_eds)) / sum(char_ref_lens)
        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = self.loss_att
            loss_att_data = float(self.loss_att.mean())
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = self.loss_ctc
            loss_att_data = None
            loss_ctc_data = float(self.loss_ctc)
        else:
            if asrtts:
                self.loss = self.loss_att
            else:
                self.loss = alpha * self.loss_ctc + (
                    1 - alpha) * self.loss_att.mean()
            loss_att_data = float(self.loss_att.mean())
            loss_ctc_data = float(self.loss_ctc)

        loss_data = float(self.loss.mean())
        #logging.info("main acc is: " + str(acc))
        if asrtts:
            self.reporter.report(loss_ctc_data, float(self.loss_att.mean()),
                                 acc, cer_ctc, cer, wer,
                                 float(self.loss_att.mean()))
            return self.loss_att, best_hyps
        elif ttsasr:
            if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
                self.reporter.report(loss_ctc_data, loss_att_data, acc,
                                     cer_ctc, cer, wer, loss_data)
            else:
                logging.warning('loss (=%f) is not correct', loss_data)
            return self.loss
        else:
            if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
                self.reporter.report(loss_ctc_data, loss_att_data, acc,
                                     cer_ctc, cer, wer, loss_data)
            else:
                logging.warning('loss (=%f) is not correct', loss_data)
            return self.loss
Example #26
0
    def forward(self, xs_pad, ilens, ys_pad, ul_xs_pad, ul_ilens, ul_ys_pad, process_info):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: loss value
        :rtype: torch.Tensor
        """

        # Forward for cross entropy loss
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. Mixup feature
        if self.mixup_alpha > 0.0:
            hs_pad, ys_pad, ys_pad_b, _, lam = mixup_data(hs_pad, ys_pad, hlens, self.mixup_alpha, self.scheme)

        # 2. RNN Encoder
        hs_pad, hlens, _ = self.enc(hs_pad, hlens)

        # 3. post-processing layer for target dimension
        hs_pad, ys_pad = self.match_pad(hs_pad, ys_pad)
        if self.mixup_alpha > 0.0:
            hs_pad, ys_pad_b = self.match_pad(hs_pad, ys_pad_b)

        # 4. Supervised loss
        if LooseVersion(torch.__version__) < LooseVersion("1.0"):
            reduction_str = "elementwise_mean"
        else:
            reduction_str = "mean"
        if self.mixup_alpha > 0.0:
            loss_ce_a = F.cross_entropy(
                hs_pad.view(-1, self.odim),
                ys_pad.view(-1),
                ignore_index=self.ignore_id,
                reduction=reduction_str,
            )
            loss_ce_b = F.cross_entropy(
                hs_pad.view(-1, self.odim),
                ys_pad_b.view(-1),
                ignore_index=self.ignore_id,
                reduction=reduction_str,
            )
            self.loss_ce = lam * loss_ce_a + (1 - lam) * loss_ce_b
        else:
            self.loss_ce = F.cross_entropy(
                hs_pad.view(-1, self.odim),
                ys_pad.view(-1),
                ignore_index=self.ignore_id,
                reduction=reduction_str,
            )

        # Forward for consistency loss
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(ul_xs_pad), ul_ilens)
            hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = ul_xs_pad, ul_ilens

        # Calculating student model accuracy consumes twice the time.
        if self.show_student_model_acc:
            ul_pred_pad, ul_hlens, _ = self.enc(hs_pad, hlens)
            ul_pred_pad, ul_ys_pad_temp = self.match_pad(ul_pred_pad, ul_ys_pad)
            self.stu_acc = th_accuracy(
                ul_pred_pad.view(-1, self.odim), ul_ys_pad_temp, ignore_label=self.ignore_id
            )
            # empty used cuda variable
            ul_pred_pad, ul_ys_pad_temp = (None, None)
        else:
            self.stu_acc = 0

        # 1. Mixup feature
        if self.mixup_alpha > 0.0:
            hs_pad, ys_pad, _, shuf_idx, lam = mixup_data(hs_pad, ul_ys_pad, hlens, self.mixup_alpha,
                                                          self.scheme)

        # 2. RNN Encoder
        ema_ul_hs_pad, ema_ul_hlens, _ = self.ema_enc(hs_pad, hlens)
        hs_pad, hlens, _ = self.enc(hs_pad, hlens)

        # 3. post-processing layer for target dimension
        ema_ul_hs_pad, ema_ul_ys_pad = self.match_pad(ema_ul_hs_pad, ul_ys_pad)
        hs_pad, ul_ys_pad = self.match_pad(hs_pad, ul_ys_pad)

        # 4. mixup ema model output
        # Calculate EMA model accuracy before mixup
        self.ema_acc = th_accuracy(
            ema_ul_hs_pad.view(-1, self.odim), ema_ul_ys_pad, ignore_label=self.ignore_id
        )
        if self.mixup_alpha > 0.0:
            ema_ul_hs_pad = mixup_logit(ema_ul_hs_pad, ema_ul_hlens, shuf_idx, lam, self.scheme)
        ema_ul_hs_pad = torch.autograd.Variable(ema_ul_hs_pad.detach().data, requires_grad=False)

        # 5. Consistency loss
        self.loss_mse = softmax_mse_loss(
            hs_pad.view(-1, self.odim),
            ema_ul_hs_pad.view(-1, self.odim),
            reduction_str=reduction_str
        )

        # 6. Total loss
        if process_info is not None:
            if process_info["epoch"] < self.consistency_rampup_starts:
                consistency_weight = 0
            else:
                consistency_weight = get_current_consistency_weight(
                    self.consistency_weight,
                    process_info["epoch"],
                    process_info["current_position"],
                    process_info["batch_len"],
                    self.consistency_rampup_starts,
                    self.consistency_rampup_ends
                )
        else:
            consistency_weight = 0
        self.loss = self.loss_ce + consistency_weight * self.loss_mse

        loss_ce_data = float(self.loss_ce)
        loss_mse_data = float(consistency_weight * self.loss_mse)
        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(
                loss_ce_data, loss_mse_data, self.stu_acc, self.ema_acc, loss_data
            )
        else:
            pass
        return self.loss
Example #27
0
    def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
        """E2E beam search.

        :param ndarray xs: input acoustic feature (T, D)
        :param Namespace recog_args: argument Namespace containing options
        :param list char_list: list of characters
        :param torch.nn.Module rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        prev = self.training
        self.eval()
        ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)

        # subsample frame
        xs = [xx[::self.subsample[0], :] for xx in xs]
        xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
        xs_pad = pad_list(xs, 0.0)

        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(xs_pad, ilens)
            hlens_n = [None] * self.num_spkrs
            for i in range(self.num_spkrs):
                hs_pad[i], hlens_n[i] = self.feature_transform(
                    hs_pad[i], hlens)
            hlens = hlens_n
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. Encoder
        if not isinstance(hs_pad,
                          list):  # single-channel multi-speaker input x
            hs_pad, hlens, _ = self.enc(hs_pad, hlens)
        else:  # multi-channel multi-speaker input x
            for i in range(self.num_spkrs):
                hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])

        # calculate log P(z_t|X) for CTC scores
        if recog_args.ctc_weight > 0.0:
            lpz = [
                self.dec.log_softmax(hs_pad[i]) for i in range(self.num_spkrs)
            ]
            normalize_score = False
        else:
            lpz = None
            normalize_score = True

        # 2. decoder
        y = [
            self.dec.recognize_beam_batch(hs_pad[i],
                                          hlens[i],
                                          lpz[i],
                                          recog_args,
                                          char_list,
                                          rnnlm,
                                          normalize_score=normalize_score,
                                          strm_idx=i)
            for i in range(self.num_spkrs)
        ]

        if prev:
            self.train()
        return y
Example #28
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, num_spkrs, Lmax)
        :return: ctc loss value
        :rtype: torch.Tensor
        :return: attention loss value
        :rtype: torch.Tensor
        :return: accuracy in attention decoder
        :rtype: float
        """
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            if isinstance(hs_pad, list):
                hlens_n = [None] * self.num_spkrs
                for i in range(self.num_spkrs):
                    hs_pad[i], hlens_n[i] = self.feature_transform(
                        hs_pad[i], hlens)
                hlens = hlens_n
            else:
                hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. Encoder
        if not isinstance(
                hs_pad, list
        ):  # single-channel input xs_pad (single- or multi-speaker)
            hs_pad, hlens, _ = self.enc(hs_pad, hlens)
        else:  # multi-channel multi-speaker input xs_pad
            for i in range(self.num_spkrs):
                hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])

        # 2. CTC loss
        with torch.no_grad():
            if self.mtlalpha == 0:
                loss_rnnt, min_perm = None, None
            else:
                if not isinstance(hs_pad, list):  # single-speaker input xs_pad
                    loss = torch.mean(self.dec(hs_pad, hlens, ys_pad))
                else:  # multi-speaker input xs_pad
                    ys_pad = ys_pad.transpose(0, 1)  # (num_spkrs, B, Lmax)
                    loss_ctc_perm = torch.stack([
                        self.ctc(hs_pad[i // self.num_spkrs],
                                 hlens[i // self.num_spkrs],
                                 ys_pad[i % self.num_spkrs])
                        for i in range(self.num_spkrs**2)
                    ],
                                                dim=1)  # (B, num_spkrs^2)
                    loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm)
                    logging.info('ctc loss:' + str(float(loss_ctc)))

        # 3. attention loss
        if self.mtlalpha == 1:
            loss_att = None
            acc = None
        else:
            if not isinstance(hs_pad, list):  # single-speaker input xs_pad
                loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad)
            else:
                for i in range(ys_pad.size(1)):  # B
                    ys_pad[:, i] = ys_pad[min_perm[i], i]
                rslt = [
                    self.dec(hs_pad[i], hlens[i], ys_pad[i], strm_idx=i)
                    for i in range(self.num_spkrs)
                ]
                acc = sum([r[1] for r in rslt]) / float(len(rslt))
        self.acc = acc

        # 4. transducer loss
        for i in range(ys_pad.size(1)):  # B
            ys_pad[:, i] = ys_pad[min_perm[i], i]
        ret = [
            self.dec(hs_pad[i], hlens[i], ys_pad[i])
            for i in range(self.num_spkrs)
        ]
        loss_rnnt = torch.mean(
            torch.stack([r[0] for r in ret], dim=0).to(ret[0].device))  # (B)

        # 5. compute cer/wer
        if self.training or not (self.report_cer
                                 or self.report_wer) or not isinstance(
                                     hs_pad, list):
            cer, wer = 0.0, 0.0
            # oracle_cer, oracle_wer = 0.0, 0.0
        else:
            if self.recog_args.ctc_weight > 0.0:
                lpz = [
                    self.dec.log_softmax(hs_pad[i]).data
                    for i in range(self.num_spkrs)
                ]
            else:
                lpz = None

            word_eds, char_eds, word_ref_lens, char_ref_lens = [], [], [], []

            batchsize = int(hs_pad.size(0))
            batch_nbest = []
            for b in six.moves.range(batchsize):
                for i in range(self.num_spkrs):
                    nbest_hyps = self.dec.recognize_beam(
                        hs_pad[b], self.recog_args)
                    batch_nbest.append(nbest_hyps)

            nbest_hyps = [
                self.dec.recognize_beam_batch(hs_pad[i],
                                              torch.tensor(hlens[i]),
                                              lpz[i],
                                              self.recog_args,
                                              self.char_list,
                                              self.rnnlm,
                                              strm_idx=i)
                for i in range(self.num_spkrs)
            ]
            # remove <sos>  todo  <eos> with att?
            y_hats = [[
                nbest_hyp[0]['yseq'][1:] for nbest_hyp in nbest_hyps[i]
            ] for i in range(self.num_spkrs)]
            for i in range(len(y_hats[0])):
                hyp_words = []
                hyp_chars = []
                ref_words = []
                ref_chars = []
                for ns in range(self.num_spkrs):
                    y_hat = y_hats[ns][i]
                    y_true = ys_pad[ns][i]

                    seq_hat = [
                        self.char_list[int(idx)] for idx in y_hat
                        if int(idx) != -1
                    ]
                    seq_true = [
                        self.char_list[int(idx)] for idx in y_true
                        if int(idx) != -1
                    ]
                    seq_hat_text = "".join(seq_hat).replace(
                        self.recog_args.space, ' ')
                    seq_hat_text = seq_hat_text.replace(
                        self.recog_args.blank, '')
                    seq_true_text = "".join(seq_true).replace(
                        self.recog_args.space, ' ')

                    hyp_words.append(seq_hat_text.split())
                    ref_words.append(seq_true_text.split())
                    hyp_chars.append(seq_hat_text.replace(' ', ''))
                    ref_chars.append(seq_true_text.replace(' ', ''))

                tmp_word_ed = [
                    editdistance.eval(hyp_words[ns // self.num_spkrs],
                                      ref_words[ns % self.num_spkrs])
                    for ns in range(self.num_spkrs**2)
                ]  # h1r1,h1r2,h2r1,h2r2
                tmp_char_ed = [
                    editdistance.eval(hyp_chars[ns // self.num_spkrs],
                                      ref_chars[ns % self.num_spkrs])
                    for ns in range(self.num_spkrs**2)
                ]  # h1r1,h1r2,h2r1,h2r2

                word_eds.append(
                    self.pit.min_pit_sample(torch.tensor(tmp_word_ed))[0])
                word_ref_lens.append(len(sum(ref_words, [])))
                char_eds.append(
                    self.pit.min_pit_sample(torch.tensor(tmp_char_ed))[0])
                char_ref_lens.append(len(''.join(ref_chars)))

            wer = 0.0 if not self.report_wer else float(
                sum(word_eds)) / sum(word_ref_lens)
            cer = 0.0 if not self.report_cer else float(
                sum(char_eds)) / sum(char_ref_lens)

        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = loss_att
            loss_att_data = float(loss_att)
            loss_rnnt_data = None
        elif alpha == 1:
            self.loss = loss_rnnt
            loss_att_data = None
            loss_rnnt_data = float(loss_rnnt)
        else:
            self.loss = alpha * loss_rnnt + (1 - alpha) * loss_att
            loss_att_data = float(loss_rnnt)
            loss_rnnt_data = float(loss_rnnt)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_rnnt_data, loss_att_data, None, cer, wer,
                                 loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
Example #29
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
        :return: loass value
        :rtype: torch.Tensor
        """
        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. Encoder
        if self.replace_sos:
            tgt_lang_ids = ys_pad[:, 0:1]
            ys_pad = ys_pad[:, 1:]  # remove target language ID in the beggining
        else:
            tgt_lang_ids = None

        hs_pad, hlens, _ = self.enc(hs_pad, hlens)

        # 2. CTC loss
        if self.mtlalpha == 0:
            self.loss_ctc = None
        else:
            self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad)

        # 3. attention loss
        if self.mtlalpha == 1:
            self.loss_att, acc = None, None
        else:
            self.loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad, tgt_lang_ids=tgt_lang_ids)
        self.acc = acc

        # 4. compute cer without beam search
        if self.mtlalpha == 0 or self.char_list is None:
            cer_ctc = None
        else:
            cers = []

            y_hats = self.ctc.argmax(hs_pad).data
            for i, y in enumerate(y_hats):
                y_hat = [x[0] for x in groupby(y)]
                y_true = ys_pad[i]

                seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1]
                seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
                seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
                seq_hat_text = seq_hat_text.replace(self.blank, '')
                seq_true_text = "".join(seq_true).replace(self.space, ' ')

                hyp_chars = seq_hat_text.replace(' ', '')
                ref_chars = seq_true_text.replace(' ', '')
                if len(ref_chars) > 0:
                    cers.append(editdistance.eval(hyp_chars, ref_chars) / len(ref_chars))

            cer_ctc = sum(cers) / len(cers) if cers else None

        # 5. compute cer/wer
        if self.training or not (self.report_cer or self.report_wer):
            cer, wer = 0.0, 0.0
            # oracle_cer, oracle_wer = 0.0, 0.0
        else:
            if self.recog_args.ctc_weight > 0.0:
                lpz = self.ctc.log_softmax(hs_pad).data
            else:
                lpz = None

            word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], []
            nbest_hyps = self.dec.recognize_beam_batch(
                hs_pad, torch.tensor(hlens), lpz,
                self.recog_args, self.char_list,
                self.rnnlm,
                tgt_lang_ids=tgt_lang_ids.squeeze(1).tolist() if self.replace_sos else None)
            # remove <sos> and <eos>
            y_hats = [nbest_hyp[0]['yseq'][1:-1] for nbest_hyp in nbest_hyps]
            for i, y_hat in enumerate(y_hats):
                y_true = ys_pad[i]

                seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1]
                seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
                seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, ' ')
                seq_hat_text = seq_hat_text.replace(self.recog_args.blank, '')
                seq_true_text = "".join(seq_true).replace(self.recog_args.space, ' ')

                hyp_words = seq_hat_text.split()
                ref_words = seq_true_text.split()
                word_eds.append(editdistance.eval(hyp_words, ref_words))
                word_ref_lens.append(len(ref_words))
                hyp_chars = seq_hat_text.replace(' ', '')
                ref_chars = seq_true_text.replace(' ', '')
                char_eds.append(editdistance.eval(hyp_chars, ref_chars))
                char_ref_lens.append(len(ref_chars))

            wer = 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens)
            cer = 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens)

        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = self.loss_att
            loss_att_data = float(self.loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = self.loss_ctc
            loss_att_data = None
            loss_ctc_data = float(self.loss_ctc)
        else:
            self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att
            loss_att_data = float(self.loss_att)
            loss_ctc_data = float(self.loss_ctc)

        loss_data = float(self.loss)
        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, acc, cer_ctc, cer, wer, loss_data)
        else:
            logging.warning('loss (=%f) is not correct', loss_data)
        return self.loss
Example #30
0
    def forward(self, xs_pad, ilens, ys_pad):
        """E2E forward.

        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
        :return: loss value
        :rtype: torch.Tensor
        """

        # 0. Frontend
        if self.frontend is not None:
            hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
            hs_pad, hlens = self.feature_transform(hs_pad, hlens)
        else:
            hs_pad, hlens = xs_pad, ilens

        # 1. Encoder
        hs_pad, hlens, _ = self.enc(hs_pad, hlens)

        # 2. CTC loss
        if self.mtlalpha == 0:
            self.loss_ctc = None
        else:
            self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad)

        # 3. attention loss
        # if self.mtlalpha == 1:
        #    self.loss_att, acc = None, None

        # else:
        #     self.loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad)

        # self.acc = acc

        # 4. compute cer without beam search
        if self.mtlalpha == 0 or self.char_list is None:
            cer_ctc = None
        else:
            cers = []

            y_hats = self.ctc.argmax(hs_pad).data
            for i, y in enumerate(y_hats):
                y_hat = [x[0] for x in groupby(y)]
                y_true = ys_pad[i]

                seq_hat = [
                    self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
                ]
                seq_true = [
                    self.char_list[int(idx)] for idx in y_true
                    if int(idx) != -1
                ]
                seq_hat_text = "".join(seq_hat).replace(self.space, " ")
                seq_hat_text = seq_hat_text.replace(self.blank, "")
                seq_true_text = "".join(seq_true).replace(self.space, " ")

                hyp_chars = seq_hat_text.replace(" ", "")
                ref_chars = seq_true_text.replace(" ", "")
                if len(ref_chars) > 0:
                    cers.append(
                        editdistance.eval(hyp_chars, ref_chars) /
                        len(ref_chars))

            cer_ctc = sum(cers) / len(cers) if cers else None

        # 5. compute cer/wer
        if self.training or not (self.report_cer or self.report_wer):
            cer, wer = 0.0, 0.0
            # oracle_cer, oracle_wer = 0.0, 0.0
        else:
            if self.recog_args.ctc_weight > 0.0:
                lpz = self.ctc.log_softmax(hs_pad).data
            else:
                lpz = None

            word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], []
            nbest_hyps = self.dec.recognize_beam_batch(
                hs_pad,
                torch.tensor(hlens),
                lpz,
                self.recog_args,
                self.char_list,
                self.rnnlm,
            )
            # remove <sos> and <eos>
            y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps]
            for i, y_hat in enumerate(y_hats):
                y_true = ys_pad[i]

                seq_hat = [
                    self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
                ]
                seq_true = [
                    self.char_list[int(idx)] for idx in y_true
                    if int(idx) != -1
                ]
                seq_hat_text = "".join(seq_hat).replace(
                    self.recog_args.space, " ")
                seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "")
                seq_true_text = "".join(seq_true).replace(
                    self.recog_args.space, " ")

                hyp_words = seq_hat_text.split()
                ref_words = seq_true_text.split()
                word_eds.append(editdistance.eval(hyp_words, ref_words))
                word_ref_lens.append(len(ref_words))
                hyp_chars = seq_hat_text.replace(" ", "")
                ref_chars = seq_true_text.replace(" ", "")
                char_eds.append(editdistance.eval(hyp_chars, ref_chars))
                char_ref_lens.append(len(ref_chars))

            wer = (0.0 if not self.report_wer else float(sum(word_eds)) /
                   sum(word_ref_lens))
            cer = (0.0 if not self.report_cer else float(sum(char_eds)) /
                   sum(char_ref_lens))

        alpha = self.mtlalpha
        if alpha == 0:
            self.loss = self.loss_att
            loss_att_data = float(self.loss_att)
            loss_ctc_data = None
        elif alpha == 1:
            self.loss = self.loss_ctc
            loss_att_data = None
            acc = None
            loss_ctc_data = float(self.loss_ctc)
        else:
            self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att
            loss_att_data = float(self.loss_att)
            loss_ctc_data = float(self.loss_ctc)

        loss_data = float(self.loss)
        with open(
                '/home/oshindo/espnet/egs/aishell/asr1/exp/train_sp_pytorch_e2e_asr/Bilstm_ctc.txt',
                "a+") as fid:
            fid.write("loss:" + str(loss_data) + ';' + "cer:" + str(cer_ctc) +
                      '\n')

        if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
            self.reporter.report(loss_ctc_data, loss_att_data, acc, cer_ctc,
                                 cer, wer, loss_data)
        else:
            logging.warning("loss (=%f) is not correct", loss_data)
        return self.loss