def translate_batch(self, xs, trans_args, char_list, rnnlm=None): """E2E batch beam search. :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] :param Namespace trans_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) # 1. Encoder hs_pad, hlens, _ = self.enc(xs_pad, ilens) # 2. Decoder hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor y = self.dec.recognize_beam_batch(hs_pad, hlens, None, trans_args, char_list, rnnlm) if prev: self.train() return y
def forward(self, xs, labels=None): ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) xs = [to_device(self.slu, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) embeddings = self.slu(xs_pad, ilens, None) outputs = self.classifier(embeddings, labels) return outputs
def recognize_batch(self, xs, recog_args, char_list, rnnlm=None): """E2E beam search. :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[:: self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) # 0. Frontend if self.frontend is not None: enhanced, hlens, mask = self.frontend(xs_pad, ilens) hs_pad, hlens = self.feature_transform(enhanced, hlens) else: hs_pad, hlens = xs_pad, ilens batchsize = hs_pad.size(0) # 1. Encoder hyps, hlens, _ = self.enc(hs_pad, hlens) hyps = hyps.view(batchsize, -1, self.odim) return hyps
def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad): """E2E CTC probability calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: CTC probability (B, Tmax, vocab) :rtype: float ndarray """ probs = None if self.mtlalpha == 0: return probs self.eval() with torch.no_grad(): # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder hpad, hlens, _ = self.enc(hs_pad, hlens) # 2. CTC probs probs = self.ctc.softmax(hpad).cpu().numpy() self.train() return probs
def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder if self.replace_sos: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining else: tgt_lang_ids = None hpad, hlens, _ = self.enc(hs_pad, hlens) # 2. Decoder att_ws = self.dec.calculate_all_attentions( hpad, hlens, ys_pad, tgt_lang_ids=tgt_lang_ids) return att_ws
def enhance(self, xs): """Forward only the frontend stage. :param ndarray xs: input acoustic feature (T, C, F) """ if self.frontend is None: raise RuntimeError('Frontend doesn\'t exist') prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) enhanced, hlensm, mask = self.frontend(xs_pad, ilens) if prev: self.train() if isinstance(enhanced, (tuple, list)): enhanced = list(enhanced) mask = list(mask) for idx in range(len(enhanced)): # number of speakers enhanced[idx] = enhanced[idx].cpu().numpy() mask[idx] = mask[idx].cpu().numpy() return enhanced, mask, ilens return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
def enhance(self, xs): """Forward only the frontend stage. Args: xs (ndarray): input acoustic feature (T, C, F) Returns: enhanced (ndarray): mask (torch.Tensor): ilens (torch.Tensor): batch of lengths of input sequences (B) """ if self.frontend is None: raise RuntimeError('Frontend does\'t exist') prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) enhanced, hlensm, mask = self.frontend(xs_pad, ilens) if prev: self.train() return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. Args: xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: att_ws (ndarray): attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). """ if self.rnnt_mode == 'rnnt': return [] with torch.no_grad(): # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # encoder hpad, hlens, _ = self.enc(hs_pad, hlens) # decoder att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad) return att_ws
def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, num_spkrs, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hlens_n = [None] * self.num_spkrs for i in range(self.num_spkrs): hs_pad[i], hlens_n[i] = self.feature_transform( hs_pad[i], hlens) hlens = hlens_n else: hs_pad, hlens = xs_pad, ilens # 1. Encoder if not isinstance(hs_pad, list): # single-channel multi-speaker input x hs_pad, hlens, _ = self.enc(hs_pad, hlens) else: # multi-channel multi-speaker input x for i in range(self.num_spkrs): hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) # Permutation ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax) if self.num_spkrs <= 3: loss_ctc = torch.stack( [ self.ctc( hs_pad[i // self.num_spkrs], hlens[i // self.num_spkrs], ys_pad[i % self.num_spkrs], ) for i in range(self.num_spkrs**2) ], 1, ) # (B, num_spkrs^2) loss_ctc, min_perm = self.pit.pit_process(loss_ctc) for i in range(ys_pad.size(1)): # B ys_pad[:, i] = ys_pad[min_perm[i], i] # 2. Decoder att_ws = [ self.dec.calculate_all_attentions(hs_pad[i], hlens[i], ys_pad[i], strm_idx=i) for i in range(self.num_spkrs) ] return att_ws
def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ self.eval() with torch.no_grad(): # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder hpad, hlens, _ = self.enc(hs_pad, hlens) # 2. Decoder att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad) self.train() return att_ws
def recognize(self, x, recog_args, char_list, rnnlm=None): """E2E beam search. :param ndarray x: input acoustic feature (T, D) :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = [x.shape[0]] # subsample frame x = x[::self.subsample[0], :] h = to_device(self, to_torch_tensor(x).float()) # make a utt list (1) to use the same interface for encoder hs = h.contiguous().unsqueeze(0) # 0. Frontend if self.frontend is not None: hs, hlens, mask = self.frontend(hs, ilens) hlens_n = [None] * self.num_spkrs for i in range(self.num_spkrs): hs[i], hlens_n[i] = self.feature_transform(hs[i], hlens) hlens = hlens_n else: hs, hlens = hs, ilens # 1. Encoder if not isinstance(hs, list): # single-channel multi-speaker input x hs, hlens, _ = self.enc(hs, hlens) else: # multi-channel multi-speaker input x for i in range(self.num_spkrs): hs[i], hlens[i], _ = self.enc(hs[i], hlens[i]) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: lpz = [self.ctc.log_softmax(i)[0] for i in hs] else: lpz = None # 2. decoder # decode the first utterance y = [ self.dec.recognize_beam(hs[i][0], lpz[i], recog_args, char_list, rnnlm, strm_idx=i) for i in range(self.num_spkrs) ] if prev: self.train() return y
def recognize_batch(self, xs, recog_args, char_list, rnnlm=None): """E2E batch beam search. :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) # 0. Frontend if self.frontend is not None: enhanced, hlens, mask = self.frontend(xs_pad, ilens) hs_pad, hlens = self.feature_transform(enhanced, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder hs_pad, hlens, _ = self.enc(hs_pad, hlens) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(hs_pad) normalize_score = False else: lpz = None normalize_score = True # 2. Decoder hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor y = self.dec.recognize_beam_batch( hs_pad, hlens, lpz, recog_args, char_list, rnnlm, normalize_score=normalize_score, ) if prev: self.train() return y
def forward_frontend_and_encoder(self, xs_pad, ilens): """Forward front-end and encoder.""" # 0. Frontend if self.frontend is not None: hs_pad, hlens, _ = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder hs_pad, hlens, _ = self.enc(hs_pad, hlens) return hs_pad, hlens
def recognize_batch(self, xs_list, recog_args, char_list, rnnlm=None): """E2E beam search. :param list xs_list: list of list of input acoustic feature arrays [[(T1_1, D), (T1_2, D), ...],[(T2_1, D), (T2_2, D), ...], ...] :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens_list = [np.fromiter((xx.shape[0] for xx in xs_list[idx]), dtype=np.int64) for idx in range(self.num_encs)] # subsample frame xs_list = [[xx[::self.subsample_list[idx][0], :] for xx in xs_list[idx]] for idx in range(self.num_encs)] xs_list = [[to_device(self, to_torch_tensor(xx).float()) for xx in xs_list[idx]] for idx in range(self.num_encs)] xs_pad_list = [pad_list(xs_list[idx], 0.0) for idx in range(self.num_encs)] # 1. Encoder hs_pad_list, hlens_list = [], [] for idx in range(self.num_encs): hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) hs_pad_list.append(hs_pad) hlens_list.append(hlens) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: if self.share_ctc: lpz_list = [self.ctc[0].log_softmax(hs_pad_list[idx]) for idx in range(self.num_encs)] else: lpz_list = [self.ctc[idx].log_softmax(hs_pad_list[idx]) for idx in range(self.num_encs)] normalize_score = False else: lpz_list = None normalize_score = True # 2. Decoder hlens_list = [torch.tensor(list(map(int, hlens_list[idx]))) for idx in range(self.num_encs)] # make sure hlens is tensor y = self.dec.recognize_beam_batch(hs_pad_list, hlens_list, lpz_list, recog_args, char_list, rnnlm, normalize_score=normalize_score) if prev: self.train() return y
def calculate_alignments(self, xs_pad, ilens, ys_pad): # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. encoder hs_pad, hlens, _ = self.enc(hs_pad, hlens) # 2. decoder _, gammas = self.dec.rnnt_alignment(hs_pad, hlens, ys_pad) return gammas
def recognize(self, x, recog_args, char_list, rnnlm=None): """E2E beam search :param ndarray x: input acoustic feature (T, D) :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = [x.shape[0]] # subsample frame x = x[::self.subsample[0], :] h = to_device(self, to_torch_tensor(x).float()) # make a utt list (1) to use the same interface for encoder hs = h.contiguous().unsqueeze(0) # 0. Frontend if self.frontend is not None: enhanced, hlens, mask = self.frontend(hs, ilens) hs, hlens = self.feature_transform(enhanced, hlens) else: hs, hlens = hs, ilens # 1. encoder hs, _, __ = self.enc(hs, hlens) print(hs.shape, _.shape, __.shape) exit(1) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(hs)[0] else: lpz = None # 2. Decoder # decode the first utterance y = self.dec.recognize_beam(hs[0], lpz, recog_args, char_list, rnnlm) if prev: self.train() return y
def recognize(self, x, recog_args, char_list, rnnlm=None): """E2E recognize. Args: x (ndarray): input acoustic feature (T, D) recog_args (namespace): argument Namespace containing options char_list (list): list of characters rnnlm (torch.nn.Module): language model module Returns: y (list): n-best decoding results """ prev = self.training self.eval() ilens = [x.shape[0]] # subsample frame x = x[::self.subsample[0], :] h = to_device(self, to_torch_tensor(x).float()) # make a utt list (1) to use the same interface for encoder hs = h.contiguous().unsqueeze(0) # 0. Frontend if self.frontend is not None: enhanced, hlens, mask = self.frontend(hs, ilens) hs, hlens = self.feature_transform(enhanced, hlens) else: hs, hlens = hs, ilens # 1. Encoder h, _, _ = self.enc(hs, hlens) # 2. Decoder if recog_args.beam_size == 1: y = self.dec.recognize(h[0], recog_args) else: y = self.dec.recognize_beam(h[0], recog_args, rnnlm) if prev: self.train() return y
def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ self.eval() ilens = numpy.fromiter((xx.shape[0] for xx in x), dtype=numpy.int64) # subsample frame x = [xx[:: self.subsample[0], :] for xx in x] x = [to_device(self, to_torch_tensor(xx).float()) for xx in x] x = pad_list(x, 0.0) enc_output, _ = self.encoder(x, None) batchsize = x.size(0) if self.outer: post_pad = self.poster(enc_output) post_pad = post_pad.view(post_pad.size(0), -1, self.odim) if post_pad.size(1) != x.size(1): if post_pad.size(1) < x.size(1): x = x[:, :post_pad.size(1)] else: raise ValueError( "target size {} and pred size {} is mismatch".format(x.size(1), post_pad.size(1))) if self.residual: post_pad = post_pad + self.matcher_res(x) else: post_pad = torch.cat([post_pad, x], dim=-1) hyps = self.matcher(post_pad) else: pred_pad = self.poster(enc_output) hyps = pred_pad.view(pred_pad.size(0), -1, self.odim) hyps = hyps.view(batchsize, -1, self.odim) return hyps
def encode_rnn(self, x): """Encode acoustic features. Args: x (ndarray): input acoustic feature (T, D) Returns: x (torch.Tensor): encoded features (T, attention_dim) """ self.eval() ilens = [x.shape[0]] x = x[::self.subsample[0], :] h = to_device(self, to_torch_tensor(x).float()) hs = h.contiguous().unsqueeze(0) h, _, _ = self.encoder(hs, ilens) return h[0]
def enhance(self, xs): """Forward only in the frontend stage. :param ndarray xs: input acoustic feature (T, C, F) :return: enhaned feature :rtype: torch.Tensor """ if self.frontend is None: raise RuntimeError("Frontend does't exist") prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) enhanced, hlensm, mask = self.frontend(xs_pad, ilens) if prev: self.train() return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
def forward(self, xs_pad, ilens, ys_pad, ys_pad_mono=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. RNN Encoder hs_pad, hlens, _ = self.enc(hs_pad, hlens) # 2. post-processing layer for target dimension pred_pad = self.poster(hs_pad) pred_pad = pred_pad.view(pred_pad.size(0), -1, self.odim) self.pred_pad = pred_pad if pred_pad.size(1) != ys_pad.size(1): if pred_pad.size(1) < ys_pad.size(1): ys_pad = ys_pad[:, :pred_pad.size(1)].contiguous() else: raise ValueError( "target size {} and pred size {} is mismatch".format( ys_pad.size(1), pred_pad.size(1))) if ys_pad_mono is not None: pred_pad_mono = self.poster_mono(hs_pad) pred_pad_mono = pred_pad_mono.view(pred_pad_mono.size(0), -1, self.mono_odim) self.pred_pad_mono = pred_pad_mono if pred_pad_mono.size(1) != ys_pad_mono.size(1): if pred_pad_mono.size(1) < ys_pad_mono.size(1): ys_pad_mono = ys_pad_mono[:, :pred_pad_mono. size(1)].contiguous() else: raise ValueError( "target size {} and pred size {} is mismatch".format( ys_pad_mono.size(1), pred_pad_mono.size(1))) # 3. CTC loss if self.mtlalpha == 0: self.loss_ctc = None else: self.loss_ctc = self.ctc(pred_pad, hlens, ys_pad) # 3. CE loss if LooseVersion(torch.__version__) < LooseVersion("1.0"): reduction_str = "elementwise_mean" else: reduction_str = "mean" self.loss_ce_tri = F.cross_entropy( pred_pad.view(-1, self.odim), ys_pad.view(-1), ignore_index=self.ignore_id, reduction=reduction_str, ) if ys_pad_mono is not None: self.loss_ce_mono = F.cross_entropy( pred_pad_mono.view(-1, self.odim), ys_pad_mono.view(-1), ignore_index=self.ignore_id, reduction=reduction_str, ) else: self.loss_ce_mono = 0 self.loss_ce = 0.6 * self.loss_ce_tri + 0.4 * self.loss_ce_mono self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer, cer_ctc = None, None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = self.loss_ce loss_ce_data = float(self.loss_ce) loss_ctc_data = None elif alpha == 1: self.loss = self.loss_ctc loss_ce_data = None loss_ctc_data = float(self.loss_ctc) else: self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_ce loss_ce_data = float(self.loss_ce) loss_ctc_data = float(self.loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_ce_data, self.acc, cer_ctc, cer, wer, loss_data) else: pass return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, num_spkrs, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) if isinstance(hs_pad, list): hlens_n = [None] * self.num_spkrs for i in range(self.num_spkrs): hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens) hlens = hlens_n else: hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder if not isinstance( hs_pad, list ): # single-channel input xs_pad (single- or multi-speaker) hs_pad, hlens, _ = self.enc(hs_pad, hlens) else: # multi-channel multi-speaker input xs_pad for i in range(self.num_spkrs): hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) # 2. CTC loss if self.mtlalpha == 0: loss_ctc, min_perm = None, None else: if not isinstance(hs_pad, list): # single-speaker input xs_pad loss_ctc = torch.mean(self.ctc(hs_pad, hlens, ys_pad)) else: # multi-speaker input xs_pad ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax) loss_ctc_perm = torch.stack( [ self.ctc( hs_pad[i // self.num_spkrs], hlens[i // self.num_spkrs], ys_pad[i % self.num_spkrs], ) for i in range(self.num_spkrs ** 2) ], dim=1, ) # (B, num_spkrs^2) loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm) logging.info("ctc loss:" + str(float(loss_ctc))) # 3. attention loss if self.mtlalpha == 1: loss_att = None acc = None else: if not isinstance(hs_pad, list): # single-speaker input xs_pad loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad) else: for i in range(ys_pad.size(1)): # B ys_pad[:, i] = ys_pad[min_perm[i], i] rslt = [ self.dec(hs_pad[i], hlens[i], ys_pad[i], strm_idx=i) for i in range(self.num_spkrs) ] loss_att = sum([r[0] for r in rslt]) / float(len(rslt)) acc = sum([r[1] for r in rslt]) / float(len(rslt)) self.acc = acc # 4. compute cer without beam search if self.mtlalpha == 0 or self.char_list is None: cer_ctc = None else: cers = [] for ns in range(self.num_spkrs): y_hats = self.ctc.argmax(hs_pad[ns]).data for i, y in enumerate(y_hats): y_hat = [x[0] for x in groupby(y)] y_true = ys_pad[ns][i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = seq_hat_text.replace(self.blank, "") seq_true_text = "".join(seq_true).replace(self.space, " ") hyp_chars = seq_hat_text.replace(" ", "") ref_chars = seq_true_text.replace(" ", "") if len(ref_chars) > 0: cers.append( editdistance.eval(hyp_chars, ref_chars) / len(ref_chars) ) cer_ctc = sum(cers) / len(cers) if cers else None # 5. compute cer/wer if ( self.training or not (self.report_cer or self.report_wer) or not isinstance(hs_pad, list) ): cer, wer = 0.0, 0.0 else: if self.recog_args.ctc_weight > 0.0: lpz = [ self.ctc.log_softmax(hs_pad[i]).data for i in range(self.num_spkrs) ] else: lpz = None word_eds, char_eds, word_ref_lens, char_ref_lens = [], [], [], [] nbest_hyps = [ self.dec.recognize_beam_batch( hs_pad[i], torch.tensor(hlens[i]), lpz[i], self.recog_args, self.char_list, self.rnnlm, strm_idx=i, ) for i in range(self.num_spkrs) ] # remove <sos> and <eos> y_hats = [ [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps[i]] for i in range(self.num_spkrs) ] for i in range(len(y_hats[0])): hyp_words = [] hyp_chars = [] ref_words = [] ref_chars = [] for ns in range(self.num_spkrs): y_hat = y_hats[ns][i] y_true = ys_pad[ns][i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ") seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") seq_true_text = "".join(seq_true).replace( self.recog_args.space, " " ) hyp_words.append(seq_hat_text.split()) ref_words.append(seq_true_text.split()) hyp_chars.append(seq_hat_text.replace(" ", "")) ref_chars.append(seq_true_text.replace(" ", "")) tmp_word_ed = [ editdistance.eval( hyp_words[ns // self.num_spkrs], ref_words[ns % self.num_spkrs] ) for ns in range(self.num_spkrs ** 2) ] # h1r1,h1r2,h2r1,h2r2 tmp_char_ed = [ editdistance.eval( hyp_chars[ns // self.num_spkrs], ref_chars[ns % self.num_spkrs] ) for ns in range(self.num_spkrs ** 2) ] # h1r1,h1r2,h2r1,h2r2 word_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_word_ed))[0]) word_ref_lens.append(len(sum(ref_words, []))) char_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_char_ed))[0]) char_ref_lens.append(len("".join(ref_chars))) wer = ( 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens) ) cer = ( 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens) ) alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) if isinstance(hs_pad, list): hlens_n = [None] * self.num_spkrs for i in range(self.num_spkrs): hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens) hlens = hlens_n else: hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder if not isinstance(hs_pad, list): # single-channel input xs_pad (single- or multi-speaker) hs_pad, hlens, _ = self.enc(hs_pad, hlens) else: # multi-channel multi-speaker input xs_pad for i in range(self.num_spkrs): hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) # 2. decoder loss = self.dec(hs_pad, hlens, ys_pad) # 3. compute cer/wer # note: not recommended outside debugging right now, # the training time is hugely impacted. if self.training or not (self.report_cer or self.report_wer): cer, wer = 0.0, 0.0 else: word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] batchsize = int(hs_pad.size(0)) batch_nbest = [] for b in six.moves.range(batchsize): nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args) batch_nbest.append(nbest_hyps) y_hats = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [self.char_list[int(idx)] for idx in y_hat] seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, ' ') seq_true_text = "".join(seq_true).replace(self.recog_args.space, ' ') hyp_words = seq_hat_text.split() ref_words = seq_true_text.split() word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) hyp_chars = seq_hat_text.replace(' ', '') ref_chars = seq_true_text.replace(' ', '') char_eds.append(editdistance.eval(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) wer = 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens) cer = 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens) self.loss = loss loss_data = float(self.loss) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # TO DO aug #rand_idx=torch.randperm(hs_pad.size(0)) #rand_ratio = 0.2 * torch.rand(1).to(hs_pad.device) #hs_pad = (1-rand_ratio) * hs_pad + rand_ratio * torch.flip(hs_pad, [1])[rand_idx].to(hs_pad.device) # 1. CNN Encoder hs_pad, hlens, _ = self.enc(hs_pad, hlens) # TODO: Not sure about oversampling & outer # 2. post-processing layer for target dimension if self.outer: post_pad = self.poster(hs_pad) post_pad = post_pad.view(post_pad.size(0), -1, self.odim) if post_pad.size(1) != xs_pad.size(1): if post_pad.size(1) < xs_pad.size(1): xs_pad = xs_pad[:, :post_pad.size(1)].contiguous() else: raise ValueError("target size {} and pred size {} is mismatch".format(xs_pad.size(1), post_pad.size(1))) if self.residual: post_pad = post_pad + self.matcher_res(xs_pad) else: post_pad = torch.cat([post_pad, xs_pad], dim=-1) pred_pad = self.matcher(post_pad) else: pred_pad = self.poster(hs_pad) pred_pad = pred_pad.view(pred_pad.size(0), -1, self.odim) self.pred_pad = pred_pad if pred_pad.size(1) != ys_pad.size(1): if pred_pad.size(1) < ys_pad.size(1): ys_pad = ys_pad[:, :pred_pad.size(1)].contiguous() else: raise ValueError("target size {} and pred size {} is mismatch".format(ys_pad.size(1), pred_pad.size(1))) # 3. CTC loss if self.mtlalpha == 0: self.loss_ctc = None else: self.loss_ctc = self.ctc(pred_pad, hlens, ys_pad) # 3. CE loss # print('pred_pad before loss computation', pred_pad.size()) # 64, 61, 3480 # print('ys_pad before loss computation', ys_pad.size()) # 64, 61 if LooseVersion(torch.__version__) < LooseVersion("1.0"): reduction_str = "elementwise_mean" else: reduction_str = "mean" self.loss_ce = F.cross_entropy( pred_pad.view(-1, self.odim), ys_pad.view(-1), ignore_index=self.ignore_id, reduction=reduction_str, ) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id ) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer, cer_ctc = None, None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = self.loss_ce loss_ce_data = float(self.loss_ce) loss_ctc_data = None elif alpha == 1: self.loss = self.loss_ctc loss_ce_data = None loss_ctc_data = float(self.loss_ctc) else: self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_ce loss_ce_data = float(self.loss_ce) loss_ctc_data = float(self.loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_ce_data, self.acc, cer_ctc, cer, wer, loss_data ) else: pass return self.loss
def forward(self, xs_pad, ilens, ys_pad, asrtts=False, ttsasr=False): """E2E forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: loass value :rtype: torch.Tensor """ # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder if self.replace_sos: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining else: tgt_lang_ids = None hs_pad, hlens, _ = self.enc(hs_pad, hlens) # 2. CTC loss if self.mtlalpha == 0: self.loss_ctc = None else: self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad) if asrtts: acc = None self.acc = acc # 4. compute cer without beam search cer_ctc = None if self.recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(hs_pad).data else: lpz = None #if self.recog_args.sampling == 'multinomial': #self.loss_att, best_hyps = self.dec.generate(hs_pad, torch.tensor(hlens), ys_pad, self.recog_args) self.loss_att, best_hyps = self.dec.generate_forward( hs_pad, torch.tensor(hlens), ys_pad, self.recog_args) #, oracle_length=self.oracle_length) #else: # best_hyps, pred_scores = self.dec.generate_beam_batch( # hs_pad, torch.tensor(hlens), lpz, # self.recog_args, self.char_list, # self.rnnlm, tgt_lang_ids=None, sampling=self.recog_args.sampling) # self.loss_att = pred_scores.mean(2)[:, -self.recog_args.nbest:].view(-1) #self.loss_att *= (np.mean([len(x) for x in best_hyps]) - 1) logging.info(self.loss_att.mean()) elif ttsasr: # 3. attention loss if self.mtlalpha == 1: self.loss_att, acc = None, None else: hs_zero_pad = torch.zeros(hs_pad.size()).cuda() self.loss_att, _, _ = self.dec(hs_zero_pad, hlens, ys_pad, tgt_lang_ids=tgt_lang_ids) # masking acc # logging.info("TTS2ASR, ACC: " + str(acc)) acc = None self.acc = acc # 4. compute cer without beam search if self.mtlalpha == 0: cer_ctc = None else: cers = [] y_hats = self.ctc.argmax(hs_pad).data for i, y in enumerate(y_hats): y_hat = [x[0] for x in groupby(y)] y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace(self.space, ' ') seq_hat_text = seq_hat_text.replace(self.blank, '') seq_true_text = "".join(seq_true).replace(self.space, ' ') hyp_chars = seq_hat_text.replace(' ', '') ref_chars = seq_true_text.replace(' ', '') if len(ref_chars) > 0: cers.append( editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)) cer_ctc = sum(cers) / len(cers) if cers else None else: # 3. attention loss if self.mtlalpha == 1: self.loss_att, acc = None, None else: self.loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad, tgt_lang_ids=tgt_lang_ids) self.acc = acc # 4. compute cer without beam search if self.mtlalpha == 0: cer_ctc = None else: cers = [] y_hats = self.ctc.argmax(hs_pad).data for i, y in enumerate(y_hats): y_hat = [x[0] for x in groupby(y)] y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace(self.space, ' ') seq_hat_text = seq_hat_text.replace(self.blank, '') seq_true_text = "".join(seq_true).replace(self.space, ' ') hyp_chars = seq_hat_text.replace(' ', '') ref_chars = seq_true_text.replace(' ', '') if len(ref_chars) > 0: cers.append( editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)) cer_ctc = sum(cers) / len(cers) if cers else None # 5. compute cer/wer if self.training or not (self.report_cer or self.report_wer): cer, wer = 0.0, 0.0 # oracle_cer, oracle_wer = 0.0, 0.0 else: if self.recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(hs_pad).data else: lpz = None word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] nbest_hyps = self.dec.recognize_beam_batch( hs_pad, torch.tensor(hlens), lpz, self.recog_args, self.char_list, self.rnnlm, tgt_lang_ids=tgt_lang_ids.squeeze(1).tolist() if self.replace_sos else None) # remove <sos> and <eos> y_hats = [nbest_hyp[0]['yseq'][1:-1] for nbest_hyp in nbest_hyps] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace( self.recog_args.space, ' ') seq_hat_text = seq_hat_text.replace(self.recog_args.blank, '') seq_true_text = "".join(seq_true).replace( self.recog_args.space, ' ') hyp_words = seq_hat_text.split() ref_words = seq_true_text.split() word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) hyp_chars = seq_hat_text.replace(' ', '') ref_chars = seq_true_text.replace(' ', '') char_eds.append(editdistance.eval(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) wer = 0.0 if not self.report_wer else float( sum(word_eds)) / sum(word_ref_lens) cer = 0.0 if not self.report_cer else float( sum(char_eds)) / sum(char_ref_lens) alpha = self.mtlalpha if alpha == 0: self.loss = self.loss_att loss_att_data = float(self.loss_att.mean()) loss_ctc_data = None elif alpha == 1: self.loss = self.loss_ctc loss_att_data = None loss_ctc_data = float(self.loss_ctc) else: if asrtts: self.loss = self.loss_att else: self.loss = alpha * self.loss_ctc + ( 1 - alpha) * self.loss_att.mean() loss_att_data = float(self.loss_att.mean()) loss_ctc_data = float(self.loss_ctc) loss_data = float(self.loss.mean()) #logging.info("main acc is: " + str(acc)) if asrtts: self.reporter.report(loss_ctc_data, float(self.loss_att.mean()), acc, cer_ctc, cer, wer, float(self.loss_att.mean())) return self.loss_att, best_hyps elif ttsasr: if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss else: if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad, ul_xs_pad, ul_ilens, ul_ys_pad, process_info): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ # Forward for cross entropy loss # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Mixup feature if self.mixup_alpha > 0.0: hs_pad, ys_pad, ys_pad_b, _, lam = mixup_data(hs_pad, ys_pad, hlens, self.mixup_alpha, self.scheme) # 2. RNN Encoder hs_pad, hlens, _ = self.enc(hs_pad, hlens) # 3. post-processing layer for target dimension hs_pad, ys_pad = self.match_pad(hs_pad, ys_pad) if self.mixup_alpha > 0.0: hs_pad, ys_pad_b = self.match_pad(hs_pad, ys_pad_b) # 4. Supervised loss if LooseVersion(torch.__version__) < LooseVersion("1.0"): reduction_str = "elementwise_mean" else: reduction_str = "mean" if self.mixup_alpha > 0.0: loss_ce_a = F.cross_entropy( hs_pad.view(-1, self.odim), ys_pad.view(-1), ignore_index=self.ignore_id, reduction=reduction_str, ) loss_ce_b = F.cross_entropy( hs_pad.view(-1, self.odim), ys_pad_b.view(-1), ignore_index=self.ignore_id, reduction=reduction_str, ) self.loss_ce = lam * loss_ce_a + (1 - lam) * loss_ce_b else: self.loss_ce = F.cross_entropy( hs_pad.view(-1, self.odim), ys_pad.view(-1), ignore_index=self.ignore_id, reduction=reduction_str, ) # Forward for consistency loss # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(ul_xs_pad), ul_ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = ul_xs_pad, ul_ilens # Calculating student model accuracy consumes twice the time. if self.show_student_model_acc: ul_pred_pad, ul_hlens, _ = self.enc(hs_pad, hlens) ul_pred_pad, ul_ys_pad_temp = self.match_pad(ul_pred_pad, ul_ys_pad) self.stu_acc = th_accuracy( ul_pred_pad.view(-1, self.odim), ul_ys_pad_temp, ignore_label=self.ignore_id ) # empty used cuda variable ul_pred_pad, ul_ys_pad_temp = (None, None) else: self.stu_acc = 0 # 1. Mixup feature if self.mixup_alpha > 0.0: hs_pad, ys_pad, _, shuf_idx, lam = mixup_data(hs_pad, ul_ys_pad, hlens, self.mixup_alpha, self.scheme) # 2. RNN Encoder ema_ul_hs_pad, ema_ul_hlens, _ = self.ema_enc(hs_pad, hlens) hs_pad, hlens, _ = self.enc(hs_pad, hlens) # 3. post-processing layer for target dimension ema_ul_hs_pad, ema_ul_ys_pad = self.match_pad(ema_ul_hs_pad, ul_ys_pad) hs_pad, ul_ys_pad = self.match_pad(hs_pad, ul_ys_pad) # 4. mixup ema model output # Calculate EMA model accuracy before mixup self.ema_acc = th_accuracy( ema_ul_hs_pad.view(-1, self.odim), ema_ul_ys_pad, ignore_label=self.ignore_id ) if self.mixup_alpha > 0.0: ema_ul_hs_pad = mixup_logit(ema_ul_hs_pad, ema_ul_hlens, shuf_idx, lam, self.scheme) ema_ul_hs_pad = torch.autograd.Variable(ema_ul_hs_pad.detach().data, requires_grad=False) # 5. Consistency loss self.loss_mse = softmax_mse_loss( hs_pad.view(-1, self.odim), ema_ul_hs_pad.view(-1, self.odim), reduction_str=reduction_str ) # 6. Total loss if process_info is not None: if process_info["epoch"] < self.consistency_rampup_starts: consistency_weight = 0 else: consistency_weight = get_current_consistency_weight( self.consistency_weight, process_info["epoch"], process_info["current_position"], process_info["batch_len"], self.consistency_rampup_starts, self.consistency_rampup_ends ) else: consistency_weight = 0 self.loss = self.loss_ce + consistency_weight * self.loss_mse loss_ce_data = float(self.loss_ce) loss_mse_data = float(consistency_weight * self.loss_mse) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ce_data, loss_mse_data, self.stu_acc, self.ema_acc, loss_data ) else: pass return self.loss
def recognize_batch(self, xs, recog_args, char_list, rnnlm=None): """E2E beam search. :param ndarray xs: input acoustic feature (T, D) :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(xs_pad, ilens) hlens_n = [None] * self.num_spkrs for i in range(self.num_spkrs): hs_pad[i], hlens_n[i] = self.feature_transform( hs_pad[i], hlens) hlens = hlens_n else: hs_pad, hlens = xs_pad, ilens # 1. Encoder if not isinstance(hs_pad, list): # single-channel multi-speaker input x hs_pad, hlens, _ = self.enc(hs_pad, hlens) else: # multi-channel multi-speaker input x for i in range(self.num_spkrs): hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: lpz = [ self.dec.log_softmax(hs_pad[i]) for i in range(self.num_spkrs) ] normalize_score = False else: lpz = None normalize_score = True # 2. decoder y = [ self.dec.recognize_beam_batch(hs_pad[i], hlens[i], lpz[i], recog_args, char_list, rnnlm, normalize_score=normalize_score, strm_idx=i) for i in range(self.num_spkrs) ] if prev: self.train() return y
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, num_spkrs, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) if isinstance(hs_pad, list): hlens_n = [None] * self.num_spkrs for i in range(self.num_spkrs): hs_pad[i], hlens_n[i] = self.feature_transform( hs_pad[i], hlens) hlens = hlens_n else: hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder if not isinstance( hs_pad, list ): # single-channel input xs_pad (single- or multi-speaker) hs_pad, hlens, _ = self.enc(hs_pad, hlens) else: # multi-channel multi-speaker input xs_pad for i in range(self.num_spkrs): hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) # 2. CTC loss with torch.no_grad(): if self.mtlalpha == 0: loss_rnnt, min_perm = None, None else: if not isinstance(hs_pad, list): # single-speaker input xs_pad loss = torch.mean(self.dec(hs_pad, hlens, ys_pad)) else: # multi-speaker input xs_pad ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax) loss_ctc_perm = torch.stack([ self.ctc(hs_pad[i // self.num_spkrs], hlens[i // self.num_spkrs], ys_pad[i % self.num_spkrs]) for i in range(self.num_spkrs**2) ], dim=1) # (B, num_spkrs^2) loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm) logging.info('ctc loss:' + str(float(loss_ctc))) # 3. attention loss if self.mtlalpha == 1: loss_att = None acc = None else: if not isinstance(hs_pad, list): # single-speaker input xs_pad loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad) else: for i in range(ys_pad.size(1)): # B ys_pad[:, i] = ys_pad[min_perm[i], i] rslt = [ self.dec(hs_pad[i], hlens[i], ys_pad[i], strm_idx=i) for i in range(self.num_spkrs) ] acc = sum([r[1] for r in rslt]) / float(len(rslt)) self.acc = acc # 4. transducer loss for i in range(ys_pad.size(1)): # B ys_pad[:, i] = ys_pad[min_perm[i], i] ret = [ self.dec(hs_pad[i], hlens[i], ys_pad[i]) for i in range(self.num_spkrs) ] loss_rnnt = torch.mean( torch.stack([r[0] for r in ret], dim=0).to(ret[0].device)) # (B) # 5. compute cer/wer if self.training or not (self.report_cer or self.report_wer) or not isinstance( hs_pad, list): cer, wer = 0.0, 0.0 # oracle_cer, oracle_wer = 0.0, 0.0 else: if self.recog_args.ctc_weight > 0.0: lpz = [ self.dec.log_softmax(hs_pad[i]).data for i in range(self.num_spkrs) ] else: lpz = None word_eds, char_eds, word_ref_lens, char_ref_lens = [], [], [], [] batchsize = int(hs_pad.size(0)) batch_nbest = [] for b in six.moves.range(batchsize): for i in range(self.num_spkrs): nbest_hyps = self.dec.recognize_beam( hs_pad[b], self.recog_args) batch_nbest.append(nbest_hyps) nbest_hyps = [ self.dec.recognize_beam_batch(hs_pad[i], torch.tensor(hlens[i]), lpz[i], self.recog_args, self.char_list, self.rnnlm, strm_idx=i) for i in range(self.num_spkrs) ] # remove <sos> todo <eos> with att? y_hats = [[ nbest_hyp[0]['yseq'][1:] for nbest_hyp in nbest_hyps[i] ] for i in range(self.num_spkrs)] for i in range(len(y_hats[0])): hyp_words = [] hyp_chars = [] ref_words = [] ref_chars = [] for ns in range(self.num_spkrs): y_hat = y_hats[ns][i] y_true = ys_pad[ns][i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace( self.recog_args.space, ' ') seq_hat_text = seq_hat_text.replace( self.recog_args.blank, '') seq_true_text = "".join(seq_true).replace( self.recog_args.space, ' ') hyp_words.append(seq_hat_text.split()) ref_words.append(seq_true_text.split()) hyp_chars.append(seq_hat_text.replace(' ', '')) ref_chars.append(seq_true_text.replace(' ', '')) tmp_word_ed = [ editdistance.eval(hyp_words[ns // self.num_spkrs], ref_words[ns % self.num_spkrs]) for ns in range(self.num_spkrs**2) ] # h1r1,h1r2,h2r1,h2r2 tmp_char_ed = [ editdistance.eval(hyp_chars[ns // self.num_spkrs], ref_chars[ns % self.num_spkrs]) for ns in range(self.num_spkrs**2) ] # h1r1,h1r2,h2r1,h2r2 word_eds.append( self.pit.min_pit_sample(torch.tensor(tmp_word_ed))[0]) word_ref_lens.append(len(sum(ref_words, []))) char_eds.append( self.pit.min_pit_sample(torch.tensor(tmp_char_ed))[0]) char_ref_lens.append(len(''.join(ref_chars))) wer = 0.0 if not self.report_wer else float( sum(word_eds)) / sum(word_ref_lens) cer = 0.0 if not self.report_cer else float( sum(char_eds)) / sum(char_ref_lens) alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_rnnt_data = None elif alpha == 1: self.loss = loss_rnnt loss_att_data = None loss_rnnt_data = float(loss_rnnt) else: self.loss = alpha * loss_rnnt + (1 - alpha) * loss_att loss_att_data = float(loss_rnnt) loss_rnnt_data = float(loss_rnnt) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_rnnt_data, loss_att_data, None, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: loass value :rtype: torch.Tensor """ # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder if self.replace_sos: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining else: tgt_lang_ids = None hs_pad, hlens, _ = self.enc(hs_pad, hlens) # 2. CTC loss if self.mtlalpha == 0: self.loss_ctc = None else: self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad) # 3. attention loss if self.mtlalpha == 1: self.loss_att, acc = None, None else: self.loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad, tgt_lang_ids=tgt_lang_ids) self.acc = acc # 4. compute cer without beam search if self.mtlalpha == 0 or self.char_list is None: cer_ctc = None else: cers = [] y_hats = self.ctc.argmax(hs_pad).data for i, y in enumerate(y_hats): y_hat = [x[0] for x in groupby(y)] y_true = ys_pad[i] seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] seq_hat_text = "".join(seq_hat).replace(self.space, ' ') seq_hat_text = seq_hat_text.replace(self.blank, '') seq_true_text = "".join(seq_true).replace(self.space, ' ') hyp_chars = seq_hat_text.replace(' ', '') ref_chars = seq_true_text.replace(' ', '') if len(ref_chars) > 0: cers.append(editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)) cer_ctc = sum(cers) / len(cers) if cers else None # 5. compute cer/wer if self.training or not (self.report_cer or self.report_wer): cer, wer = 0.0, 0.0 # oracle_cer, oracle_wer = 0.0, 0.0 else: if self.recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(hs_pad).data else: lpz = None word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] nbest_hyps = self.dec.recognize_beam_batch( hs_pad, torch.tensor(hlens), lpz, self.recog_args, self.char_list, self.rnnlm, tgt_lang_ids=tgt_lang_ids.squeeze(1).tolist() if self.replace_sos else None) # remove <sos> and <eos> y_hats = [nbest_hyp[0]['yseq'][1:-1] for nbest_hyp in nbest_hyps] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, ' ') seq_hat_text = seq_hat_text.replace(self.recog_args.blank, '') seq_true_text = "".join(seq_true).replace(self.recog_args.space, ' ') hyp_words = seq_hat_text.split() ref_words = seq_true_text.split() word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) hyp_chars = seq_hat_text.replace(' ', '') ref_chars = seq_true_text.replace(' ', '') char_eds.append(editdistance.eval(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) wer = 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens) cer = 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens) alpha = self.mtlalpha if alpha == 0: self.loss = self.loss_att loss_att_data = float(self.loss_att) loss_ctc_data = None elif alpha == 1: self.loss = self.loss_ctc loss_att_data = None loss_ctc_data = float(self.loss_ctc) else: self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att loss_att_data = float(self.loss_att) loss_ctc_data = float(self.loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder hs_pad, hlens, _ = self.enc(hs_pad, hlens) # 2. CTC loss if self.mtlalpha == 0: self.loss_ctc = None else: self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad) # 3. attention loss # if self.mtlalpha == 1: # self.loss_att, acc = None, None # else: # self.loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad) # self.acc = acc # 4. compute cer without beam search if self.mtlalpha == 0 or self.char_list is None: cer_ctc = None else: cers = [] y_hats = self.ctc.argmax(hs_pad).data for i, y in enumerate(y_hats): y_hat = [x[0] for x in groupby(y)] y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = seq_hat_text.replace(self.blank, "") seq_true_text = "".join(seq_true).replace(self.space, " ") hyp_chars = seq_hat_text.replace(" ", "") ref_chars = seq_true_text.replace(" ", "") if len(ref_chars) > 0: cers.append( editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)) cer_ctc = sum(cers) / len(cers) if cers else None # 5. compute cer/wer if self.training or not (self.report_cer or self.report_wer): cer, wer = 0.0, 0.0 # oracle_cer, oracle_wer = 0.0, 0.0 else: if self.recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(hs_pad).data else: lpz = None word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] nbest_hyps = self.dec.recognize_beam_batch( hs_pad, torch.tensor(hlens), lpz, self.recog_args, self.char_list, self.rnnlm, ) # remove <sos> and <eos> y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace( self.recog_args.space, " ") seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") seq_true_text = "".join(seq_true).replace( self.recog_args.space, " ") hyp_words = seq_hat_text.split() ref_words = seq_true_text.split() word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) hyp_chars = seq_hat_text.replace(" ", "") ref_chars = seq_true_text.replace(" ", "") char_eds.append(editdistance.eval(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) wer = (0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens)) cer = (0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens)) alpha = self.mtlalpha if alpha == 0: self.loss = self.loss_att loss_att_data = float(self.loss_att) loss_ctc_data = None elif alpha == 1: self.loss = self.loss_ctc loss_att_data = None acc = None loss_ctc_data = float(self.loss_ctc) else: self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att loss_att_data = float(self.loss_att) loss_ctc_data = float(self.loss_ctc) loss_data = float(self.loss) with open( '/home/oshindo/espnet/egs/aishell/asr1/exp/train_sp_pytorch_e2e_asr/Bilstm_ctc.txt', "a+") as fid: fid.write("loss:" + str(loss_data) + ';' + "cer:" + str(cer_ctc) + '\n') if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss