class E2E(ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument( "--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal", ], help="how to initialize transformer parameters", ) group.add_argument( "--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help="transformer input layer type", ) group.add_argument( "--transformer-attn-dropout-rate", default=None, type=float, help="dropout in transformer attention. use --dropout-rate if None is set", ) group.add_argument( "--transformer-lr", default=10.0, type=float, help="Initial value of learning rate", ) group.add_argument( "--transformer-warmup-steps", default=25000, type=int, help="optimizer warmup steps", ) group.add_argument( "--transformer-length-normalized-loss", default=True, type=strtobool, help="normalize loss by length", ) group.add_argument( "--transformer-encoder-selfattn-layer-type", type=str, default="selfattn", choices=[ "selfattn", "rel_selfattn", "lightconv", "lightconv2d", "dynamicconv", "dynamicconv2d", "light-dynamicconv2d", ], help="transformer encoder self-attention layer type", ) group.add_argument( "--transformer-decoder-selfattn-layer-type", type=str, default="selfattn", choices=[ "selfattn", "lightconv", "lightconv2d", "dynamicconv", "dynamicconv2d", "light-dynamicconv2d", ], help="transformer decoder self-attention layer type", ) # Lightweight/Dynamic convolution related parameters. # See https://arxiv.org/abs/1912.11793v2 # and https://arxiv.org/abs/1901.10430 for detail of the method. # Configurations used in the first paper are in # egs/{csj, librispeech}/asr1/conf/tuning/ld_conv/ group.add_argument( "--wshare", default=4, type=int, help="Number of parameter shargin for lightweight convolution", ) group.add_argument( "--ldconv-encoder-kernel-length", default="21_23_25_27_29_31_33_35_37_39_41_43", type=str, help="kernel size for lightweight/dynamic convolution: " 'Encoder side. For example, "21_23_25" means kernel length 21 for ' "First layer, 23 for Second layer and so on.", ) group.add_argument( "--ldconv-decoder-kernel-length", default="11_13_15_17_19_21", type=str, help="kernel size for lightweight/dynamic convolution: " 'Decoder side. For example, "21_23_25" means kernel length 21 for ' "First layer, 23 for Second layer and so on.", ) group.add_argument( "--ldconv-usebias", type=strtobool, default=False, help="use bias term in lightweight/dynamic convolution", ) group.add_argument( "--dropout-rate", default=0.0, type=float, help="Dropout rate for the encoder", ) # Encoder group.add_argument( "--elayers", default=4, type=int, help="Number of encoder layers (for shared recognition part " "in multi-speaker asr mode)", ) group.add_argument( "--eunits", "-u", default=300, type=int, help="Number of encoder hidden units", ) # Attention group.add_argument( "--adim", default=320, type=int, help="Number of attention transformation dimensions", ) group.add_argument( "--aheads", default=4, type=int, help="Number of heads for multi head attention", ) # Decoder group.add_argument( "--dlayers", default=1, type=int, help="Number of decoder layers" ) group.add_argument( "--dunits", default=320, type=int, help="Number of decoder hidden units" ) # Non-autoregressive training group.add_argument( "--decoder-mode", default="AR", type=str, choices=["ar", "maskctc"], help="AR: standard autoregressive training, " "maskctc: non-autoregressive training based on Mask CTC", ) return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.criterion = LabelSmoothingLoss( odim, ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) else: self.decoder = None self.criterion = None self.blank = 0 self.decoder_mode = args.decoder_mode if self.decoder_mode == "maskctc": self.mask_token = odim - 1 self.sos = odim - 2 self.eos = odim - 2 else: self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model) self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder if self.decoder is not None: if self.decoder_mode == "maskctc": ys_in_pad, ys_out_pad = mask_uniform( ys_pad, self.mask_token, self.eos, self.ignore_id ) ys_mask = (ys_in_pad != self.ignore_id).unsqueeze(-2) else: ys_in_pad, ys_out_pad = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id ) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) else: loss_att = None self.acc = None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training and self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization if not self.training: self.ctc.softmax(hs_pad) # 5. compute cer/wer if self.training or self.error_calculator is None or self.decoder is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) enc_output, _ = self.encoder(x, None) return enc_output.squeeze(0) def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if self.mtlalpha == 1.0: recog_args.ctc_weight = 1.0 logging.info("Set to pure CTC decoding mode.") if self.mtlalpha > 0 and recog_args.ctc_weight == 1.0: from itertools import groupby lpz = self.ctc.argmax(enc_output) collapsed_indices = [x[0] for x in groupby(lpz[0])] hyp = [x for x in filter(lambda x: x != self.blank, collapsed_indices)] nbest_hyps = [{"score": 0.0, "yseq": hyp}] if recog_args.beam_size > 1: raise NotImplementedError("Pure CTC beam search is not implemented.") # TODO(hirofumi0810): Implement beam search return nbest_hyps elif self.mtlalpha > 0 and recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output) ) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output )[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) local_scores = ( local_att_scores + recog_args.lm_weight * local_lm_scores ) else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1 ) ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"] ) local_scores = (1.0 - ctc_weight) * local_att_scores[ :, local_best_ids[0] ] + ctc_weight * torch.from_numpy( ctc_scores - hyp["ctc_score_prev"] ) if rnnlm: local_scores += ( recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] ) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1 ) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1 ) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[0, j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x["score"], reverse=True )[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"] ) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) ) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ : min(len(ended_hyps), recog_args.nbest) ] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, perform recognition " "again with smaller minlenratio." ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info( "normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) ) return nbest_hyps def recognize_maskctc(self, x, recog_args, char_list=None): """Non-autoregressive decoding using Mask CTC. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :return: decoding result :rtype: list """ self.eval() h = self.encode(x).unsqueeze(0) ctc_probs, ctc_ids = torch.exp(self.ctc.log_softmax(h)).max(dim=-1) y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])]) y_idx = torch.nonzero(y_hat != 0).squeeze(-1) probs_hat = [] cnt = 0 for i, y in enumerate(y_hat.tolist()): probs_hat.append(-1) while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]: if probs_hat[i] < ctc_probs[0][cnt]: probs_hat[i] = ctc_probs[0][cnt].item() cnt += 1 probs_hat = torch.from_numpy(numpy.array(probs_hat)) char_mask = "_" p_thres = recog_args.maskctc_probability_threshold mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1) confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1) mask_num = len(mask_idx) y_in = torch.zeros(1, len(y_idx) + 1, dtype=torch.long) + self.mask_token y_in[0][confident_idx] = y_hat[y_idx][confident_idx] y_in[0][-1] = self.eos logging.info( "ctc:{}".format( "".join( [ char_list[y] if y != self.mask_token else char_mask for y in y_in[0].tolist() ] ).replace("<space>", " ") ) ) if not mask_num == 0: K = recog_args.maskctc_n_iterations num_iter = K if mask_num >= K and K > 0 else mask_num for t in range(1, num_iter): pred, _ = self.decoder( y_in, (y_in != self.ignore_id).unsqueeze(-2), h, None ) pred_sc, pred_id = pred[0][mask_idx].max(dim=-1) cand = torch.topk(pred_sc, mask_num // num_iter, -1)[1] y_in[0][mask_idx[cand]] = pred_id[cand] mask_idx = torch.nonzero(y_in[0] == self.mask_token).squeeze(-1) logging.info( "msk:{}".format( "".join( [ char_list[y] if y != self.mask_token else char_mask for y in y_in[0].tolist() ] ).replace("<space>", " ") ) ) pred, pred_mask = self.decoder( y_in, (y_in != self.ignore_id).unsqueeze(-2), h, None ) y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1) logging.info( "msk:{}".format( "".join( [ char_list[y] if y != self.mask_token else char_mask for y in y_in[0].tolist() ] ).replace("<space>", " ") ) ) ret = y_in.tolist()[0][:-1] hyp = {"score": 0.0, "yseq": [self.sos] + ret + [self.eos]} return [hyp] def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights (B, H, Lmax, Tmax) :rtype: float ndarray """ self.eval() with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if ( isinstance(m, MultiHeadedAttention) or isinstance(m, DynamicConvolution) or isinstance(m, RelPositionMultiHeadedAttention) ): ret[name] = m.attn.cpu().numpy() if isinstance(m, DynamicConvolution2D): ret[name + "_time"] = m.attn_t.cpu().numpy() ret[name + "_freq"] = m.attn_f.cpu().numpy() self.train() return ret def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad): """E2E CTC probability calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: CTC probability (B, Tmax, vocab) :rtype: float ndarray """ ret = None if self.mtlalpha == 0: return ret self.eval() with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) for name, m in self.named_modules(): if isinstance(m, CTC) and m.probs is not None: ret = m.probs.cpu().numpy() self.train() return ret
class E2E(STInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" E2E.encoder_add_arguments(parser) E2E.attention_add_arguments(parser) E2E.decoder_add_arguments(parser) return parser @staticmethod def encoder_add_arguments(parser): """Add arguments for the encoder.""" group = parser.add_argument_group("E2E encoder setting") group = add_arguments_rnn_encoder_common(group) return parser @staticmethod def attention_add_arguments(parser): """Add arguments for the attention.""" group = parser.add_argument_group("E2E attention setting") group = add_arguments_rnn_attention_common(group) return parser @staticmethod def decoder_add_arguments(parser): """Add arguments for the decoder.""" group = parser.add_argument_group("E2E decoder setting") group = add_arguments_rnn_decoder_common(group) return parser def get_total_subsampling_factor(self): """Get total subsampling factor.""" return self.enc.conv_subsampling_factor * int(np.prod(self.subsample)) def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) self.asr_weight = args.asr_weight self.mt_weight = args.mt_weight self.mtlalpha = args.mtlalpha assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)" assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)" assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 self.pad = 0 # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class # in ASR. However, blank labels are not used in MT. # To keep the vocabulary size, # we use index:0 for padding instead of adding one more class. # subsample info self.subsample = get_subsample(args, mode="st", arch="rnn") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None # multilingual related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) # encoder self.enc = encoder_for(args, idim, self.subsample) # attention (ST) self.att = att_for(args) # decoder (ST) self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # submodule for ASR task self.ctc = None self.att_asr = None self.dec_asr = None if self.asr_weight > 0: if self.mtlalpha > 0.0: self.ctc = CTC( odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=True, ) if self.mtlalpha < 1.0: # attention (asr) self.att_asr = att_for(args) # decoder (asr) args_asr = copy.deepcopy(args) args_asr.atype = "location" # TODO(hirofumi0810): make this option self.dec_asr = decoder_for(args_asr, odim, self.sos, self.eos, self.att_asr, labeldist) # submodule for MT task if self.mt_weight > 0: self.embed_mt = torch.nn.Embedding(odim, args.eunits, padding_idx=self.pad) self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate) self.enc_mt = encoder_for(args, args.eunits, subsample=np.ones(args.elayers + 1, dtype=np.int)) # weight initialization self.init_like_chainer() # options for beam search if self.asr_weight > 0 and args.report_cer or args.report_wer: recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False if args.report_bleu: trans_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": 0, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, } self.trans_args = argparse.Namespace(**trans_args) self.report_bleu = args.report_bleu else: self.report_bleu = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None def init_like_chainer(self): """Initialize weight like chainer. chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) however, there are two exceptions as far as I know. - EmbedID.W ~ Normal(0, 1) - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) """ lecun_normal_init_parameters(self) # exceptions # embed weight ~ Normal(0, 1) self.dec.embed.weight.data.normal_(0, 1) # forget-bias = 1.0 # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 for i in six.moves.range(len(self.dec.decoder)): set_forget_bias_to_one(self.dec.decoder[i].bias_ih) def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ # 0. Extract target language ID if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beginning else: tgt_lang_ids = None # 1. Encoder hs_pad, hlens, _ = self.enc(xs_pad, ilens) # 2. ST attention loss self.loss_st, self.acc, _ = self.dec(hs_pad, hlens, ys_pad, lang_ids=tgt_lang_ids) # 3. ASR loss ( self.loss_asr_att, acc_asr, self.loss_asr_ctc, cer_ctc, cer, wer, ) = self.forward_asr(hs_pad, hlens, ys_pad_src) # 4. MT attention loss self.loss_mt, acc_mt = self.forward_mt(ys_pad, ys_pad_src) # 5. Compute BLEU if self.training or not self.report_bleu: self.bleu = 0.0 else: lpz = None nbest_hyps = self.dec.recognize_beam_batch( hs_pad, torch.tensor(hlens), lpz, self.trans_args, self.char_list, self.rnnlm, lang_ids=tgt_lang_ids.squeeze(1).tolist() if self.multilingual else None, ) # remove <sos> and <eos> list_of_refs = [] hyps = [] y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace( self.trans_args.space, " ") seq_hat_text = seq_hat_text.replace(self.trans_args.blank, "") seq_true_text = "".join(seq_true).replace( self.trans_args.space, " ") hyps += [seq_hat_text.split(" ")] list_of_refs += [[seq_true_text.split(" ")]] self.bleu = nltk.bleu_score.corpus_bleu(list_of_refs, hyps) * 100 asr_ctc_weight = self.mtlalpha self.loss_asr = (asr_ctc_weight * self.loss_asr_ctc + (1 - asr_ctc_weight) * self.loss_asr_att) self.loss = ((1 - self.asr_weight - self.mt_weight) * self.loss_st + self.asr_weight * self.loss_asr + self.mt_weight * self.loss_mt) loss_st_data = float(self.loss_st) loss_asr_data = float(self.loss_asr) loss_mt_data = float(self.loss_mt) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_asr_data, loss_mt_data, loss_st_data, acc_asr, acc_mt, self.acc, cer_ctc, cer, wer, self.bleu, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss def forward_asr(self, hs_pad, hlens, ys_pad): """Forward pass in the auxiliary ASR task. :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor hlens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ASR attention loss value :rtype: torch.Tensor :return: accuracy in ASR attention decoder :rtype: float :return: ASR CTC loss value :rtype: torch.Tensor :return: character error rate from CTC prediction :rtype: float :return: character error rate from attetion decoder prediction :rtype: float :return: word error rate from attetion decoder prediction :rtype: float """ import editdistance loss_att, loss_ctc = 0.0, 0.0 acc = None cer, wer = None, None cer_ctc = None if self.asr_weight == 0: return loss_att, acc, loss_ctc, cer_ctc, cer, wer # attention if self.mtlalpha < 1: loss_asr, acc_asr, _ = self.dec_asr(hs_pad, hlens, ys_pad) # Compute wer and cer if not self.training and (self.report_cer or self.report_wer): if self.mtlalpha > 0 and self.recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(hs_pad).data else: lpz = None word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] nbest_hyps_asr = self.dec_asr.recognize_beam_batch( hs_pad, torch.tensor(hlens), lpz, self.recog_args, self.char_list, self.rnnlm, ) # remove <sos> and <eos> y_hats = [ nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps_asr ] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace( self.recog_args.space, " ") seq_hat_text = seq_hat_text.replace( self.recog_args.blank, "") seq_true_text = "".join(seq_true).replace( self.recog_args.space, " ") hyp_words = seq_hat_text.split() ref_words = seq_true_text.split() word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) hyp_chars = seq_hat_text.replace(" ", "") ref_chars = seq_true_text.replace(" ", "") char_eds.append(editdistance.eval(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) wer = (0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens)) cer = (0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens)) # CTC if self.mtlalpha > 0: loss_ctc = self.ctc(hs_pad, hlens, ys_pad) # Compute cer with CTC prediction if self.char_list is not None: cers = [] y_hats = self.ctc.argmax(hs_pad).data for i, y in enumerate(y_hats): y_hat = [x[0] for x in groupby(y)] y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = seq_hat_text.replace(self.blank, "") seq_true_text = "".join(seq_true).replace(self.space, " ") hyp_chars = seq_hat_text.replace(" ", "") ref_chars = seq_true_text.replace(" ", "") if len(ref_chars) > 0: cers.append( editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)) cer_ctc = sum(cers) / len(cers) if cers else None return loss_att, acc, loss_ctc, cer_ctc, cer, wer def forward_mt(self, xs_pad, ys_pad): """Forward pass in the auxiliary MT task. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: MT loss value :rtype: torch.Tensor :return: accuracy in MT decoder :rtype: float """ loss = 0.0 acc = 0.0 if self.mt_weight == 0: return loss, acc ilens = torch.sum(xs_pad != -1, dim=1).cpu().numpy() # NOTE: xs_pad is padded with -1 ys_src = [y[y != -1] for y in xs_pad] # parse padded ys_src xs_zero_pad = pad_list(ys_src, self.pad) # re-pad with zero hs_pad, hlens, _ = self.enc_mt( self.dropout_mt(self.embed_mt(xs_zero_pad)), ilens) loss, acc, _ = self.dec(hs_pad, hlens, ys_pad) return loss, acc def scorers(self): """Scorers.""" return dict(decoder=self.dec) def encode(self, x): """Encode acoustic features. :param ndarray x: input acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() ilens = [x.shape[0]] # subsample frame x = x[::self.subsample[0], :] p = next(self.parameters()) h = torch.as_tensor(x, device=p.device, dtype=p.dtype) # make a utt list (1) to use the same interface for encoder hs = h.contiguous().unsqueeze(0) # 1. encoder hs, _, _ = self.enc(hs, ilens) return hs.squeeze(0) def translate(self, x, trans_args, char_list, rnnlm=None): """E2E beam search. :param ndarray x: input acoustic feature (T, D) :param Namespace trans_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ logging.info("input lengths: " + str(x.shape[0])) hs = self.encode(x).unsqueeze(0) logging.info("encoder output lengths: " + str(hs.size(1))) # 2. Decoder # decode the first utterance y = self.dec.recognize_beam(hs[0], None, trans_args, char_list, rnnlm) return y def translate_batch(self, xs, trans_args, char_list, rnnlm=None): """E2E batch beam search. :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] :param Namespace trans_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) # 1. Encoder hs_pad, hlens, _ = self.enc(xs_pad, ilens) # 2. Decoder hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor y = self.dec.recognize_beam_batch(hs_pad, hlens, None, trans_args, char_list, rnnlm) if prev: self.train() return y def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ self.eval() with torch.no_grad(): # 1. Encoder if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beginning else: tgt_lang_ids = None hpad, hlens, _ = self.enc(xs_pad, ilens) # 2. Decoder att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad, lang_ids=tgt_lang_ids) self.train() return att_ws def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E CTC probability calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :param torch.Tensor ys_pad_src: batch of padded token id sequence tensor (B, Lmax) :return: CTC probability (B, Tmax, vocab) :rtype: float ndarray """ probs = None if self.asr_weight == 0 or self.mtlalpha == 0: return probs self.eval() with torch.no_grad(): # 1. Encoder hpad, hlens, _ = self.enc(xs_pad, ilens) # 2. CTC probs probs = self.ctc.softmax(hpad).cpu().numpy() self.train() return probs def subsample_frames(self, x): """Subsample speeh frames in the encoder.""" # subsample frame x = x[::self.subsample[0], :] ilen = [x.shape[0]] h = to_device(self, torch.from_numpy(np.array(x, dtype=np.float32))) h.contiguous() return h, ilen
class E2E(ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help='transformer input layer type') group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') # ctc init path group.add_argument('--pretrained-cn-ctc-model', default='', type=str, help='pretrained cn ctc model') group.add_argument('--pretrained-en-ctc-model', default='', type=str, help='pretrained en ctc model') group.add_argument('--pretrained-mlme-model', default='', type=str, help='pretrained multi-lingual multi-encoder model') group.add_argument('--enc-lambda', default=0.5, type=float, help='encoder fusion lambda params') return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last self.enc_lambda = args.enc_lambda logging.warning("Using fixed encoder lambda: {}".format( self.enc_lambda)) logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True))) def count_parameters(self, requires_grad=False): if requires_grad: return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1024 / 1024 else: return sum(p.numel() for p in self.parameters()) / 1024 / 1024 def reset_parameters(self, args): """Initialize parameters.""" # load state_dict, and keeps only encoder part # note that self.ctc.ctc_lo is also removed # prefix is added to meet the needs of moe structure def load_state_dict_encoder(path, prefix=''): if 'snapshot' in path: model_state_dict = torch.load( path, map_location=lambda storage, loc: storage)['model'] else: model_state_dict = torch.load( path, map_location=lambda storage, loc: storage) for k in list(model_state_dict.keys()): if not 'encoder' in k: # remove this key del model_state_dict[k] else: new_k = k.replace('encoder.', prefix + 'encoder.') model_state_dict[new_k] = model_state_dict.pop(k) return model_state_dict # initialize parameters if args.pretrained_mlme_model: logging.warning( "loading pretrained mlme model for parallel encoder") # still need to initialize the 'other' params initialize(self, args.transformer_init) path = args.pretrained_mlme_model if 'snapshot' in path: model_state_dict = torch.load( path, map_location=lambda storage, loc: storage)['model'] else: model_state_dict = torch.load( path, map_location=lambda storage, loc: storage) self.load_state_dict(model_state_dict, strict=False) del model_state_dict elif args.pretrained_cn_ctc_model and args.pretrained_en_ctc_model: logging.warning( "loading pretrained ctc model for parallel encoder") # still need to initialize the 'other' params initialize(self, args.transformer_init) cn_state_dict = load_state_dict_encoder( args.pretrained_cn_ctc_model, prefix='cn_') self.load_state_dict(cn_state_dict, strict=False) del cn_state_dict en_state_dict = load_state_dict_encoder( args.pretrained_en_ctc_model, prefix='en_') self.load_state_dict(en_state_dict, strict=False) del en_state_dict else: initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # mlp moe forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) # gated add module """ lambda = sigmoid(W_cn * cn_xs + w_en * en_xs + b) #(B, T, 1) xs = lambda * cn_xs + (1-lambda) * en_xs """ hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) lambda_ = self.enc_lambda hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad self.hs_pad = hs_pad # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) if self.mtlalpha == 1: self.loss_att, acc = None, None else: # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) self.acc = acc # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) # (B, T, D) with #B=1 cn_enc_output, _ = self.cn_encoder(x, None) en_enc_output, _ = self.en_encoder(x, None) enc_output = torch.cat((cn_enc_output, en_enc_output), dim=-1) lambda_ = self.enc_lambda enc_output = lambda_ * cn_enc_output + (1 - lambda_) * en_enc_output return enc_output.squeeze(0) # returns tensor(T, D) def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): if recog_args.ctc_greedy_decoding: return self.recognize_ctc_greedy(x, recog_args) else: return self.recognize_jca(x, recog_args, char_list, rnnlm, use_jit) def store_penultimate_state(self, xs_pad, ilens, ys_pad): xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # multi-encoder forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) # hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) # lambda_ = self.enc_lambda # hs_pad = lambda_ * cn_hs_pad + (1 - lambda_) * en_hs_pad # penultimate_state = lambda_ penultimate_state = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) # self.hs_pad = hs_pad # forward decoder # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # ys_mask = target_mask(ys_in_pad, self.ignore_id) # pred_pad, pred_mask, penultimate_state = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True) # plot penultimate_state, (B,T,att_dim) return penultimate_state.squeeze(0).detach().cpu().numpy() def recognize_ctc_greedy(self, x, recog_args): """Recognize input speech with ctc greedy decoding. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :return: N-best decoding results (fake results for compatibility) :rtype: list """ enc_output = self.encode(x).unsqueeze(0) # (1, T, D) lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) # shape of (T, D) idx = lpz.argmax(-1).cpu().numpy().tolist() hyp = {} if recog_args.ctc_raw_results: hyp['yseq'] = [ self.sos ] + idx # not apply ctc mapping, to get ctc alignment else: # <sos> is added here to be compatible with S2S decoding, # file: espnet/asr/asr_utils/parse_hypothesis hyp['yseq'] = [self.sos] + self.ctc_mapping(idx) logging.info(hyp['yseq']) hyp['score'] = -1 return [hyp] def ctc_mapping(self, x, blank=0): prev = blank y = [] for i in x: if i != blank and i != prev: y.append(i) prev = i return y def recognize_jca(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) # (1, T, D) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) # shape of (T, D) else: lpz = None h = enc_output.squeeze(0) # (B, T, D), #B=1 logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: if self.remove_blank_in_ctc_mode: ctc_beam = lpz.shape[-1] - 1 # except blank else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze( 0) # mask scores of future state ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: if self.remove_blank_in_ctc_mode: # here we need to filter out <blank> in local_best_ids # it happens in pure ctc-mode, when ctc_beam equals to #vocab local_best_scores, local_best_ids = torch.topk( local_att_scores[:, 1:], ctc_beam, dim=1) local_best_ids += 1 # hack else: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]] local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection # from espnet.nets.e2e_asr_common import end_detect # if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: from espnet.nets.e2e_asr_common import end_detect_yzl23 if end_detect_yzl23(ended_hyps, remained_hyps, penalty) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret
class E2E(torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed", "custom"], help='transformer input layer type') group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') return parser def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None self.rnnlm = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] self.hs_pad = hs_pad cer_ctc = None batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask start_time = time.time() # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) return self.loss, loss_ctc_data, loss_att_data, self.acc def encode(self, x, mask=None): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0).cuda() if mask is not None: mask = mask.cuda() if isinstance(self.encoder.embed, EncoderConv2d): hs, _ = self.encoder.embed( x, torch.Tensor([float(x.shape[1])]).cuda()) else: hs, _ = self.encoder.embed(x, None) enc_output, _ = self.encoder.encoders(hs, mask) if self.encoder.normalize_before: enc_output = self.encoder.after_norm(enc_output) return enc_output.squeeze(0) def viterbi_decode(self, x, y, mask=None): enc_output = self.encode(x) logits = self.ctc.ctc_lo(enc_output).detach().data logit = np.array(logits.cpu().data).T align = viterbi_align(logit, y)[0] return align def ctc_decode(self, x, mask=None): enc_output = self.encode(x) logits = self.ctc.argmax(enc_output.view(1, -1, 512)).detach().data path = np.array(logits.cpu()[0]) return path def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.cpu().detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0).cuda() ys = torch.tensor(hyp['yseq']).unsqueeze(0).cuda() # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0].cpu(), hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]].cpu() \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]].cpu() local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
class E2E(STInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" E2E.encoder_add_arguments(parser) E2E.attention_add_arguments(parser) E2E.decoder_add_arguments(parser) return parser @staticmethod def encoder_add_arguments(parser): """Add arguments for the encoder.""" group = parser.add_argument_group("E2E encoder setting") # encoder group.add_argument('--etype', default='blstmp', type=str, choices=[ 'lstm', 'blstm', 'lstmp', 'blstmp', 'vgglstmp', 'vggblstmp', 'vgglstm', 'vggblstm', 'gru', 'bgru', 'grup', 'bgrup', 'vgggrup', 'vggbgrup', 'vgggru', 'vggbgru' ], help='Type of encoder network architecture') group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') group.add_argument('--eprojs', default=320, type=int, help='Number of encoder projection units') group.add_argument( '--subsample', default="1", type=str, help= 'Subsample input frames x_y_z means subsample every x frame at 1st layer, ' 'every y frame at 2nd layer etc.') return parser @staticmethod def attention_add_arguments(parser): """Add arguments for the attention.""" group = parser.add_argument_group("E2E attention setting") # attention group.add_argument('--atype', default='dot', type=str, choices=[ 'noatt', 'dot', 'add', 'location', 'coverage', 'coverage_location', 'location2d', 'location_recurrent', 'multi_head_dot', 'multi_head_add', 'multi_head_loc', 'multi_head_multi_res_loc' ], help='Type of attention architecture') group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--awin', default=5, type=int, help='Window size for location2d attention') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') group.add_argument('--aconv-chans', default=-1, type=int, help='Number of attention convolution channels \ (negative value indicates no location-aware attention)' ) group.add_argument('--aconv-filts', default=100, type=int, help='Number of attention convolution filters \ (negative value indicates no location-aware attention)' ) group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') return parser @staticmethod def decoder_add_arguments(parser): """Add arguments for the decoder.""" group = parser.add_argument_group("E2E encoder setting") group.add_argument('--dtype', default='lstm', type=str, choices=['lstm', 'gru'], help='Type of decoder network architecture') group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') group.add_argument('--dropout-rate-decoder', default=0.0, type=float, help='Dropout rate for the decoder') group.add_argument( '--sampling-probability', default=0.0, type=float, help='Ratio of predicted labels fed back to decoder') return parser def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) self.asr_weight = getattr(args, "asr_weight", 0) self.mt_weight = getattr(args, "mt_weight", 0) self.mtlalpha = args.mtlalpha assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)" assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)" assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 self.pad = 0 # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class # in ASR. However, blank labels are not used in NMT. To keep the vocabulary size, # we use index:0 for padding instead of adding one more class. # subsample info # +1 means input (+1) and layers outputs (args.elayer) subsample = np.ones(args.elayers + 1, dtype=np.int) if args.etype.endswith("p") and not args.etype.startswith("vgg"): ss = args.subsample.split("_") for j in range(min(args.elayers + 1, len(ss))): subsample[j] = int(ss[j]) else: logging.warning( 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.' ) logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) self.subsample = subsample # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.joint_asr = getattr(args, "joint_asr", False) self.replace_sos = getattr(args, "replace_sos", False) # encoder self.enc = encoder_for(args, idim, self.subsample) # attention (ST) self.att = att_for(args) # decoder (ST) self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # submodule for ASR task self.ctc = None self.att_asr = None self.dec_asr = None if self.asr_weight > 0: if self.mtlalpha > 0.0: self.ctc = CTC(odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) if self.mtlalpha < 1.0: # attention (asr) self.att_asr = att_for(args) # decoder (asr) args_asr = copy.deepcopy(args) args_asr.atype = 'location' # TODO(hirofumi0810): make this option self.dec_asr = decoder_for(args_asr, odim, self.sos, self.eos, self.att_asr, labeldist) # submodule for MT task if self.mt_weight > 0: self.embed_mt = torch.nn.Embedding(odim, args.eunits, padding_idx=self.pad) self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate) self.enc_mt = encoder_for(args, args.eunits, subsample=np.ones(args.elayers + 1, dtype=np.int)) # weight initialization self.init_like_chainer() # options for beam search if self.asr_weight > 0 and args.report_cer or args.report_wer: recog_args = { 'beam_size': args.beam_size, 'penalty': args.penalty, 'ctc_weight': args.ctc_weight, 'maxlenratio': args.maxlenratio, 'minlenratio': args.minlenratio, 'lm_weight': args.lm_weight, 'rnnlm': args.rnnlm, 'nbest': args.nbest, 'space': args.sym_space, 'blank': args.sym_blank, 'tgt_lang': False } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False if args.report_bleu: trans_args = { 'beam_size': args.beam_size, 'penalty': args.penalty, 'ctc_weight': 0, 'maxlenratio': args.maxlenratio, 'minlenratio': args.minlenratio, 'lm_weight': args.lm_weight, 'rnnlm': args.rnnlm, 'nbest': args.nbest, 'space': args.sym_space, 'blank': args.sym_blank, 'tgt_lang': False } self.trans_args = argparse.Namespace(**trans_args) self.report_bleu = args.report_bleu else: self.report_bleu = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None def init_like_chainer(self): """Initialize weight like chainer. chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) however, there are two exceptions as far as I know. - EmbedID.W ~ Normal(0, 1) - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) """ lecun_normal_init_parameters(self) # exceptions # embed weight ~ Normal(0, 1) self.dec.embed.weight.data.normal_(0, 1) # forget-bias = 1.0 # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 for l in six.moves.range(len(self.dec.decoder)): set_forget_bias_to_one(self.dec.decoder[l].bias_ih) def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ # 0. Extract target language ID if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining else: tgt_lang_ids = None # 1. Encoder hs_pad, hlens, _ = self.enc(xs_pad, ilens) # 2. ST attention loss self.loss_st, acc, _ = self.dec(hs_pad, hlens, ys_pad, lang_ids=tgt_lang_ids) self.acc = acc # 2. ASR CTC loss if self.asr_weight == 0 or self.mtlalpha == 0: self.loss_ctc = 0.0 else: self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad_src) # 3. ASR attention loss if self.asr_weight == 0 or self.mtlalpha == 1: self.loss_asr = 0.0 self.acc_asr = 0.0 else: self.loss_asr, acc_asr, _ = self.dec_asr(hs_pad, hlens, ys_pad_src) self.acc_asr = acc_asr # 3. MT attention loss if self.mt_weight == 0: self.loss_mt = 0.0 self.acc_mt = 0.0 else: # ys_pad_src, ys_pad = self.target_forcing(ys_pad_src, ys_pad) ilens_mt = torch.sum(ys_pad_src != -1, dim=1).cpu().numpy() # NOTE: ys_pad_src is padded with -1 ys_src = [y[y != -1] for y in ys_pad_src] # parse padded ys_src ys_zero_pad_src = pad_list(ys_src, self.pad) # re-pad with zero hs_pad_mt, hlens_mt, _ = self.enc_mt( self.dropout_mt(self.embed_mt(ys_zero_pad_src)), ilens_mt) self.loss_mt, acc_mt, _ = self.dec(hs_pad_mt, hlens_mt, ys_pad) self.acc_mt = acc_mt # 4. compute cer without beam search if (self.asr_weight == 0 or self.mtlalpha == 0) or self.char_list is None: cer_ctc = None else: cers = [] y_hats = self.ctc.argmax(hs_pad).data for i, y in enumerate(y_hats): y_hat = [x[0] for x in groupby(y)] y_true = ys_pad_src[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace(self.space, ' ') seq_hat_text = seq_hat_text.replace(self.blank, '') seq_true_text = "".join(seq_true).replace(self.space, ' ') hyp_chars = seq_hat_text.replace(' ', '') ref_chars = seq_true_text.replace(' ', '') if len(ref_chars) > 0: cers.append( editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)) cer_ctc = sum(cers) / len(cers) if cers else None # 5. compute cer/wer if self.training or (self.asr_weight == 0 or self.mtlalpha == 1 or not (self.report_cer or self.report_wer)): cer, wer = 0.0, 0.0 # oracle_cer, oracle_wer = 0.0, 0.0 else: if (self.asr_weight > 0 and self.mtlalpha > 0) and self.recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(hs_pad).data else: lpz = None word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] nbest_hyps_asr = self.dec_asr.recognize_beam_batch( hs_pad, torch.tensor(hlens), lpz, self.recog_args, self.char_list, self.rnnlm) # remove <sos> and <eos> y_hats = [ nbest_hyp[0]['yseq'][1:-1] for nbest_hyp in nbest_hyps_asr ] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace( self.recog_args.space, ' ') seq_hat_text = seq_hat_text.replace(self.recog_args.blank, '') seq_true_text = "".join(seq_true).replace( self.recog_args.space, ' ') hyp_words = seq_hat_text.split() ref_words = seq_true_text.split() word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) hyp_chars = seq_hat_text.replace(' ', '') ref_chars = seq_true_text.replace(' ', '') char_eds.append(editdistance.eval(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) wer = 0.0 if not self.report_wer else float( sum(word_eds)) / sum(word_ref_lens) cer = 0.0 if not self.report_cer else float( sum(char_eds)) / sum(char_ref_lens) # 6. compute bleu if self.training or not self.report_bleu: bleu = 0.0 else: lpz = None bleus = [] nbest_hyps = self.dec.recognize_beam_batch( hs_pad, torch.tensor(hlens), lpz, self.trans_args, self.char_list, self.rnnlm, lang_ids=tgt_lang_ids.squeeze(1).tolist() if self.multilingual else None) # remove <sos> and <eos> y_hats = [nbest_hyp[0]['yseq'][1:-1] for nbest_hyp in nbest_hyps] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace( self.trans_args.space, ' ') seq_hat_text = seq_hat_text.replace(self.trans_args.blank, '') seq_true_text = "".join(seq_true).replace( self.trans_args.space, ' ') bleu = nltk.bleu_score.sentence_bleu([seq_true_text], seq_hat_text) * 100 bleus.append(bleu) bleu = 0.0 if not self.report_bleu else sum(bleus) / len(bleus) alpha = self.mtlalpha self.loss = (1 - self.asr_weight - self.mt_weight) * self.loss_st + self.asr_weight * \ (alpha * self.loss_ctc + (1 - alpha) * self.loss_asr) + self.mt_weight * self.loss_mt loss_st_data = float(self.loss_st) loss_asr_data = float(alpha * self.loss_ctc + (1 - alpha) * self.loss_asr) loss_mt_data = None if self.mt_weight == 0 else float(self.loss_mt) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_asr_data, loss_mt_data, loss_st_data, self.acc_asr, self.acc_mt, acc, cer_ctc, cer, wer, bleu, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.dec) def encode(self, x): """Encode acoustic features. :param ndarray x: input acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() ilens = [x.shape[0]] # subsample frame x = x[::self.subsample[0], :] p = next(self.parameters()) h = torch.as_tensor(x, device=p.device, dtype=p.dtype) # make a utt list (1) to use the same interface for encoder hs = h.contiguous().unsqueeze(0) # 1. encoder hs, _, _ = self.enc(hs, ilens) return hs.squeeze(0) def translate(self, x, trans_args, char_list, rnnlm=None): """E2E beam search. :param ndarray x: input acoustic feature (T, D) :param Namespace trans_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ hs = self.encode(x).unsqueeze(0) lpz = None # 2. Decoder # decode the first utterance y = self.dec.recognize_beam(hs[0], lpz, trans_args, char_list, rnnlm) return y def translate_batch(self, xs, trans_args, char_list, rnnlm=None): """E2E beam search. :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] :param Namespace trans_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) # 1. Encoder hs_pad, hlens, _ = self.enc(xs_pad, ilens) lpz = None # 2. Decoder hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor y = self.dec.recognize_beam_batch(hs_pad, hlens, lpz, trans_args, char_list, rnnlm) if prev: self.train() return y def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_src): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): # 1. Encoder if self.multilingual: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining else: tgt_lang_ids = None hpad, hlens, _ = self.enc(xs_pad, ilens) # 2. Decoder att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad, lang_ids=tgt_lang_ids) return att_ws def subsample_frames(self, x): """Subsample speeh frames in the encoder.""" # subsample frame x = x[::self.subsample[0], :] ilen = [x.shape[0]] h = to_device(self, torch.from_numpy(np.array(x, dtype=np.float32))) h.contiguous() return h, ilen
class E2E(ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group = add_arguments_transformer_common(group) return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.criterion = LabelSmoothingLoss( odim, ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) else: self.decoder = None self.criterion = None self.blank = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model) self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, : max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder if self.decoder is not None: ys_in_pad, ys_out_pad = add_sos_eos( ys_pad, self.sos, self.eos, self.ignore_id ) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy( pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id ) else: loss_att = None self.acc = None # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if not self.training and self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # for visualization if not self.training: self.ctc.softmax(hs_pad) # 5. compute cer/wer if self.training or self.error_calculator is None or self.decoder is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) enc_output, _ = self.encoder(x, None) return enc_output.squeeze(0) def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if self.mtlalpha == 1.0: recog_args.ctc_weight = 1.0 logging.info("Set to pure CTC decoding mode.") if self.mtlalpha > 0 and recog_args.ctc_weight == 1.0: from itertools import groupby lpz = self.ctc.argmax(enc_output) collapsed_indices = [x[0] for x in groupby(lpz[0])] hyp = [x for x in filter(lambda x: x != self.blank, collapsed_indices)] nbest_hyps = [{"score": 0.0, "yseq": [self.sos] + hyp}] if recog_args.beam_size > 1: raise NotImplementedError("Pure CTC beam search is not implemented.") # TODO(hirofumi0810): Implement beam search return nbest_hyps elif self.mtlalpha > 0 and recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output) ) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output )[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) local_scores = ( local_att_scores + recog_args.lm_weight * local_lm_scores ) else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1 ) ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"] ) local_scores = (1.0 - ctc_weight) * local_att_scores[ :, local_best_ids[0] ] + ctc_weight * torch.from_numpy( ctc_scores - hyp["ctc_score_prev"] ) if rnnlm: local_scores += ( recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] ) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1 ) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1 ) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[0, j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x["score"], reverse=True )[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) ) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"] ) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) ) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ : min(len(ended_hyps), recog_args.nbest) ] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( "there is no N-best results, perform recognition " "again with smaller minlenratio." ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info( "normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) ) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights (B, H, Lmax, Tmax) :rtype: float ndarray """ self.eval() with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if ( isinstance(m, MultiHeadedAttention) or isinstance(m, DynamicConvolution) or isinstance(m, RelPositionMultiHeadedAttention) ): ret[name] = m.attn.cpu().numpy() if isinstance(m, DynamicConvolution2D): ret[name + "_time"] = m.attn_t.cpu().numpy() ret[name + "_freq"] = m.attn_f.cpu().numpy() self.train() return ret def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad): """E2E CTC probability calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: CTC probability (B, Tmax, vocab) :rtype: float ndarray """ ret = None if self.mtlalpha == 0: return ret self.eval() with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) for name, m in self.named_modules(): if isinstance(m, CTC) and m.probs is not None: ret = m.probs.cpu().numpy() self.train() return ret
class E2E(E2EASR, ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" E2EASR.add_arguments(parser) E2EASRMIX.encoder_mix_add_arguments(parser) return parser def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__(idim, odim, args, ignore_id=-1) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = EncoderMix( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks_sd=args.elayers_sd, num_blocks_rec=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, num_spkrs=args.num_spkrs, ) if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=False) else: self.ctc = None self.num_spkrs = args.num_spkrs self.pit = PIT(self.num_spkrs) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, num_spkrs, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # list: speaker differentiate self.hs_pad = hs_pad # 2. ctc # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None assert self.mtlalpha > 0.0 batch_size = xs_pad.size(0) ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax) hs_len = [ hs_mask[i].view(batch_size, -1).sum(1) for i in range(self.num_spkrs) ] loss_ctc_perm = torch.stack( [ self.ctc( hs_pad[i // self.num_spkrs].view(batch_size, -1, self.adim), hs_len[i // self.num_spkrs], ys_pad[i % self.num_spkrs], ) for i in range(self.num_spkrs**2) ], dim=1, ) # (B, num_spkrs^2) loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm) logging.info("ctc loss:" + str(float(loss_ctc))) # Permute the labels according to loss for b in range(batch_size): # B ys_pad[:, b] = ys_pad[min_perm[b], b] # (num_spkrs, B, Lmax) ys_out_len = [ float(torch.sum(ys_pad[i] != self.ignore_id)) for i in range(self.num_spkrs) ] # TODO(karita) show predicted text # TODO(karita) calculate these stats if self.error_calculator is not None: cer_ctc = [] for i in range(self.num_spkrs): ys_hat = self.ctc.argmax(hs_pad[i].view( batch_size, -1, self.adim)).data cer_ctc.append( self.error_calculator(ys_hat.cpu(), ys_pad[i].cpu(), is_ctc=True)) cer_ctc = sum(map(lambda x: x[0] * x[1], zip( cer_ctc, ys_out_len))) / sum(ys_out_len) else: cer_ctc = None # 3. forward decoder if self.mtlalpha == 1.0: loss_att, self.acc, cer, wer = None, None, None, None else: pred_pad, pred_mask = [None] * self.num_spkrs, [None ] * self.num_spkrs loss_att, acc = [None] * self.num_spkrs, [None] * self.num_spkrs for i in range(self.num_spkrs): ( pred_pad[i], pred_mask[i], loss_att[i], acc[i], ) = self.decoder_and_attention(hs_pad[i], hs_mask[i], ys_pad[i], batch_size) # 4. compute attention loss # The following is just an approximation loss_att = sum( map(lambda x: x[0] * x[1], zip(loss_att, ys_out_len))) / sum(ys_out_len) self.acc = sum(map(lambda x: x[0] * x[1], zip( acc, ys_out_len))) / sum(ys_out_len) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss def decoder_and_attention(self, hs_pad, hs_mask, ys_pad, batch_size): """Forward decoder and attention loss.""" # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) # compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) return pred_pad, pred_mask, loss_att, acc def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) enc_output, _ = self.encoder(x, None) return enc_output def recog(self, enc_output, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech of each speaker. :param ndnarray enc_output: encoder outputs (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp["rnnlm_prev"], vy) local_scores = (local_att_scores + recog_args.lm_weight * local_lm_scores) else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"]) local_scores = ( 1.0 - ctc_weight) * local_att_scores[:, local_best_ids[ 0]] + ctc_weight * torch.from_numpy( ctc_scores - hyp["ctc_score_prev"]) if rnnlm: local_scores += (recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float( local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[ 0, j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last position in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"]) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning("there is no N-best results, perform recognition " "again with smaller minlenratio.") # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recog(enc_output, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info("normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))) return nbest_hyps def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech of each speaker. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ # Encoder enc_output = self.encode(x) # Decoder nbest_hyps = [] for enc_out in enc_output: nbest_hyps.append( self.recog(enc_out, recog_args, char_list, rnnlm, use_jit)) return nbest_hyps
class E2E(torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed", "custom"], help='transformer input layer type') group.add_argument("--transformer-output-layer", type=str, default='embed', choices=['conv', 'embed', 'linear']) group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') # Streaming params group.add_argument( '--chunk', default=True, type=strtobool, help= 'streaming mode, set True for chunk-encoder, False for look-ahead encoder' ) group.add_argument('--chunk-size', default=16, type=int, help='chunk size for chunk-based encoder') group.add_argument( '--left-window', default=1000, type=int, help='left window size for look-ahead based encoder') group.add_argument( '--right-window', default=1000, type=int, help='right window size for look-ahead based encoder') group.add_argument( '--dec-left-window', default=0, type=int, help='left window size for decoder (look-ahead based method)') group.add_argument( '--dec-right-window', default=6, type=int, help='right window size for decoder (look-ahead based method)') return parser def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=args.transformer_output_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None self.rnnlm = None self.left_window = args.dec_left_window self.right_window = args.dec_right_window def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel batch_size = xs_pad.shape[0] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) if isinstance(self.encoder.embed, EncoderConv2d): xs, hs_mask = self.encoder.embed(xs_pad, torch.sum(src_mask, 2).squeeze()) hs_mask = hs_mask.unsqueeze(1) else: xs, hs_mask = self.encoder.embed(xs_pad, src_mask) if enc_mask is not None: enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]] enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask hs_pad, _ = self.encoder.encoders(xs, enc_mask) if self.encoder.normalize_before: hs_pad = self.encoder.after_norm(hs_pad) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] if dec_mask is not None: dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]] self.hs_pad = hs_pad batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_ctc_data, loss_att_data, self.acc def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x, mask=None): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0).cuda() if mask is not None: mask = mask.cuda() if isinstance(self.encoder.embed, EncoderConv2d): hs, _ = self.encoder.embed( x, torch.Tensor([float(x.shape[1])]).cuda()) else: hs, _ = self.encoder.embed(x, None) hs, _ = self.encoder.encoders(hs, mask) if self.encoder.normalize_before: hs = self.encoder.after_norm(hs) return hs.squeeze(0) def viterbi_decode(self, x, y, mask=None): enc_output = self.encode(x, mask) logits = self.ctc.ctc_lo(enc_output).detach().data logit = np.array(logits.cpu().data).T align = viterbi_align(logit, y)[0] return align def ctc_decode(self, x, mask=None): enc_output = self.encode(x, mask) logits = self.ctc.argmax(enc_output.view(1, -1, 512)).detach().data path = np.array(logits.cpu()[0]) return path def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.cpu().detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0).cuda() ys = torch.tensor(hyp['yseq']).unsqueeze(0).cuda() # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0].cpu(), hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]].cpu() \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]].cpu() local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def prefix_recognize(self, x, recog_args, train_args, char_list=None, rnnlm=None): '''recognize feat :param ndnarray x: input acouctic feature (B, T, D) or (T, D) :param namespace recog_args: argment namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list TODO(karita): do not recompute previous attention for faster decoding ''' pad_len = self.eos - len(char_list) + 1 for i in range(pad_len): char_list.append('<eos>') if isinstance(self.encoder.embed, EncoderConv2d): seq_len = ((x.shape[0] + 1) // 2 + 1) // 2 else: seq_len = ((x.shape[0] - 1) // 2 - 1) // 2 if train_args.chunk: s = np.arange(0, seq_len, train_args.chunk_size) mask = adaptive_enc_mask(seq_len, s).unsqueeze(0) else: mask = turncated_mask(1, seq_len, train_args.left_window, train_args.right_window) enc_output = self.encode(x, mask).unsqueeze(0) lpz = torch.nn.functional.softmax(self.ctc.ctc_lo(enc_output), dim=-1) lpz = lpz.squeeze(0) h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) h_len = h.size(0) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) hyp = { 'score': 0.0, 'yseq': [y], 'rnnlm_prev': None, 'seq': char_list[y], 'last_time': [], "ctc_score": 0.0, "rnnlm_score": 0.0, "att_score": 0.0, "cache": None, "precache": None, "preatt_score": 0.0, "prev_score": 0.0 } hyps = {char_list[y]: hyp} hyps_att = {char_list[y]: hyp} Pb_prev, Pnb_prev = Counter(), Counter() Pb, Pnb = Counter(), Counter() Pjoint = Counter() lpz = lpz.cpu().detach().numpy() vocab_size = lpz.shape[1] r = np.ndarray((vocab_size), dtype=np.float32) l = char_list[y] Pb_prev[l] = 1 Pnb_prev[l] = 0 A_prev = [l] A_prev_id = [[y]] vy.unsqueeze(1) total_copy = time.time() - time.time() samelen = 0 hat_att = {} if mask is not None: chunk_pos = set(np.array(mask.sum(dim=-1))[0]) for i in chunk_pos: hat_att[i] = {} else: hat_att[enc_output.shape[1]] = {} for i in range(h_len): hyps_ctc = {} threshold = recog_args.threshold #self.threshold #np.percentile(r, 98) pos_ctc = np.where(lpz[i] > threshold)[0] #self.removeIlegal(hyps) if mask is not None: chunk_index = mask[0][i].sum().item() else: chunk_index = h_len hyps_res = {} for l, hyp in hyps.items(): if l in hat_att[chunk_index]: hyp['tmp_cache'] = hat_att[chunk_index][l]['cache'] hyp['tmp_att'] = hat_att[chunk_index][l]['att_scores'] else: hyps_res[l] = hyp tmp = self.clusterbyLength( hyps_res ) # This step clusters hyps according to length dict:{length,hyps} start = time.time() # pre-compute beam self.compute_hyps(tmp, i, h_len, enc_output, hat_att[chunk_index], mask) total_copy += time.time() - start # Assign score and tokens to hyps #print(hyps.keys()) for l, hyp in hyps.items(): if 'tmp_att' not in hyp: continue #Todo check why local_att_scores = hyp['tmp_att'] local_best_scores, local_best_ids = torch.topk( local_att_scores, 5, dim=1) pos_att = np.array(local_best_ids[0].cpu()) pos = np.union1d(pos_ctc, pos_att) hyp['pos'] = pos # pre-compute ctc beam hyps_ctc_compute = self.get_ctchyps2compute(hyps, hyps_ctc, i) hyps_res2 = {} for l, hyp in hyps_ctc_compute.items(): l_minus = ' '.join(l.split()[:-1]) if l_minus in hat_att[chunk_index]: hyp['tmp_cur_new_cache'] = hat_att[chunk_index][l_minus][ 'cache'] hyp['tmp_cur_att_scores'] = hat_att[chunk_index][l_minus][ 'att_scores'] else: hyps_res2[l] = hyp tmp2_cluster = self.clusterbyLength(hyps_res2) self.compute_hyps_ctc(tmp2_cluster, h_len, enc_output, hat_att[chunk_index], mask) for l, hyp in hyps.items(): start = time.time() l_id = hyp['yseq'] l_end = l_id[-1] vy[0] = l_end prefix_len = len(l_id) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) else: rnnlm_state = None local_lm_scores = torch.zeros([1, len(char_list)]) r = lpz[i] * (Pb_prev[l] + Pnb_prev[l]) start = time.time() if 'tmp_att' not in hyp: continue #Todo check why local_att_scores = hyp['tmp_att'] new_cache = hyp['tmp_cache'] align = [0] * prefix_len align[:prefix_len - 1] = hyp['last_time'][:] align[-1] = i pos = hyp['pos'] if 0 in pos or l_end in pos: if l not in hyps_ctc: hyps_ctc[l] = {'yseq': l_id} hyps_ctc[l]['rnnlm_prev'] = hyp['rnnlm_prev'] hyps_ctc[l]['rnnlm_score'] = hyp['rnnlm_score'] if l_end != self.eos: hyps_ctc[l]['last_time'] = [0] * prefix_len hyps_ctc[l]['last_time'][:] = hyp['last_time'][:] hyps_ctc[l]['last_time'][-1] = i try: cur_att_scores = hyps_ctc_compute[l][ "tmp_cur_att_scores"] cur_new_cache = hyps_ctc_compute[l][ "tmp_cur_new_cache"] except: pdb.set_trace() hyps_ctc[l]['att_score'] = hyp['preatt_score'] + \ float(cur_att_scores[0, l_end].data) hyps_ctc[l]['cur_att'] = float( cur_att_scores[0, l_end].data) hyps_ctc[l]['cache'] = cur_new_cache else: if len(hyps_ctc[l]["yseq"]) > 1: hyps_ctc[l]["end"] = True hyps_ctc[l]['last_time'] = [] hyps_ctc[l]['att_score'] = hyp['att_score'] hyps_ctc[l]['cur_att'] = 0 hyps_ctc[l]['cache'] = hyp['cache'] hyps_ctc[l]['prev_score'] = hyp['prev_score'] hyps_ctc[l]['preatt_score'] = hyp['preatt_score'] hyps_ctc[l]['precache'] = hyp['precache'] hyps_ctc[l]['seq'] = hyp['seq'] for c in list(pos): if c == 0: Pb[l] += lpz[i][0] * (Pb_prev[l] + Pnb_prev[l]) else: l_plus = l + " " + char_list[c] if l_plus not in hyps_ctc: hyps_ctc[l_plus] = {} if "end" in hyp: hyps_ctc[l_plus]['yseq'] = True hyps_ctc[l_plus]['yseq'] = [0] * (prefix_len + 1) hyps_ctc[l_plus]['yseq'][:len(hyp['yseq'])] = l_id hyps_ctc[l_plus]['yseq'][-1] = int(c) hyps_ctc[l_plus]['rnnlm_prev'] = rnnlm_state hyps_ctc[l_plus][ 'rnnlm_score'] = hyp['rnnlm_score'] + float( local_lm_scores[0, c].data) hyps_ctc[l_plus]['att_score'] = hyp['att_score'] \ + float(local_att_scores[0, c].data) hyps_ctc[l_plus]['cur_att'] = float( local_att_scores[0, c].data) hyps_ctc[l_plus]['cache'] = new_cache hyps_ctc[l_plus]['precache'] = hyp['cache'] hyps_ctc[l_plus]['preatt_score'] = hyp['att_score'] hyps_ctc[l_plus]['prev_score'] = hyp['score'] hyps_ctc[l_plus]['last_time'] = align hyps_ctc[l_plus]['rule_penalty'] = 0 hyps_ctc[l_plus]['seq'] = l_plus if l_end != self.eos and c == l_end: Pnb[l_plus] += lpz[i][l_end] * Pb_prev[l] Pnb[l] += lpz[i][l_end] * Pnb_prev[l] else: Pnb[l_plus] += r[c] if l_plus not in hyps: Pb[l_plus] += lpz[i][0] * (Pb_prev[l_plus] + Pnb_prev[l_plus]) Pb[l_plus] += lpz[i][c] * Pnb_prev[l_plus] #total_copy += time.time() - start for l in hyps_ctc.keys(): if Pb[l] != 0 or Pnb[l] != 0: hyps_ctc[l]['ctc_score'] = np.log(Pb[l] + Pnb[l]) else: hyps_ctc[l]['ctc_score'] = float('-inf') local_score = hyps_ctc[l]['ctc_score'] + recog_args.ctc_lm_weight * hyps_ctc[l]['rnnlm_score'] + \ recog_args.penalty * (len(hyps_ctc[l]['yseq'])) hyps_ctc[l]['local_score'] = local_score hyps_ctc[l]['score'] = (1-recog_args.ctc_weight) * hyps_ctc[l]['att_score'] \ + recog_args.ctc_weight * hyps_ctc[l]['ctc_score'] + \ recog_args.penalty * (len(hyps_ctc[l]['yseq'])) + \ recog_args.lm_weight * hyps_ctc[l]['rnnlm_score'] Pb_prev = Pb Pnb_prev = Pnb Pb = Counter() Pnb = Counter() hyps1 = sorted(hyps_ctc.items(), key=lambda x: x[1]['local_score'], reverse=True)[:beam] hyps1 = dict(hyps1) hyps2 = sorted(hyps_ctc.items(), key=lambda x: x[1]['att_score'], reverse=True)[:beam] hyps2 = dict(hyps2) hyps = sorted(hyps_ctc.items(), key=lambda x: x[1]['score'], reverse=True)[:beam] hyps = dict(hyps) for key in hyps1.keys(): if key not in hyps: hyps[key] = hyps1[key] for key in hyps2.keys(): if key not in hyps: hyps[key] = hyps2[key] hyps = sorted(hyps.items(), key=lambda x: x[1]['score'], reverse=True)[:beam] hyps = dict(hyps) logging.info('input lengths: ' + str(h.size(0))) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) if "<eos>" in hyps.keys(): del hyps["<eos>"] #for key in hyps.keys(): # logging.info("{0}\tctc:{1}\tatt:{2}\trnnlm:{3}\tscore:{4}".format(key,hyps[key]["ctc_score"],hyps[key]['att_score'], # hyps[key]['rnnlm_score'], hyps[key]['score'])) # print("!!!","Decoding None") best = list(hyps.keys())[0] ids = hyps[best]['yseq'] score = hyps[best]['score'] logging.info('score: ' + str(score)) #if l in hyps.keys(): # logging.info(l) #print(samelen,h_len) return best, ids, score def removeIlegal(self, hyps): max_y = max([len(hyp['yseq']) for l, hyp in hyps.items()]) for_remove = [] for l, hyp in hyps.items(): if max_y - len(hyp['yseq']) > 4: for_remove.append(l) for cur_str in for_remove: del hyps[cur_str] def clusterbyLength(self, hyps): tmp = {} for l, hyp in hyps.items(): prefix_len = len(hyp['yseq']) if prefix_len > 1 and hyp['yseq'][-1] == self.eos: continue else: if prefix_len not in tmp: tmp[prefix_len] = [] tmp[prefix_len].append(hyp) return tmp def compute_hyps(self, current_hyps, curren_frame, total_frame, enc_output, hat_att, enc_mask=None): for length, hyps_t in current_hyps.items(): ys_mask = subsequent_mask(length).unsqueeze(0).cuda() ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1) # print(ys_mask4use.shape) l_id = [hyp_t['yseq'] for hyp_t in hyps_t] ys4use = torch.tensor(l_id).cuda() enc_output4use = enc_output.repeat(len(hyps_t), 1, 1) if hyps_t[0]["cache"] is None: cache4use = None else: cache4use = [] for decode_num in range(len(hyps_t[0]["cache"])): current_cache = [] for hyp_t in hyps_t: current_cache.append( hyp_t["cache"][decode_num].squeeze(0)) # print( torch.stack(current_cache).shape) current_cache = torch.stack(current_cache) cache4use.append(current_cache) partial_mask4use = [] for hyp_t in hyps_t: #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time'])+1, enc_mask.shape[1]]).byte()) align = [0] * length align[:length - 1] = hyp_t['last_time'][:] align[-1] = curren_frame align_tensor = torch.tensor(align).unsqueeze(0) if enc_mask is not None: partial_mask = enc_mask[0][align_tensor] else: right_window = self.right_window partial_mask = trigger_mask(1, total_frame, align_tensor, self.left_window, right_window) partial_mask4use.append(partial_mask) partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1) local_att_scores_b, new_cache_b = self.decoder.forward_one_step( ys4use, ys_mask4use, enc_output4use, partial_mask4use, cache4use) for idx, hyp_t in enumerate(hyps_t): hyp_t['tmp_cache'] = [ new_cache_b[decode_num][idx].unsqueeze(0) for decode_num in range(len(new_cache_b)) ] hyp_t['tmp_att'] = local_att_scores_b[idx].unsqueeze(0) hat_att[hyp_t['seq']] = {} hat_att[hyp_t['seq']]['cache'] = hyp_t['tmp_cache'] hat_att[hyp_t['seq']]['att_scores'] = hyp_t['tmp_att'] def get_ctchyps2compute(self, hyps, hyps_ctc, current_frame): tmp2 = {} for l, hyp in hyps.items(): l_id = hyp['yseq'] l_end = l_id[-1] if "pos" not in hyp: continue if 0 in hyp['pos'] or l_end in hyp['pos']: #l_minus = ' '.join(l.split()[:-1]) #if l_minus in hat_att: # hyps[l]['tmp_cur_new_cache'] = hat_att[l_minus]['cache'] # hyps[l]['tmp_cur_att_scores'] = hat_att[l_minus]['att_scores'] # continue if l not in hyps_ctc and l_end != self.eos: tmp2[l] = {'yseq': l_id} tmp2[l]['seq'] = l tmp2[l]['rnnlm_prev'] = hyp['rnnlm_prev'] tmp2[l]['rnnlm_score'] = hyp['rnnlm_score'] if l_end != self.eos: tmp2[l]['last_time'] = [0] * len(l_id) tmp2[l]['last_time'][:] = hyp['last_time'][:] tmp2[l]['last_time'][-1] = current_frame return tmp2 def compute_hyps_ctc(self, hyps_ctc_cluster, total_frame, enc_output, hat_att, enc_mask=None): for length, hyps_t in hyps_ctc_cluster.items(): ys_mask = subsequent_mask(length - 1).unsqueeze(0).cuda() ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1) l_id = [hyp_t['yseq'][:-1] for hyp_t in hyps_t] ys4use = torch.tensor(l_id).cuda() enc_output4use = enc_output.repeat(len(hyps_t), 1, 1) if "precache" not in hyps_t[0] or hyps_t[0]["precache"] is None: cache4use = None else: cache4use = [] for decode_num in range(len(hyps_t[0]["precache"])): current_cache = [] for hyp_t in hyps_t: # print(length, hyp_t["yseq"], hyp_t["cache"][0].shape, # hyp_t["cache"][2].shape, hyp_t["cache"][4].shape) current_cache.append( hyp_t["precache"][decode_num].squeeze(0)) current_cache = torch.stack(current_cache) cache4use.append(current_cache) partial_mask4use = [] for hyp_t in hyps_t: #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time']), enc_mask.shape[1]]).byte()) align = hyp_t['last_time'] align_tensor = torch.tensor(align).unsqueeze(0) if enc_mask is not None: partial_mask = enc_mask[0][align_tensor] else: right_window = self.right_window partial_mask = trigger_mask(1, total_frame, align_tensor, self.left_window, right_window) partial_mask4use.append(partial_mask) partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1) local_att_scores_b, new_cache_b = \ self.decoder.forward_one_step(ys4use, ys_mask4use, enc_output4use, partial_mask4use, cache4use) for idx, hyp_t in enumerate(hyps_t): hyp_t['tmp_cur_new_cache'] = [ new_cache_b[decode_num][idx].unsqueeze(0) for decode_num in range(len(new_cache_b)) ] hyp_t['tmp_cur_att_scores'] = local_att_scores_b[ idx].unsqueeze(0) l_minus = ' '.join(hyp_t['seq'].split()[:-1]) hat_att[l_minus] = {} hat_att[l_minus]['att_scores'] = hyp_t['tmp_cur_att_scores'] hat_att[l_minus]['cache'] = hyp_t['tmp_cur_new_cache']
class E2E(ASRInterface, torch.nn.Module): @staticmethod def add_arguments(parser): group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=["pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help='transformer input layer type') group.add_argument('--transformer-attn-dropout-rate', default=None, type=float, help='dropout in transformer attention. use --dropout-rate if None is set') group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument("--transformer-encoder-center-chunk-len", type=int, default=8, help='transformer chunk size, only used when transformer-encoder-type is memory') group.add_argument("--transformer-encoder-left-chunk-len", type=int, default=0, help='transformer left chunk size') group.add_argument("--transformer-encoder-hop-len", type=int, default=0, help='transformer encoder hop len, default') group.add_argument("--transformer-encoder-right-chunk-len", type=int, default=0, help='future data of the encoder') group.add_argument("--transformer-encoder-abs-embed", type=int, default=1, help='whether the network use absolute embed') group.add_argument("--transformer-encoder-rel-embed", type=int, default=0, help='whether the network us reality embed') group.add_argument("--transformer-encoder-use-memory", type=int, default=0, help='whether the network us memory to store history') return parser @property def attention_plot_class(self): return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, center_len=args.transformer_encoder_center_chunk_len, left_len=args.transformer_encoder_left_chunk_len, hop_len=args.transformer_encoder_hop_len, right_len=args.transformer_encoder_right_chunk_len, abs_pos=args.transformer_encoder_abs_embed, rel_pos=args.transformer_encoder_rel_embed, use_mem=args.transformer_encoder_use_memory, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer or args.mtlalpha > 0.0: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None def reset_parameters(self, args): if args.transformer_init == "pytorch": return # weight init for p in self.parameters(): if p.dim() > 1: if args.transformer_init == "xavier_uniform": torch.nn.init.xavier_uniform_(p.data) elif args.transformer_init == "xavier_normal": torch.nn.init.xavier_normal_(p.data) elif args.transformer_init == "kaiming_uniform": torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") elif args.transformer_init == "kaiming_normal": torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") else: raise ValueError("Unknown initialization: " + args.transformer_init) # bias init for p in self.parameters(): if p.dim() == 1: p.data.zero_() # reset some modules with default init for m in self.modules(): if isinstance(m, (torch.nn.Embedding, LayerNorm)): m.reset_parameters() def add_sos_eos(self, ys_pad): from espnet.nets.pytorch_backend.nets_utils import pad_list eos = ys_pad.new([self.eos]) sos = ys_pad.new([self.sos]) ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys ys_in = [torch.cat([sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, eos], dim=0) for y in ys] return pad_list(ys_in, self.eos), pad_list(ys_out, self.ignore_id) def target_mask(self, ys_in_pad): ys_mask = ys_in_pad != self.ignore_id m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(-2) & m def forward(self, xs_pad, ilens, ys_pad): '''E2E forward :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float ''' # forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # forward decoder ys_in_pad, ys_out_pad = self.add_sos_eos(ys_pad) ys_mask = self.target_mask(ys_in_pad) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # compute loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predected text # TODO(karita) calculate these stats if self.mtlalpha == 0.0: loss_ctc = None cer_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def recognize(self, feat, recog_args, char_list=None, rnnlm=None, use_jit=False): '''recognize feat :param ndnarray x: input acouctic feature (B, T, D) or (T, D) :param namespace recog_args: argment namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list TODO(karita): do not recompute previous attention for faster decoding ''' self.eval() feat = torch.as_tensor(feat).unsqueeze(0) enc_output, _ = self.encoder(feat, None) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace(self.decoder.recognize, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output) else: local_att_scores = self.decoder.recognize(ys, ys_mask, enc_output) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted( hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.') # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(feat, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): '''E2E attention calculation :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray ''' with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret
class E2E(ASRInterface, torch.nn.Module): """E2E module. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ @staticmethod def add_arguments(parser): """Extend arguments for transducer models. Both Transformer and RNN modules are supported. General options encapsulate both modules options. """ group = parser.add_argument_group("transformer model setting") # Encoder - general group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') group.add_argument('--kernel-size', default=32, type=int) group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder - RNN group.add_argument('--eprojs', default=320, type=int, help='Number of encoder projection units') group.add_argument( '--subsample', default="1", type=str, help='Subsample input frames x_y_z means subsample every x frame \ at 1st layer, every y frame at 2nd layer etc.') # Attention - general group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder - general group.add_argument('--dtype', default='lstm', type=str, choices=['lstm', 'gru', 'transformer'], help='Type of decoder to use.') group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') group.add_argument('--dropout-rate-decoder', default=0.0, type=float, help='Dropout rate for the decoder') group.add_argument('--relative-v', default=False, type=strtobool) # Decoder - RNN group.add_argument('--dec-embed-dim', default=320, type=int, help='Number of decoder embeddings dimensions') group.add_argument('--dropout-rate-embed-decoder', default=0.0, type=float, help='Dropout rate for the decoder embeddings') # Transformer group.add_argument( "--input-layer", type=str, default="conv2d", choices=["conv2d", "vgg2l", "linear", "embed", "custom"], help='transformer encoder input layer type') group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=False, type=strtobool, help='normalize loss by length') group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument('--causal', type=strtobool, default=False) return parser def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0): """Construct an E2E object for transducer model. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder(idim=idim, d_model=args.adim, n_heads=args.aheads, d_ffn=args.eunits, layers=args.elayers, kernel_size=args.kernel_size, input_layer=args.input_layer, dropout_rate=args.dropout_rate, causal=args.causal) args.eprojs = args.adim if args.mtlalpha < 1.0: self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.ignore_id = ignore_id self.odim = odim self.adim = args.adim self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad, enc_mask=None): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder xs_pad = xs_pad[:, :max(ilens)] masks = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, masks) self.hs_pad = hs_pad # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] self.hs_pad = hs_pad cer_ctc = None batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask start_time = time.time() # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_att_data, loss_ctc_data, self.acc def encode(self, x, streaming_mask=None): """Encode acoustic features. Args: x (ndarray): input acoustic feature (T, D) Returns: x (torch.Tensor): encoded features (T, attention_dim) """ self.eval() x = torch.as_tensor(x).unsqueeze(0).cuda() enc_output, _ = self.encoder(x, None, streaming_mask) return enc_output.squeeze(0) def greedy_recognize(self, x, recog_args, streaming_mask=None): h = self.encode(x, streaming_mask).unsqueeze(0) lpz = self.ctc.log_softmax(h) ys_hat = lpz.argmax(dim=-1)[0].cpu() ids = [] for i in range(len(ys_hat)): if ys_hat[i].item() != 0 and ys_hat[i].item() != ys_hat[i - 1].item(): ids.append(ys_hat[i].item()) rets = [{'score': 0.0, 'yseq': ids}] return rets def recognize(self, x, recog_args, char_list=None, rnnlm=None, streaming_mask=None): """ Recognize input features. Args: x (ndarray): input acoustic feature (T, D) recog_args (namespace): argument Namespace containing options char_list (list): list of characters rnnlm (torch.nn.Module): language model module Returns: y (list): n-best decoding results """ h = self.encode(x, streaming_mask) params = [h, recog_args] if recog_args.beam_size == 1: nbest_hyps = self.decoder.recognize(*params) else: params.append(rnnlm) start_time = time.time() nbest_hyps = self.decoder.recognize_beam(*params) #nbest_hyps, decoder_time = self.decoder.beam_search(h, recog_args, prefix=False) end_time = time.time() decode_time = end_time - start_time return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. Args: xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: ret (ndarray): attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). """ if self.etype == 'transformer' and self.dtype != 'transformer' and \ self.rnnt_mode == 'rnnt-att': raise NotImplementedError( "Transformer encoder with rnn attention decoder" "is not supported yet.") elif self.etype != 'transformer' and self.dtype != 'transformer': if self.rnnt_mode == 'rnnt': return [] else: with torch.no_grad(): hs_pad, hlens = xs_pad, ilens hpad, hlens, _ = self.encoder(hs_pad, hlens) ret = self.decoder.calculate_all_attentions( hpad, hlens, ys_pad) else: with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret
class E2E(ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help='transformer input layer type') group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def recognize(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = x.unsqueeze(0) enc_output, _ = self.encoder(x, None) return enc_output.squeeze(0) def forward1(self, x): # enc_output = self.encode(x).unsqueeze(0) lpz = self.ctc.log_softmax(x) return lpz def forward(self, x): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ ctc_weight = 0.5 beam_size = 1 penalty = 0.0 maxlenratio = 0.0 nbest = 1 sos = eos = 7442 import onnxruntime # enc_output = self.encode(x).unsqueeze(0) encoder_rt = onnxruntime.InferenceSession("encoder.onnx") ort_inputs = {"x": x.numpy()} enc_output = encoder_rt.run(None, ort_inputs) enc_output = torch.tensor(enc_output) enc_output = enc_output.squeeze(0) # print(f"enc_output shape: {enc_output.shape}") # lpz = self.ctc.log_softmax(enc_output) # lpz = lpz.squeeze(0) ctc_softmax_rt = onnxruntime.InferenceSession("ctc_softmax.onnx") ort_inputs = {"enc_output": enc_output.numpy()} lpz = ctc_softmax_rt.run(None, ort_inputs) lpz = torch.tensor(lpz) lpz = lpz.squeeze(0).squeeze(0) # print(f"lpz shape: {lpz.shape}") h = enc_output.squeeze(0) # print(f"h shape: {h.shape}") # preprare sos y = sos maxlen = h.shape[0] minlen = 0.0 # initialize hypothesis hyp = {'score': 0.0, 'yseq': [y]} ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 # pre-pruning based on attention scores # ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) ctc_beam = 1 hyps = [hyp] ended_hyps = [] decoder_fos_rt = onnxruntime.InferenceSession("decoder_fos.onnx") for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) ort_inputs = { "ys": ys.numpy(), "ys_mask": ys_mask.numpy(), "enc_output": enc_output.numpy() } local_att_scores = decoder_fos_rt.run(None, ort_inputs) local_att_scores = torch.tensor(local_att_scores[0]) # local_att_scores = self.decoder.forward_one_step(ys, ys_mask, enc_output)[0] local_scores = local_att_scores # print(local_scores.shape) 1, 7443 local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) # print(local_best_scores.shape) 1, 1 ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) local_best_scores, joint_best_ids = torch.topk(local_scores, beam_size, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] for j in six.moves.range(beam_size): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam_size] # sort and get nbest hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp['yseq'].append(eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and maxlenratio == 0.0: break hyps = remained_hyps if len(hyps) > 0: pass else: break nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), nbest)] # return nbest_hyps return torch.tensor(nbest_hyps[0]['yseq']) def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret
class Transformer(t.nn.Module): def __init__(self, num_time_mask=2, num_freq_mask=2, freq_mask_length=15, time_mask_length=15, feature_dim=320, model_size=512, feed_forward_size=1024, hidden_size=64, dropout=0.1, num_head=8, num_encoder_layer=6, num_decoder_layer=6, vocab_path='testing_vocab.model', max_feature_length=1024, max_token_length=50, enable_spec_augment=True, share_weight=True, smoothing=0.1, restrict_left_length=20, restrict_right_length=20, mtlalpha=0.2, report_wer=True): super(Transformer, self).__init__() self.enable_spec_augment = enable_spec_augment self.max_token_length = max_token_length self.restrict_left_length = restrict_left_length self.restrict_right_length = restrict_right_length self.vocab = Vocab(vocab_path) self.sos = self.vocab.bos_id self.eos = self.vocab.eos_id self.adim = model_size self.odim = self.vocab.vocab_size self.ignore_id = self.vocab.pad_id if enable_spec_augment: self.spec_augment = SpecAugment( num_time_mask=num_time_mask, num_freq_mask=num_freq_mask, freq_mask_length=freq_mask_length, time_mask_length=time_mask_length, max_sequence_length=max_feature_length) self.encoder = Encoder(idim=feature_dim, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_encoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, attention_dropout_rate=dropout, input_layer='linear', padding_idx=self.vocab.pad_id) self.decoder = Decoder(odim=self.vocab.vocab_size, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_decoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, self_attention_dropout_rate=dropout, src_attention_dropout_rate=0, input_layer='embed', use_output_layer=False) self.decoder_linear = t.nn.Linear(model_size, self.vocab.vocab_size, bias=True) self.decoder_switch_linear = t.nn.Linear(model_size, 4, bias=True) self.criterion = LabelSmoothingLoss(size=self.odim, smoothing=smoothing, padding_idx=self.vocab.pad_id, normalize_length=True) self.switch_criterion = LabelSmoothingLoss( size=4, smoothing=0, padding_idx=self.vocab.pad_id, normalize_length=True) self.mtlalpha = mtlalpha if mtlalpha > 0.0: self.ctc = CTC(self.odim, eprojs=self.adim, dropout_rate=dropout, ctc_type='builtin', reduce=False) else: self.ctc = None if report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator def load_token_list(path=vocab_path.replace('.model', '.vocab')): with open(path) as reader: data = reader.readlines() data = [i.split('\t')[0] for i in data] return data self.char_list = load_token_list() self.error_calculator = ErrorCalculator( char_list=self.char_list, sym_space=' ', sym_blank=self.vocab.blank_token, report_wer=True) else: self.error_calculator = None self.rnnlm = None self.reporter = Reporter() self.switch_loss = LabelSmoothingLoss(size=4, smoothing=0, padding_idx=0) print('initing') initialize(self, init_type='xavier_normal') print('inited') def build_sample_data(self, feature_dim=320, cuda=False): feature = t.randn((2, 120, feature_dim)) feature_length = t.LongTensor([i for i in range(119, 121)]) target = t.LongTensor([[1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 0]]) target_length = t.LongTensor([7, 6]) if cuda: return feature.cuda(), feature_length.cuda(), target.cuda( ), target_length.cuda() else: return feature, feature_length, target, target_length def get_switch_target(self, ys_out_pad): switch = t.ones_like(ys_out_pad, device=ys_out_pad.device).long() # eng = 1 switch.masked_fill_(ys_out_pad.eq(0), 0) # pad=0 switch.masked_fill_((ys_out_pad.ge(12) & ys_out_pad.le(4211)), 2) # ch = 2 switch.masked_fill_((ys_out_pad.ge(1) & ys_out_pad.le(10)), 3) # other = 3 return switch def forward(self, xs_pad, ilens, ys_pad, ys_length): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) if self.spec_augment is not None: if self.training: xs_pad = self.spec_augment(xs_pad, ilens) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, ys_length, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) switch_pred_pad = self.decoder_switch_linear(pred_pad) pred_pad = t.nn.functional.log_softmax(self.decoder_linear(pred_pad), -1) # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) ys_switch_pad = self.get_switch_target(ys_out_pad) loss_switch = self.switch_criterion(switch_pred_pad, ys_switch_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) self.cer_ctc = float(cer_ctc) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + ( 1 - 0.1 - alpha) * loss_att + loss_switch * 0.1 loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) self.loss_att = loss_att_data self.loss_ctc = loss_ctc_data self.loss_switch = float(loss_switch) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): pass # self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) enc_output, _ = self.encoder(x, None) return enc_output.squeeze(0) def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) print('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) print('max output length: ' + str(maxlen)) print('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): print('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] local_att_scores = t.nn.functional.log_softmax( self.decoder_linear(local_att_scores), -1) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]] local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept print('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: print('best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: print('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: print('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: print('remeined hypothes: ' + str(len(hyps))) else: print('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: print('hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) print('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: print( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) print('total log probability: ' + str(nbest_hyps[0]['score'])) print('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret
class E2E(ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument( "--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal", ], help="how to initialize transformer parameters", ) group.add_argument( "--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help="transformer input layer type", ) group.add_argument( "--transformer-attn-dropout-rate", default=None, type=float, help= "dropout in transformer attention. use --dropout-rate if None is set", ) group.add_argument( "--transformer-lr", default=10.0, type=float, help="Initial value of learning rate", ) group.add_argument( "--transformer-warmup-steps", default=25000, type=int, help="optimizer warmup steps", ) group.add_argument( "--transformer-length-normalized-loss", default=True, type=strtobool, help="normalize loss by length", ) group.add_argument( "--transformer-encoder-selfattn-layer-type", type=str, default="selfattn", choices=[ "selfattn", "lightconv", "lightconv2d", "dynamicconv", "dynamicconv2d", "light-dynamicconv2d", ], help="transformer encoder self-attention layer type", ) group.add_argument( "--transformer-decoder-selfattn-layer-type", type=str, default="selfattn", choices=[ "selfattn", "lightconv", "lightconv2d", "dynamicconv", "dynamicconv2d", "light-dynamicconv2d", ], help="transformer decoder self-attention layer type", ) # Lightweight/Dynamic convolution related parameters. # See https://arxiv.org/abs/1912.11793v2 # and https://arxiv.org/abs/1901.10430 for detail of the method. # Configurations used in the first paper are in # egs/{csj, librispeech}/asr1/conf/tuning/ld_conv/ parser.add_argument( "--wshare", default=4, type=int, help="Number of parameter shargin for lightweight convolution", ) parser.add_argument( "--ldconv-encoder-kernel-length", default="21_23_25_27_29_31_33_35_37_39_41_43", type=str, help="kernel size for lightweight/dynamic convolution: " 'Encoder side. For example, "21_23_25" means kernel length 21 for ' "First layer, 23 for Second layer and so on.", ) parser.add_argument( "--ldconv-decoder-kernel-length", default="11_13_15_17_19_21", type=str, help="kernel size for lightweight/dynamic convolution: " 'Decoder side. For example, "21_23_25" means kernel length 21 for ' "First layer, 23 for Second layer and so on.", ) parser.add_argument( "--ldconv-usebias", type=strtobool, default=False, help="use bias term in lightweight/dynamic convolution", ) group.add_argument( "--dropout-rate", default=0.0, type=float, help="Dropout rate for the encoder", ) # Encoder group.add_argument( "--elayers", default=4, type=int, help="Number of encoder layers (for shared recognition part " "in multi-speaker asr mode)", ) group.add_argument( "--eunits", "-u", default=300, type=int, help="Number of encoder hidden units", ) # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') group.add_argument( '--transformer-enc-attn-type', default='self_attn', type=str, choices=[ 'self_attn', 'self_attn2', 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_fixed_span', 'self_attn_adaptive_span2', 'self_attn_dynamic_span2', 'self_attn_fixed_span2', 'self_attn_fixed_span3' ], help='encoder self-attention type') group.add_argument( '--transformer-dec-attn-type', default='self_attn', type=str, choices=[ 'self_attn', 'self_attn2', 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_fixed_span', 'self_attn_adaptive_span2', 'self_attn_dynamic_span2', 'self_attn_fixed_span2' ], help='decoder self-attention type') group.add_argument( '--enc-max-attn-span', default=[50], type=int, nargs='*', help= "Max dynamic/adaptive span (window) size for self-attention of encoder" ) group.add_argument( '--dec-max-attn-span', default=[50], type=int, nargs='*', help= "Max dynamic/adaptive span (window) size for self-attention of decoder" ) group.add_argument('--span-init', default=0, type=float, help="Dynamic/adaptive span (window) initial value") group.add_argument('--span-ratio', default=0.5, type=float, help="Dynamic/adaptive span (window) left ratio") group.add_argument('--ratio-adaptive', default=False, type=strtobool, help="Adaptive span ratio.") group.add_argument( '--span-loss-coef', default=None, type=float, help="The coefficient for computing the loss of spanned attention." ) # Decoder group.add_argument("--dlayers", default=1, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=320, type=int, help="Number of decoder hidden units") return parser @property def attention_plot_class(self): """Return PlotAttentionReport.""" return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_enc_attn_type', 'self_attn'), max_attn_span=getattr(args, 'enc_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_dec_attn_type', 'self_attn'), max_attn_span=getattr(args, 'dec_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None self.attention_enc_type = getattr(args, 'transformer_enc_attn_type', 'self_attn') self.attention_dec_type = getattr(args, 'transformer_dec_attn_type', 'self_attn') self.span_loss_coef = getattr(args, 'span_loss_coef', None) self.ratio_adaptive = getattr(args, 'ratio_adaptive', None) self.sym_blank = args.sym_blank def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float ''' """ if self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: for layer in self.encoder.encoders: layer.self_attn.clamp_param() if self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: for layer in self.decoder.decoders: layer.self_attn.clamp_param() # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) # xkc09 Span attention loss computation # xkc09 Span attention size loss computation loss_span = 0 if self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_dynamic_span2' ]: loss_span += sum([ layer.self_attn.get_mean_span() for layer in self.encoder.encoders ]) if self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_dynamic_span2' ]: loss_span += sum([ layer.self_attn.get_mean_span() for layer in self.decoder.decoders ]) # xkc09 Span attention ratio loss computation loss_ratio = 0 if self.ratio_adaptive: # target_ratio = 0.5 if self.attention_enc_type in [ 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: loss_ratio += sum([ 1 - layer.self_attn.get_mean_ratio() for layer in self.encoder.encoders ]) if self.attention_dec_type in [ 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]: loss_ratio += sum([ 1 - layer.self_attn.get_mean_ratio() for layer in self.decoder.decoders ]) if (self.attention_enc_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ] or self.attention_dec_type in [ 'self_attn_dynamic_span', 'self_attn_adaptive_span', 'self_attn_adaptive_span2', 'self_attn_fixed_span2', 'self_attn_dynamic_span2' ]): if getattr(self, 'span_loss_coef', None): self.loss += (loss_span + loss_ratio) * self.span_loss_coef loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0) enc_output, _ = self.encoder(x, None) return enc_output.squeeze(0) def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info("input lengths: " + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info("max output length: " + str(maxlen)) logging.info("min output length: " + str(minlen)) # initialize hypothesis if rnnlm: hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} else: hyp = {"score": 0.0, "yseq": [y]} if lpz is not None: ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() hyp["ctc_score_prev"] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug("position " + str(i)) hyps_best_kept = [] for hyp in hyps: vy[0] = hyp["yseq"][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp["yseq"]).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp["rnnlm_prev"], vy) local_scores = (local_att_scores + recog_args.lm_weight * local_lm_scores) else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"]) local_scores = ( 1.0 - ctc_weight) * local_att_scores[:, local_best_ids[ 0]] + ctc_weight * torch.from_numpy( ctc_scores - hyp["ctc_score_prev"]) if rnnlm: local_scores += (recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]) local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + float( local_best_scores[0, j]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) if rnnlm: new_hyp["rnnlm_prev"] = rnnlm_state if lpz is not None: new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[ 0, j]] new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug("number of pruned hypothes: " + str(len(hyps))) if char_list is not None: logging.debug( "best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info("adding <eos> in the last postion in the loop") for hyp in hyps: hyp["yseq"].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp["yseq"]) > minlen: hyp["score"] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp["score"] += recog_args.lm_weight * rnnlm.final( hyp["rnnlm_prev"]) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info("end detected at %d", i) break hyps = remained_hyps if len(hyps) > 0: logging.debug("remeined hypothes: " + str(len(hyps))) else: logging.info("no hypothesis. Finish decoding.") break if char_list is not None: for hyp in hyps: logging.debug( "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])) logging.debug("number of ended hypothes: " + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning("there is no N-best results, perform recognition " "again with smaller minlenratio.") # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info("total log probability: " + str(nbest_hyps[0]["score"])) logging.info("normalized log probability: " + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) or isinstance( m, DynamicConvolution): ret[name] = m.attn.cpu().numpy() if isinstance(m, DynamicConvolution2D): ret[name + "_time"] = m.attn_t.cpu().numpy() ret[name + "_freq"] = m.attn_f.cpu().numpy() return ret def get_ctc_alignments(self, x, y, char_list): """E2E get alignments of CTC :param torch.Tensor x: input acoustic feature (T, D) :param torch.Tensor y: id sequence tensor (L) :param list char_list: list of characters :return: best alignment results :rtype: list """ def interpolate_blank(l, blank_id=0): l = np.expand_dims(l, 1) zero = np.zeros((l.shape[0], 1), dtype=np.int64) l = np.concatenate([zero, l], axis=1) l = l.reshape(-1) l = np.append(l, l[0]) return l h = self.encode(x).unsqueeze(0) lpc = self.ctc.log_softmax(h)[0] blank_id = char_list.index(self.sym_blank) y_int = interpolate_blank(y, blank_id) logdelta = np.zeros( (lpc.size(0), len(y_int))) - 100000000000.0 # log of zero state_path = np.zeros( (lpc.size(0), len(y_int)), dtype=np.int16) - 1 # state path logdelta[0, 0] = lpc[0][y_int[0]] logdelta[0, 1] = lpc[0][y_int[1]] for t in range(1, lpc.size(0)): for s in range(len(y_int)): if (y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]): candidates = np.array( [logdelta[t - 1, s], logdelta[t - 1, s - 1]]) prev_state = [s, s - 1] else: candidates = np.array([ logdelta[t - 1, s], logdelta[t - 1, s - 1], logdelta[t - 1, s - 2] ]) prev_state = [s, s - 1, s - 2] logdelta[t, s] = np.max(candidates) + lpc[t][y_int[s]] state_path[t, s] = prev_state[np.argmax(candidates)] state_seq = -1 * np.ones((lpc.size(0), 1), dtype=np.int16) candidates = np.array( [logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]]) prev_state = [len(y_int) - 1, len(y_int) - 2] state_seq[-1] = prev_state[np.argmax(candidates)] for t in range(lpc.size(0) - 2, -1, -1): state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] output_state_seq = [] for t in range(0, lpc.size(0)): output_state_seq.append(y_int[state_seq[t, 0]]) # orig_seq = [] # for t in range(0, len(y)): # orig_seq.append(char_list[y[t]]) return output_state_seq
class E2E(ASRInterface, torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help='transformer input layer type') group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') return parser @property def attention_plot_class(self): return PlotAttentionReport def __init__(self, idim, odim, args, ignore_id=-1, asr_model=None, mt_model=None): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None def reset_parameters(self, args): # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad, ys_pad_asr=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats cer_ctc = None if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) if self.error_calculator is not None: ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) # 5. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: ys_hat = pred_pad.argmax(dim=-1) cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, feat): """Encode acoustic features.""" self.eval() feat = torch.as_tensor(feat).unsqueeze(0) enc_output, _ = self.encoder(feat, None) return enc_output.squeeze(0) def recognize(self, feat, recog_args, char_list=None, rnnlm=None, use_jit=False): """recognize feat. :param ndnarray x: input acouctic feature (B, T, D) or (T, D) :param namespace recog_args: argment namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list TODO(karita): do not recompute previous attention for faster decoding """ enc_output = self.encode(feat).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0) ys = torch.tensor(hyp['yseq']).unsqueeze(0) # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]] local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(feat, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_asr=None): """E2E attention calculation. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): self.forward(xs_pad, ilens, ys_pad) ret = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): ret[name] = m.attn.cpu().numpy() return ret ##################### CPD modified ##################### def calculate_alignments(self, xs_pad, ilens, ys_pad): att_dict = self.calculate_all_attentions(xs_pad, ilens, ys_pad) att_ws = np.concatenate( [att_dict[k] for k in att_dict.keys() if "src_attn" in k], axis=1) if not hasattr(self, "diag_head_idx"): diagnal_scores = att_ws.max(axis=-2).mean(axis=-1).mean(axis=0) self.diag_head_idx = diagnal_scores.argmax() logging.info("Using src_attn head id: " + str(self.diag_head_idx)) alignments = [] olens = [y[y != self.ignore_id].size(0) for y in ys_pad] if isinstance(self.encoder.embed, Conv2dSubsampling): for i in range(len(att_ws)): att_w = att_ws[i][self.diag_head_idx] num_frames = ilens[i] num_tokens = olens[i] alignments.append( np.transpose(att_w[:num_tokens + 1, :((num_frames - 1) // 2 - 1) // 2])) else: raise (NotImplementedError("Only Conv2dSubsampling is supported.")) return alignments