def encode(self, xs, task='all', flip=False): """Encode acoustic or text features. Args: xs (list): A list of length `[B]`, which contains Tensor of size `[T, input_dim]` task (str): all or ys or ys_sub* flip (bool): if True, flip acoustic features in the time-dimension Returns: enc_outs (dict): perm_ids (): """ if 'lmobj' in task: eouts = {'ys': {'xs': None, 'xlens': None}, 'ys.ctc': {'xs': None, 'xlens': None}, 'ys_sub1': {'xs': None, 'xlens': None}, 'ys_sub1.ctc': {'xs': None, 'xlens': None}, 'ys_sub2': {'xs': None, 'xlens': None}, 'ys_sub2.ctc': {'xs': None, 'xlens': None}} return eouts, None else: # Sort by lenghts in the descending order perm_ids = sorted(list(range(0, len(xs), 1)), key=lambda i: len(xs[i]), reverse=True) xs = [xs[i] for i in perm_ids] # NOTE: must be descending order for pack_padded_sequence if self.input_type == 'speech': # Frame stacking if self.nstacks > 1: xs = [stack_frame(x, self.nstacks, self.nskips)for x in xs] # Splicing if self.nsplices > 1: xs = [splice(x, self.nsplices, self.nstacks) for x in xs] xlens = [len(x) for x in xs] # Flip acoustic features in the reverse order if flip: xs = [torch.from_numpy(np.flip(x, axis=0).copy()).float().cuda(self.device_id) for x in xs] else: xs = [np2tensor(x, self.device_id).float() for x in xs] xs = pad_list(xs) elif self.input_type == 'text': xlens = [len(x) for x in xs] xs = [np2tensor(np.fromiter(x, dtype=np.int64), self.device_id).long() for x in xs] xs = pad_list(xs, self.pad) xs = self.embed_in(xs) enc_outs = self.enc(xs, xlens, task) if self.main_weight < 1 and self.enc_type == 'cnn': for sub in ['sub1', 'sub2']: enc_outs['ys_' + sub]['xs'] = enc_outs['ys']['xs'].clone() enc_outs['ys_' + sub]['xlens'] = copy.deepcopy(enc_outs['ys']['xlens']) # Bridge between the encoder and decoder if self.main_weight > 0 and (self.enc_type == 'cnn' or self.bridge_layer) and (task in ['all', 'ys']): enc_outs['ys']['xs'] = self.bridge(enc_outs['ys']['xs']) if self.sub1_weight > 0 and (self.enc_type == 'cnn' or self.bridge_layer) and (task in ['all', 'ys_sub1']): enc_outs['ys_sub1']['xs'] = self.bridge_sub1(enc_outs['ys_sub1']['xs']) if self.sub2_weight > 0 and (self.enc_type == 'cnn' or self.bridge_layer)and (task in ['all', 'ys_sub2']): enc_outs['ys_sub2']['xs'] = self.bridge_sub2(enc_outs['ys_sub2']['xs']) return enc_outs, perm_ids
def beam_search(self, eouts, elens, params, idx2token, lm=None, lm_rev=None, ctc_log_probs=None, nbest=1, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, ensmbl_eouts=None, ensmbl_elens=None, ensmbl_decs=[]): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, dec_n_units]` elens (IntTensor): `[B]` params (dict): recog_beam_width (int): size of hyp recog_max_len_ratio (int): maximum sequence length of tokens recog_min_len_ratio (float): minimum sequence length of tokens recog_length_penalty (float): length penalty recog_coverage_penalty (float): coverage penalty recog_coverage_threshold (float): threshold for coverage penalty recog_lm_weight (float): weight of LM score recog_n_caches (int): idx2token (): converter from index to token lm (RNNLM or GatedConvLM or TransformerLM): lm_rev (RNNLM or GatedConvLM or TransformerLM): ctc_log_probs (FloatTensor): nbest (int): exclude_eos (bool): refs_id (list): utt_ids (list): speakers (list): ensmbl_eouts (list): list of FloatTensor ensmbl_elens (list) list of list ensmbl_decs (list): list of torch.nn.Module Returns: nbest_hyps_idx (list): A list of length `[B]`, which contains list of N hypotheses aws: dummy scores: dummy cache_info: dummy """ logger = logging.getLogger("decoding") bs = eouts.size(0) best_hyps = [] oracle = params['recog_oracle'] beam_width = params['recog_beam_width'] ctc_weight = params['recog_ctc_weight'] lm_weight = params['recog_lm_weight'] asr_state_carry_over = params['recog_asr_state_carry_over'] lm_state_carry_over = params['recog_lm_state_carry_over'] lm_usage = params['recog_lm_usage'] if lm is not None: lm.eval() for b in range(bs): # Initialization y = eouts.new_zeros(bs, 1).fill_(self.eos) dout, dstate = self.recurrency(self.embed(y), None) lmstate = None if lm_state_carry_over: lmstate = self.lmstate_final self.prev_spk = speakers[b] end_hyps = [] hyps = [{ 'hyp': [self.eos], 'lattice': [], 'ref_id': [self.eos], 'score': 0.0, 'score_lm': 0.0, 'score_ctc': 0.0, 'dout': dout, 'dstate': dstate, 'lmstate': lmstate, }] for t in range(elens[b]): new_hyps = [] for hyp in hyps: prev_idx = ([self.eos] + refs_id[b])[t] if oracle else hyp['hyp'][-1] score = hyp['score'] score_lm = hyp['score_lm'] dout = hyp['dout'] dstate = hyp['dstate'] lmstate = hyp['lmstate'] # Pick up the top-k scores out = self.joint(eouts[b:b + 1, t:t + 1], dout.squeeze(1)) log_probs = F.log_softmax(out.squeeze(2), dim=-1) log_probs_topk, topk_ids = torch.topk(log_probs[0, 0], k=min( beam_width, self.vocab), dim=-1, largest=True, sorted=True) for k in range(beam_width): idx = topk_ids[k].item() score += log_probs_topk[k].item() # Update prediction network only when predicting non-blank labels lattice = hyp['lattice'] + [idx] if idx == self.blank: hyp_id = hyp['hyp'] else: hyp_id = hyp['hyp'] + [idx] hyp_str = ' '.join(list(map(str, hyp_id[1:]))) if hyp_str in self.state_cache.keys(): # from cache dout = self.state_cache[hyp_str]['dout'] new_dstate = self.state_cache[hyp_str][ 'dstate'] else: if oracle: y = eouts.new_zeros(1, 1).fill_( refs_id[b, len(hyp_id) - 1]) else: y = eouts.new_zeros(1, 1).fill_(idx) dout, new_dstate = self.recurrency( self.embed(y), dstate) # Update LM states for shallow fusion if lm_weight > 0 and lm is not None: _, lmstate, lm_log_probs = lm.predict( eouts.new_zeros(1, 1).fill_(prev_idx), hyp['lmstate']) local_score_lm = lm_log_probs[0, idx].item() score_lm += local_score_lm * lm_weight score += local_score_lm * lm_weight # to cache self.state_cache[hyp_str] = { 'lattice': lattice, 'dout': dout, 'dstate': new_dstate, 'lmstate': lmstate, } new_hyps.append({ 'hyp': hyp_id, 'lattice': lattice, 'score': score, 'score_lm': score_lm, 'score_ctc': 0, # TODO(hirofumi): 'dout': dout, 'dstate': dstate if idx == self.blank else new_dstate, 'lmstate': lmstate, }) # Local pruning new_hyps_tmp = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width] # Remove complete hypotheses new_hyps = [] for hyp in new_hyps_tmp: if oracle: if t == len(refs_id[b]): end_hyps += [hyp] else: new_hyps += [hyp] else: if self.end_pointing and hyp['hyp'][-1] == self.eos: end_hyps += [hyp] else: new_hyps += [hyp] if len(end_hyps) >= beam_width: end_hyps = end_hyps[:beam_width] logger.info('End-pointed at %d / %d frames' % (t, elens[b])) break hyps = new_hyps[:] # Rescoing lattice if lm_weight > 0 and lm is not None and lm_usage == 'rescoring': new_hyps = [] for hyp in hyps: ys = [ np2tensor(np.fromiter(hyp['hyp'], dtype=np.int64), self.device_id) ] ys_pad = pad_list(ys, lm.pad) _, _, lm_log_probs = lm.predict(ys_pad, None) score_ctc = 0 # TODO(hirofumi): score_lm = lm_log_probs.sum() * lm_weight new_hyps.append({ 'hyp': hyp['hyp'], 'score': hyp['score'] + score_lm, 'score_ctc': score_ctc, 'score_lm': score_lm }) hyps = sorted(new_hyps, key=lambda x: x['score'], reverse=True) # Exclude <eos> if False and exclude_eos and self.end_pointing and hyps[0]['hyp'][ -1] == self.eos: best_hyps.append([hyps[0]['hyp'][1:-1]]) else: best_hyps.append([hyps[0]['hyp'][1:]]) # Reset state cache self.state_cache = OrderedDict() if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) if refs_id is not None and self.vocab == idx2token.vocab: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token(hyps[0]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % hyps[0]['score']) if ctc_weight > 0 and ctc_log_probs is not None: logger.info('log prob (hyp, ctc): %.7f' % (hyps[0]['score_ctc'])) # logger.info('log prob (lp): %.7f' % hyps[0]['score_lp']) if lm_weight > 0 and lm is not None: logger.info('log prob (hyp, lm): %.7f' % (hyps[0]['score_lm'])) return np.array(best_hyps), None, None, None
def test_forward_streaming_chunkwise(args): args = make_args(**args) assert args['chunk_size_left'] > 0 unidir = args['rnn_type'] in ['conv_lstm', 'conv_gru', 'lstm', 'gru'] batch_size = 1 xmaxs = [t for t in range(160, 192, 1)] device_id = -1 N_l = max(0, args['chunk_size_left']) // args['n_stacks'] N_r = max(0, args['chunk_size_right']) // args['n_stacks'] if unidir: args['chunk_size_left'] = 0 args['chunk_size_right'] = 0 module = importlib.import_module('neural_sp.models.seq2seq.encoders.rnn') enc = module.RNNEncoder(**args) factor = enc.subsampling_factor lookback = enc.conv.n_frames_context if enc.conv is not None else 0 lookahead = enc.conv.n_frames_context if enc.conv is not None else 0 module_fs = importlib.import_module( 'neural_sp.models.seq2seq.frontends.frame_stacking') if enc.conv is not None: enc.turn_off_ceil_mode(enc) enc.eval() with torch.no_grad(): for xmax in xmaxs: xs = np.random.randn(batch_size, xmax, args['input_dim']).astype(np.float32) if args['n_stacks'] > 1: xs = [ module_fs.stack_frame(x, args['n_stacks'], args['n_stacks']) for x in xs ] xlens = torch.IntTensor([len(x) for x in xs]) xmax = xlens.max().item() # all inputs xs_pad = pad_list([np2tensor(x, device_id).float() for x in xs], 0.) enc_out_dict = enc(xs_pad, xlens, task='all') assert enc_out_dict['ys']['xs'].size(0) == batch_size assert enc_out_dict['ys']['xs'].size( 1) == enc_out_dict['ys']['xlens'][0] enc.reset_cache() # chunk by chunk encoding eouts_stream = [] n_chunks = math.ceil(xmax / N_l) j = 0 # time offset for input j_out = 0 # time offset for encoder output for chunk_idx in range(n_chunks): start = j - lookback end = (j + N_l + N_r) + lookahead xs_pad_stream = pad_list([ np2tensor(x[max(0, start):end], device_id).float() for x in xs ], 0.) xlens_stream = torch.IntTensor( [xs_pad_stream.size(1) for x in xs]) enc_out_dict_stream = enc(xs_pad_stream, xlens_stream, task='all', streaming=True, lookback=start > 0, lookahead=end < xmax - 1) a = enc_out_dict['ys']['xs'][:, j_out:j_out + (N_l // factor)] b = enc_out_dict_stream['ys']['xs'] b = b[:, :a.size(1)] for t in range(a.size(1)): print(torch.equal(a[:, t], b[:, t])) eouts_stream.append(b) j += N_l j_out += (N_l // factor) if j > xmax: break enc.reset_cache() eouts_stream = torch.cat(eouts_stream, dim=1) assert enc_out_dict['ys']['xs'].size() == eouts_stream.size() assert torch.equal(enc_out_dict['ys']['xs'], eouts_stream)
def test_decoding(backward, params): args = make_args() params = make_decode_params(**params) params['backward'] = backward batch_size = params['recog_batch_size'] emax = 40 device = "cpu" eouts = np.random.randn(batch_size, emax, ENC_N_UNITS).astype(np.float32) elens = torch.IntTensor([len(x) for x in eouts]) eouts = pad_list([np2tensor(x, device).float() for x in eouts], 0.) ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_logits = torch.FloatTensor(batch_size, emax, VOCAB, device=device) ctc_log_probs = torch.softmax(ctc_logits, dim=-1) lm = None if params['recog_lm_weight'] > 0: args_lm = make_args_rnnlm() module = importlib.import_module('neural_sp.models.lm.rnnlm') lm = module.RNNLM(args_lm).to(device) lm_second = None if params['recog_lm_second_weight'] > 0: args_lm = make_args_rnnlm() module = importlib.import_module('neural_sp.models.lm.rnnlm') lm_second = module.RNNLM(args_lm).to(device) lm_second_bwd = None if params['recog_lm_bwd_weight'] > 0: args_lm = make_args_rnnlm() module = importlib.import_module('neural_sp.models.lm.rnnlm') lm_second_bwd = module.RNNLM(args_lm).to(device) ylens = [4, 5, 3, 7] ys = [np.random.randint(0, VOCAB, ylen).astype(np.int32) for ylen in ylens] module = importlib.import_module( 'neural_sp.models.seq2seq.decoders.transformer') dec = module.TransformerDecoder(**args) dec = dec.to(device) # TODO(hirofumi0810): # recog_lm_state_carry_over dec.eval() with torch.no_grad(): if params['recog_beam_width'] == 1: out = dec.greedy(eouts, elens, max_len_ratio=1.0, idx2token=None, exclude_eos=params['exclude_eos'], refs_id=ys, utt_ids=None, speakers=None, cache_states=params['cache_states']) assert len(out) == 2 hyps, aws = out assert isinstance(hyps, list) assert len(hyps) == batch_size assert isinstance(aws, list) assert aws[0].shape == (args['n_heads'] * args['n_layers'], len(hyps[0]), emax) else: out = dec.beam_search(eouts, elens, params, idx2token=None, lm=lm, lm_second=lm_second, lm_second_bwd=lm_second_bwd, ctc_log_probs=ctc_log_probs, nbest=params['nbest'], exclude_eos=params['exclude_eos'], refs_id=None, utt_ids=None, speakers=None, cache_states=params['cache_states']) assert len(out) == 3 nbest_hyps, aws, scores = out assert isinstance(nbest_hyps, list) assert len(nbest_hyps) == batch_size assert len(nbest_hyps[0]) == params['nbest'] ymax = len(nbest_hyps[0][0]) assert isinstance(aws, list) assert aws[0][0].shape == (args['n_heads'] * args['n_layers'], ymax, emax) assert isinstance(scores, list) assert len(scores) == batch_size assert len(scores[0]) == params['nbest'] # ensemble ensmbl_eouts, ensmbl_elens, ensmbl_decs = [], [], [] for _ in range(3): ensmbl_eouts += [eouts] ensmbl_elens += [elens] ensmbl_decs += [dec] out = dec.beam_search(eouts, elens, params, idx2token=None, lm=lm, lm_second=lm_second, lm_second_bwd=lm_second_bwd, ctc_log_probs=ctc_log_probs, nbest=params['nbest'], exclude_eos=params['exclude_eos'], refs_id=None, utt_ids=None, speakers=None, ensmbl_eouts=ensmbl_eouts, ensmbl_elens=ensmbl_elens, ensmbl_decs=ensmbl_decs, cache_states=params['cache_states']) assert len(out) == 3 nbest_hyps, aws, scores = out assert isinstance(nbest_hyps, list) assert len(nbest_hyps) == batch_size assert len(nbest_hyps[0]) == params['nbest'] ymax = len(nbest_hyps[0][0]) assert isinstance(aws, list) assert aws[0][0].shape == (args['n_heads'] * args['n_layers'], ymax, emax) assert isinstance(scores, list) assert len(scores) == batch_size assert len(scores[0]) == params['nbest']
def beam_search(self, eouts, elens, params, idx2token, lm=None, lm_second=None, lm_second_rev=None, nbest=1, refs_id=None, utt_ids=None, speakers=None): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (list): length `B` params (dict): recog_beam_width (int): size of beam recog_length_penalty (float): length penalty recog_lm_weight (float): weight of first path LM score recog_lm_second_weight (float): weight of second path LM score recog_lm_bwd_weight (float): weight of second path backward LM score idx2token (): converter from index to token lm: firsh path LM lm_second: second path LM lm_second_rev: secoding path backward LM nbest (int): refs_id (list): reference list utt_ids (list): utterance id list speakers (list): speaker list Returns: best_hyps (list): Best path hypothesis. `[B, L]` """ bs = eouts.size(0) beam_width = params['recog_beam_width'] lp_weight = params['recog_length_penalty'] lm_weight = params['recog_lm_weight'] lm_weight_second = params['recog_lm_second_weight'] if lm is not None: assert lm_weight > 0 lm.eval() if lm_second is not None: assert lm_weight_second > 0 lm_second.eval() best_hyps = [] log_probs = torch.log_softmax(self.output(eouts), dim=-1) for b in range(bs): # Elements in the beam are (prefix, (p_b, p_no_blank)) # Initialize the beam with the empty sequence, a probability of # 1 for ending in blank and zero for ending in non-blank (in log space). beam = [{ 'hyp': [self.eos], # <eos> is used for LM 'p_b': LOG_1, 'p_nb': LOG_0, 'score_lm': LOG_1, 'lmstate': None }] for t in range(elens[b]): new_beam = [] # Pick up the top-k scores log_probs_topk, topk_ids = torch.topk(log_probs[b:b + 1, t], k=min( beam_width, self.vocab), dim=-1, largest=True, sorted=True) for i_beam in range(len(beam)): hyp = beam[i_beam]['hyp'][:] p_b = beam[i_beam]['p_b'] p_nb = beam[i_beam]['p_nb'] score_lm = beam[i_beam]['score_lm'] # case 1. hyp is not extended new_p_b = np.logaddexp( p_b + log_probs[b, t, self.blank].item(), p_nb + log_probs[b, t, self.blank].item()) if len(hyp) > 1: new_p_nb = p_nb + log_probs[b, t, hyp[-1]].item() else: new_p_nb = LOG_0 score_ctc = np.logaddexp(new_p_b, new_p_nb) score_lp = len(hyp[1:]) * lp_weight new_beam.append({ 'hyp': hyp, 'score': score_ctc + score_lm + score_lp, 'p_b': new_p_b, 'p_nb': new_p_nb, 'score_ctc': score_ctc, 'score_lm': score_lm, 'score_lp': score_lp, 'lmstate': beam[i_beam]['lmstate'] }) # Update LM states for shallow fusion if lm is not None: _, lmstate, lm_log_probs = lm.predict( eouts.new_zeros(1, 1).fill_(hyp[-1]), beam[i_beam]['lmstate']) else: lmstate = None # case 2. hyp is extended new_p_b = LOG_0 for c in tensor2np(topk_ids)[0]: p_t = log_probs[b, t, c].item() if c == self.blank: continue c_prev = hyp[-1] if len(hyp) > 1 else None if c == c_prev: new_p_nb = p_b + p_t # TODO(hirofumi): apply character LM here else: new_p_nb = np.logaddexp(p_b + p_t, p_nb + p_t) # TODO(hirofumi): apply character LM here if c == self.space: pass # TODO(hirofumi): apply word LM here score_ctc = np.logaddexp(new_p_b, new_p_nb) score_lp = (len(hyp[1:]) + 1) * lp_weight if lm_weight > 0 and lm is not None: local_score_lm = lm_log_probs[0, 0, c].item() * lm_weight score_lm += local_score_lm new_beam.append({ 'hyp': hyp + [c], 'score': score_ctc + score_lm + score_lp, 'p_b': new_p_b, 'p_nb': new_p_nb, 'score_ctc': score_ctc, 'score_lm': score_lm, 'score_lp': score_lp, 'lmstate': lmstate }) # Pruning beam = sorted(new_beam, key=lambda x: x['score'], reverse=True)[:beam_width] # Rescoing lattice if lm_second is not None: new_beam = [] for i_beam in range(len(beam)): ys = [ np2tensor( np.fromiter(beam[i_beam]['hyp'], dtype=np.int64), self.device_id) ] ys_pad = pad_list(ys, lm_second.pad) _, _, lm_log_probs = lm_second.predict(ys_pad, None) score_ctc = np.logaddexp(beam[i_beam]['p_b'], beam[i_beam]['p_nb']) score_lm = lm_log_probs.sum() * lm_weight_second score_lp = len(beam[i_beam]['hyp'][1:]) * lp_weight new_beam.append({ 'hyp': beam[i_beam]['hyp'], 'score': score_ctc + score_lm + score_lp, 'score_ctc': score_ctc, 'score_lp': score_lp, 'score_lm': score_lm }) beam = sorted(new_beam, key=lambda x: x['score'], reverse=True) best_hyps.append(np.array(beam[0]['hyp'][1:])) if idx2token is not None: if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) assert self.vocab == idx2token.vocab logger.info('=' * 200) for k in range(len(beam)): if refs_id is not None: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token(beam[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % beam[k]['score']) logger.info('log prob (hyp, ctc): %.7f' % (beam[k]['score_ctc'])) logger.info('log prob (hyp, lp): %.7f' % (beam[k]['score_lp'] * lp_weight)) if lm is not None: logger.info('log prob (hyp, first-path lm): %.7f' % (beam[k]['score_lm'] * lm_weight)) if lm_second is not None: logger.info( 'log prob (hyp, second-path lm): %.7f' % (beam[k]['score_lm_second'] * lm_weight_second)) logger.info('-' * 50) return np.array(best_hyps)