class E2E(torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed", "custom"], help='transformer input layer type') group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') return parser def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None self.rnnlm = None def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] self.hs_pad = hs_pad cer_ctc = None batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask start_time = time.time() # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) return self.loss, loss_ctc_data, loss_att_data, self.acc def encode(self, x, mask=None): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0).cuda() if mask is not None: mask = mask.cuda() if isinstance(self.encoder.embed, EncoderConv2d): hs, _ = self.encoder.embed( x, torch.Tensor([float(x.shape[1])]).cuda()) else: hs, _ = self.encoder.embed(x, None) enc_output, _ = self.encoder.encoders(hs, mask) if self.encoder.normalize_before: enc_output = self.encoder.after_norm(enc_output) return enc_output.squeeze(0) def viterbi_decode(self, x, y, mask=None): enc_output = self.encode(x) logits = self.ctc.ctc_lo(enc_output).detach().data logit = np.array(logits.cpu().data).T align = viterbi_align(logit, y)[0] return align def ctc_decode(self, x, mask=None): enc_output = self.encode(x) logits = self.ctc.argmax(enc_output.view(1, -1, 512)).detach().data path = np.array(logits.cpu()[0]) return path def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.cpu().detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0).cuda() ys = torch.tensor(hyp['yseq']).unsqueeze(0).cuda() # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0].cpu(), hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]].cpu() \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]].cpu() local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps
class E2E(torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed", "custom"], help='transformer input layer type') group.add_argument("--transformer-output-layer", type=str, default='embed', choices=['conv', 'embed', 'linear']) group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') # Streaming params group.add_argument( '--chunk', default=True, type=strtobool, help= 'streaming mode, set True for chunk-encoder, False for look-ahead encoder' ) group.add_argument('--chunk-size', default=16, type=int, help='chunk size for chunk-based encoder') group.add_argument( '--left-window', default=1000, type=int, help='left window size for look-ahead based encoder') group.add_argument( '--right-window', default=1000, type=int, help='right window size for look-ahead based encoder') group.add_argument( '--dec-left-window', default=0, type=int, help='left window size for decoder (look-ahead based method)') group.add_argument( '--dec-right-window', default=6, type=int, help='right window size for decoder (look-ahead based method)') return parser def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=args.transformer_output_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None self.rnnlm = None self.left_window = args.dec_left_window self.right_window = args.dec_right_window def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel batch_size = xs_pad.shape[0] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) if isinstance(self.encoder.embed, EncoderConv2d): xs, hs_mask = self.encoder.embed(xs_pad, torch.sum(src_mask, 2).squeeze()) hs_mask = hs_mask.unsqueeze(1) else: xs, hs_mask = self.encoder.embed(xs_pad, src_mask) if enc_mask is not None: enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]] enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask hs_pad, _ = self.encoder.encoders(xs, enc_mask) if self.encoder.normalize_before: hs_pad = self.encoder.after_norm(hs_pad) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] if dec_mask is not None: dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]] self.hs_pad = hs_pad batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_ctc_data, loss_att_data, self.acc def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x, mask=None): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0).cuda() if mask is not None: mask = mask.cuda() if isinstance(self.encoder.embed, EncoderConv2d): hs, _ = self.encoder.embed( x, torch.Tensor([float(x.shape[1])]).cuda()) else: hs, _ = self.encoder.embed(x, None) hs, _ = self.encoder.encoders(hs, mask) if self.encoder.normalize_before: hs = self.encoder.after_norm(hs) return hs.squeeze(0) def viterbi_decode(self, x, y, mask=None): enc_output = self.encode(x, mask) logits = self.ctc.ctc_lo(enc_output).detach().data logit = np.array(logits.cpu().data).T align = viterbi_align(logit, y)[0] return align def ctc_decode(self, x, mask=None): enc_output = self.encode(x, mask) logits = self.ctc.argmax(enc_output.view(1, -1, 512)).detach().data path = np.array(logits.cpu()[0]) return path def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.cpu().detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0).cuda() ys = torch.tensor(hyp['yseq']).unsqueeze(0).cuda() # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0].cpu(), hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]].cpu() \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]].cpu() local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def prefix_recognize(self, x, recog_args, train_args, char_list=None, rnnlm=None): '''recognize feat :param ndnarray x: input acouctic feature (B, T, D) or (T, D) :param namespace recog_args: argment namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list TODO(karita): do not recompute previous attention for faster decoding ''' pad_len = self.eos - len(char_list) + 1 for i in range(pad_len): char_list.append('<eos>') if isinstance(self.encoder.embed, EncoderConv2d): seq_len = ((x.shape[0] + 1) // 2 + 1) // 2 else: seq_len = ((x.shape[0] - 1) // 2 - 1) // 2 if train_args.chunk: s = np.arange(0, seq_len, train_args.chunk_size) mask = adaptive_enc_mask(seq_len, s).unsqueeze(0) else: mask = turncated_mask(1, seq_len, train_args.left_window, train_args.right_window) enc_output = self.encode(x, mask).unsqueeze(0) lpz = torch.nn.functional.softmax(self.ctc.ctc_lo(enc_output), dim=-1) lpz = lpz.squeeze(0) h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) h_len = h.size(0) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) hyp = { 'score': 0.0, 'yseq': [y], 'rnnlm_prev': None, 'seq': char_list[y], 'last_time': [], "ctc_score": 0.0, "rnnlm_score": 0.0, "att_score": 0.0, "cache": None, "precache": None, "preatt_score": 0.0, "prev_score": 0.0 } hyps = {char_list[y]: hyp} hyps_att = {char_list[y]: hyp} Pb_prev, Pnb_prev = Counter(), Counter() Pb, Pnb = Counter(), Counter() Pjoint = Counter() lpz = lpz.cpu().detach().numpy() vocab_size = lpz.shape[1] r = np.ndarray((vocab_size), dtype=np.float32) l = char_list[y] Pb_prev[l] = 1 Pnb_prev[l] = 0 A_prev = [l] A_prev_id = [[y]] vy.unsqueeze(1) total_copy = time.time() - time.time() samelen = 0 hat_att = {} if mask is not None: chunk_pos = set(np.array(mask.sum(dim=-1))[0]) for i in chunk_pos: hat_att[i] = {} else: hat_att[enc_output.shape[1]] = {} for i in range(h_len): hyps_ctc = {} threshold = recog_args.threshold #self.threshold #np.percentile(r, 98) pos_ctc = np.where(lpz[i] > threshold)[0] #self.removeIlegal(hyps) if mask is not None: chunk_index = mask[0][i].sum().item() else: chunk_index = h_len hyps_res = {} for l, hyp in hyps.items(): if l in hat_att[chunk_index]: hyp['tmp_cache'] = hat_att[chunk_index][l]['cache'] hyp['tmp_att'] = hat_att[chunk_index][l]['att_scores'] else: hyps_res[l] = hyp tmp = self.clusterbyLength( hyps_res ) # This step clusters hyps according to length dict:{length,hyps} start = time.time() # pre-compute beam self.compute_hyps(tmp, i, h_len, enc_output, hat_att[chunk_index], mask) total_copy += time.time() - start # Assign score and tokens to hyps #print(hyps.keys()) for l, hyp in hyps.items(): if 'tmp_att' not in hyp: continue #Todo check why local_att_scores = hyp['tmp_att'] local_best_scores, local_best_ids = torch.topk( local_att_scores, 5, dim=1) pos_att = np.array(local_best_ids[0].cpu()) pos = np.union1d(pos_ctc, pos_att) hyp['pos'] = pos # pre-compute ctc beam hyps_ctc_compute = self.get_ctchyps2compute(hyps, hyps_ctc, i) hyps_res2 = {} for l, hyp in hyps_ctc_compute.items(): l_minus = ' '.join(l.split()[:-1]) if l_minus in hat_att[chunk_index]: hyp['tmp_cur_new_cache'] = hat_att[chunk_index][l_minus][ 'cache'] hyp['tmp_cur_att_scores'] = hat_att[chunk_index][l_minus][ 'att_scores'] else: hyps_res2[l] = hyp tmp2_cluster = self.clusterbyLength(hyps_res2) self.compute_hyps_ctc(tmp2_cluster, h_len, enc_output, hat_att[chunk_index], mask) for l, hyp in hyps.items(): start = time.time() l_id = hyp['yseq'] l_end = l_id[-1] vy[0] = l_end prefix_len = len(l_id) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) else: rnnlm_state = None local_lm_scores = torch.zeros([1, len(char_list)]) r = lpz[i] * (Pb_prev[l] + Pnb_prev[l]) start = time.time() if 'tmp_att' not in hyp: continue #Todo check why local_att_scores = hyp['tmp_att'] new_cache = hyp['tmp_cache'] align = [0] * prefix_len align[:prefix_len - 1] = hyp['last_time'][:] align[-1] = i pos = hyp['pos'] if 0 in pos or l_end in pos: if l not in hyps_ctc: hyps_ctc[l] = {'yseq': l_id} hyps_ctc[l]['rnnlm_prev'] = hyp['rnnlm_prev'] hyps_ctc[l]['rnnlm_score'] = hyp['rnnlm_score'] if l_end != self.eos: hyps_ctc[l]['last_time'] = [0] * prefix_len hyps_ctc[l]['last_time'][:] = hyp['last_time'][:] hyps_ctc[l]['last_time'][-1] = i try: cur_att_scores = hyps_ctc_compute[l][ "tmp_cur_att_scores"] cur_new_cache = hyps_ctc_compute[l][ "tmp_cur_new_cache"] except: pdb.set_trace() hyps_ctc[l]['att_score'] = hyp['preatt_score'] + \ float(cur_att_scores[0, l_end].data) hyps_ctc[l]['cur_att'] = float( cur_att_scores[0, l_end].data) hyps_ctc[l]['cache'] = cur_new_cache else: if len(hyps_ctc[l]["yseq"]) > 1: hyps_ctc[l]["end"] = True hyps_ctc[l]['last_time'] = [] hyps_ctc[l]['att_score'] = hyp['att_score'] hyps_ctc[l]['cur_att'] = 0 hyps_ctc[l]['cache'] = hyp['cache'] hyps_ctc[l]['prev_score'] = hyp['prev_score'] hyps_ctc[l]['preatt_score'] = hyp['preatt_score'] hyps_ctc[l]['precache'] = hyp['precache'] hyps_ctc[l]['seq'] = hyp['seq'] for c in list(pos): if c == 0: Pb[l] += lpz[i][0] * (Pb_prev[l] + Pnb_prev[l]) else: l_plus = l + " " + char_list[c] if l_plus not in hyps_ctc: hyps_ctc[l_plus] = {} if "end" in hyp: hyps_ctc[l_plus]['yseq'] = True hyps_ctc[l_plus]['yseq'] = [0] * (prefix_len + 1) hyps_ctc[l_plus]['yseq'][:len(hyp['yseq'])] = l_id hyps_ctc[l_plus]['yseq'][-1] = int(c) hyps_ctc[l_plus]['rnnlm_prev'] = rnnlm_state hyps_ctc[l_plus][ 'rnnlm_score'] = hyp['rnnlm_score'] + float( local_lm_scores[0, c].data) hyps_ctc[l_plus]['att_score'] = hyp['att_score'] \ + float(local_att_scores[0, c].data) hyps_ctc[l_plus]['cur_att'] = float( local_att_scores[0, c].data) hyps_ctc[l_plus]['cache'] = new_cache hyps_ctc[l_plus]['precache'] = hyp['cache'] hyps_ctc[l_plus]['preatt_score'] = hyp['att_score'] hyps_ctc[l_plus]['prev_score'] = hyp['score'] hyps_ctc[l_plus]['last_time'] = align hyps_ctc[l_plus]['rule_penalty'] = 0 hyps_ctc[l_plus]['seq'] = l_plus if l_end != self.eos and c == l_end: Pnb[l_plus] += lpz[i][l_end] * Pb_prev[l] Pnb[l] += lpz[i][l_end] * Pnb_prev[l] else: Pnb[l_plus] += r[c] if l_plus not in hyps: Pb[l_plus] += lpz[i][0] * (Pb_prev[l_plus] + Pnb_prev[l_plus]) Pb[l_plus] += lpz[i][c] * Pnb_prev[l_plus] #total_copy += time.time() - start for l in hyps_ctc.keys(): if Pb[l] != 0 or Pnb[l] != 0: hyps_ctc[l]['ctc_score'] = np.log(Pb[l] + Pnb[l]) else: hyps_ctc[l]['ctc_score'] = float('-inf') local_score = hyps_ctc[l]['ctc_score'] + recog_args.ctc_lm_weight * hyps_ctc[l]['rnnlm_score'] + \ recog_args.penalty * (len(hyps_ctc[l]['yseq'])) hyps_ctc[l]['local_score'] = local_score hyps_ctc[l]['score'] = (1-recog_args.ctc_weight) * hyps_ctc[l]['att_score'] \ + recog_args.ctc_weight * hyps_ctc[l]['ctc_score'] + \ recog_args.penalty * (len(hyps_ctc[l]['yseq'])) + \ recog_args.lm_weight * hyps_ctc[l]['rnnlm_score'] Pb_prev = Pb Pnb_prev = Pnb Pb = Counter() Pnb = Counter() hyps1 = sorted(hyps_ctc.items(), key=lambda x: x[1]['local_score'], reverse=True)[:beam] hyps1 = dict(hyps1) hyps2 = sorted(hyps_ctc.items(), key=lambda x: x[1]['att_score'], reverse=True)[:beam] hyps2 = dict(hyps2) hyps = sorted(hyps_ctc.items(), key=lambda x: x[1]['score'], reverse=True)[:beam] hyps = dict(hyps) for key in hyps1.keys(): if key not in hyps: hyps[key] = hyps1[key] for key in hyps2.keys(): if key not in hyps: hyps[key] = hyps2[key] hyps = sorted(hyps.items(), key=lambda x: x[1]['score'], reverse=True)[:beam] hyps = dict(hyps) logging.info('input lengths: ' + str(h.size(0))) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) if "<eos>" in hyps.keys(): del hyps["<eos>"] #for key in hyps.keys(): # logging.info("{0}\tctc:{1}\tatt:{2}\trnnlm:{3}\tscore:{4}".format(key,hyps[key]["ctc_score"],hyps[key]['att_score'], # hyps[key]['rnnlm_score'], hyps[key]['score'])) # print("!!!","Decoding None") best = list(hyps.keys())[0] ids = hyps[best]['yseq'] score = hyps[best]['score'] logging.info('score: ' + str(score)) #if l in hyps.keys(): # logging.info(l) #print(samelen,h_len) return best, ids, score def removeIlegal(self, hyps): max_y = max([len(hyp['yseq']) for l, hyp in hyps.items()]) for_remove = [] for l, hyp in hyps.items(): if max_y - len(hyp['yseq']) > 4: for_remove.append(l) for cur_str in for_remove: del hyps[cur_str] def clusterbyLength(self, hyps): tmp = {} for l, hyp in hyps.items(): prefix_len = len(hyp['yseq']) if prefix_len > 1 and hyp['yseq'][-1] == self.eos: continue else: if prefix_len not in tmp: tmp[prefix_len] = [] tmp[prefix_len].append(hyp) return tmp def compute_hyps(self, current_hyps, curren_frame, total_frame, enc_output, hat_att, enc_mask=None): for length, hyps_t in current_hyps.items(): ys_mask = subsequent_mask(length).unsqueeze(0).cuda() ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1) # print(ys_mask4use.shape) l_id = [hyp_t['yseq'] for hyp_t in hyps_t] ys4use = torch.tensor(l_id).cuda() enc_output4use = enc_output.repeat(len(hyps_t), 1, 1) if hyps_t[0]["cache"] is None: cache4use = None else: cache4use = [] for decode_num in range(len(hyps_t[0]["cache"])): current_cache = [] for hyp_t in hyps_t: current_cache.append( hyp_t["cache"][decode_num].squeeze(0)) # print( torch.stack(current_cache).shape) current_cache = torch.stack(current_cache) cache4use.append(current_cache) partial_mask4use = [] for hyp_t in hyps_t: #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time'])+1, enc_mask.shape[1]]).byte()) align = [0] * length align[:length - 1] = hyp_t['last_time'][:] align[-1] = curren_frame align_tensor = torch.tensor(align).unsqueeze(0) if enc_mask is not None: partial_mask = enc_mask[0][align_tensor] else: right_window = self.right_window partial_mask = trigger_mask(1, total_frame, align_tensor, self.left_window, right_window) partial_mask4use.append(partial_mask) partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1) local_att_scores_b, new_cache_b = self.decoder.forward_one_step( ys4use, ys_mask4use, enc_output4use, partial_mask4use, cache4use) for idx, hyp_t in enumerate(hyps_t): hyp_t['tmp_cache'] = [ new_cache_b[decode_num][idx].unsqueeze(0) for decode_num in range(len(new_cache_b)) ] hyp_t['tmp_att'] = local_att_scores_b[idx].unsqueeze(0) hat_att[hyp_t['seq']] = {} hat_att[hyp_t['seq']]['cache'] = hyp_t['tmp_cache'] hat_att[hyp_t['seq']]['att_scores'] = hyp_t['tmp_att'] def get_ctchyps2compute(self, hyps, hyps_ctc, current_frame): tmp2 = {} for l, hyp in hyps.items(): l_id = hyp['yseq'] l_end = l_id[-1] if "pos" not in hyp: continue if 0 in hyp['pos'] or l_end in hyp['pos']: #l_minus = ' '.join(l.split()[:-1]) #if l_minus in hat_att: # hyps[l]['tmp_cur_new_cache'] = hat_att[l_minus]['cache'] # hyps[l]['tmp_cur_att_scores'] = hat_att[l_minus]['att_scores'] # continue if l not in hyps_ctc and l_end != self.eos: tmp2[l] = {'yseq': l_id} tmp2[l]['seq'] = l tmp2[l]['rnnlm_prev'] = hyp['rnnlm_prev'] tmp2[l]['rnnlm_score'] = hyp['rnnlm_score'] if l_end != self.eos: tmp2[l]['last_time'] = [0] * len(l_id) tmp2[l]['last_time'][:] = hyp['last_time'][:] tmp2[l]['last_time'][-1] = current_frame return tmp2 def compute_hyps_ctc(self, hyps_ctc_cluster, total_frame, enc_output, hat_att, enc_mask=None): for length, hyps_t in hyps_ctc_cluster.items(): ys_mask = subsequent_mask(length - 1).unsqueeze(0).cuda() ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1) l_id = [hyp_t['yseq'][:-1] for hyp_t in hyps_t] ys4use = torch.tensor(l_id).cuda() enc_output4use = enc_output.repeat(len(hyps_t), 1, 1) if "precache" not in hyps_t[0] or hyps_t[0]["precache"] is None: cache4use = None else: cache4use = [] for decode_num in range(len(hyps_t[0]["precache"])): current_cache = [] for hyp_t in hyps_t: # print(length, hyp_t["yseq"], hyp_t["cache"][0].shape, # hyp_t["cache"][2].shape, hyp_t["cache"][4].shape) current_cache.append( hyp_t["precache"][decode_num].squeeze(0)) current_cache = torch.stack(current_cache) cache4use.append(current_cache) partial_mask4use = [] for hyp_t in hyps_t: #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time']), enc_mask.shape[1]]).byte()) align = hyp_t['last_time'] align_tensor = torch.tensor(align).unsqueeze(0) if enc_mask is not None: partial_mask = enc_mask[0][align_tensor] else: right_window = self.right_window partial_mask = trigger_mask(1, total_frame, align_tensor, self.left_window, right_window) partial_mask4use.append(partial_mask) partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1) local_att_scores_b, new_cache_b = \ self.decoder.forward_one_step(ys4use, ys_mask4use, enc_output4use, partial_mask4use, cache4use) for idx, hyp_t in enumerate(hyps_t): hyp_t['tmp_cur_new_cache'] = [ new_cache_b[decode_num][idx].unsqueeze(0) for decode_num in range(len(new_cache_b)) ] hyp_t['tmp_cur_att_scores'] = local_att_scores_b[ idx].unsqueeze(0) l_minus = ' '.join(hyp_t['seq'].split()[:-1]) hat_att[l_minus] = {} hat_att[l_minus]['att_scores'] = hyp_t['tmp_cur_att_scores'] hat_att[l_minus]['cache'] = hyp_t['tmp_cur_new_cache']