def forward(self, eouts, elens, ys, forced_align=False): """Compute CTC loss. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (List): length `B` ys (List): length `B`, each of which contains a list of size `[L]` Returns: loss (FloatTensor): `[1]` trigger_points (IntTensor): `[B, L]` """ # Concatenate all elements in ys for warpctc_pytorch ylens = np2tensor(np.fromiter([len(y) for y in ys], dtype=np.int32)) ys_ctc = torch.cat([np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int32)) for y in ys], dim=0) # NOTE: do not copy to GPUs here # Compute CTC loss logits = self.output(eouts) loss = self.loss_fn(logits.transpose(1, 0), ys_ctc, elens, ylens) # Label smoothing for CTC if self.lsm_prob > 0: loss = loss * (1 - self.lsm_prob) + kldiv_lsm_ctc(logits, elens) * self.lsm_prob trigger_points = self.forced_align(logits, elens, ys, ylens) if forced_align else None if not self.training: self.data_dict['elens'] = tensor2np(elens) self.prob_dict['probs'] = tensor2np(torch.softmax(logits, dim=-1)) return loss, trigger_points
def ctc_probs_topk(self, eouts, temperature, topk): probs = F.softmax(self.ctc.output(eouts) / temperature, dim=-1) if topk is None: topk = probs.size(-1) _, topk_ids = torch.topk(probs.sum(1), k=topk, dim=-1, largest=True, sorted=True) return tensor2np(probs), tensor2np(topk_ids)
def ctc_posteriors(self, eouts, x_lens, temperature, topk): # Path through the softmax layer logits_ctc = self.output_ctc(eouts) ctc_probs = F.softmax(logits_ctc / temperature, dim=-1) if topk is None: topk = ctc_probs.size(-1) _, indices_topk = torch.topk(ctc_probs.sum(1), k=topk, dim=-1, largest=True, sorted=True) return tensor2np(ctc_probs), tensor2np(indices_topk)
def add_ctc_score(self, hyp, topk_ids, ctc_state, total_scores_topk, ctc_prefix_scorer, new_chunk=False, backward=False): beam_width = self.beam_width_bwd if backward else self.beam_width if ctc_prefix_scorer is None: return None, topk_ids.new_zeros(beam_width), total_scores_topk ctc_scores, new_ctc_states = ctc_prefix_scorer(hyp, tensor2np(topk_ids[0]), ctc_state, new_chunk=new_chunk) total_scores_ctc = torch.from_numpy(ctc_scores).to(self.device) total_scores_topk += total_scores_ctc * self.ctc_weight # Sort again total_scores_topk, joint_ids_topk = torch.topk(total_scores_topk, k=beam_width, dim=1, largest=True, sorted=True) topk_ids = topk_ids[:, joint_ids_topk[0]] new_ctc_states = new_ctc_states[joint_ids_topk[0].cpu().numpy()] return new_ctc_states, total_scores_ctc, total_scores_topk
def decode(self, ys, state=None, mems=None, cache=None, incremental=False): """Decode function. Args: ys (LongTensor): `[B, L]` state (list): dummy interfance for RNNLM mems (list): length `n_layers`, each of which contains a FloatTensor `[B, mlen, d_model]` cache (list): length `L`, each of which contains a FloatTensor `[B, L-1, d_model]` incremental (bool): ASR decoding mode Returns: logits (FloatTensor): `[B, L, vocab]` out (FloatTensor): `[B, L, d_model]` new_cache (list): length `n_layers`, each of which contains a FloatTensor `[B, L, d_model]` """ # for ASR decoding if cache is None: cache = [None] * self.n_layers # 1-th to L-th layer if mems is None: mems = self.init_memory() # Create the self-attention mask bs, ylen = ys.size()[:2] if incremental and cache[0] is not None: ylen = cache[0].size(1) + 1 causal_mask = ys.new_ones(ylen, ylen).byte() causal_mask = torch.tril(causal_mask, diagonal=0, out=causal_mask).unsqueeze(0) causal_mask = causal_mask.repeat([bs, 1, 1]) out = self.pos_enc(self.embed(ys.long())) new_mems = [None] * self.n_layers new_cache = [None] * self.n_layers hidden_states = [out] for lth, (mem, layer) in enumerate(zip(mems, self.layers)): out = layer(out, causal_mask, cache=cache[lth], memory=mem) if incremental: new_cache[lth] = out elif lth < self.n_layers - 1: hidden_states.append(out) # NOTE: outputs from the last layer is not used for memory if not self.training and layer.yy_aws is not None: setattr(self, 'yy_aws_layer%d' % lth, tensor2np(layer.yy_aws)) out = self.norm_out(out) if self.adaptive_softmax is None: logits = self.output(out) else: logits = out if incremental: # NOTE: do not update memory here during ASR decoding return logits, out, new_cache elif self.mem_len > 0: # Update memory new_mems = self.update_memory(mems, hidden_states) return logits, out, new_mems else: return logits, out, mems
def ctc_forced_align(self, xs, ys, task='ys'): """CTC-based forced alignment. Args: xs (FloatTensor): `[B, T, idim]` ys (List): length `B`, each of which contains a list of size `[L]` Returns: trigger_points (np.ndarray): `[B, L]` """ from neural_sp.models.seq2seq.decoders.ctc import CTCForcedAligner forced_aligner = CTCForcedAligner() self.eval() with torch.no_grad(): eout_dict = self.encode(xs, 'ys') # NOTE: support the main task only ctc = getattr(self, 'dec_fwd').ctc logits = ctc.output(eout_dict[task]['xs']) ylens = np2tensor(np.fromiter([len(y) for y in ys], dtype=np.int32)) trigger_points = forced_aligner(logits, eout_dict[task]['xlens'], ys, ylens) return tensor2np(trigger_points)
def get_ctc_probs(self, xs, task='ys', temperature=1, topk=None): self.eval() with torch.no_grad(): eout_dict = self.encode(xs, task) dir = 'fwd' if self.fwd_weight >= self.bwd_weight else 'bwd' if task == 'ys_sub1': dir += '_sub1' elif task == 'ys_sub2': dir += '_sub2' if task == 'ys': assert self.ctc_weight > 0 elif task == 'ys_sub1': assert self.ctc_weight_sub1 > 0 elif task == 'ys_sub2': assert self.ctc_weight_sub2 > 0 ctc_probs, indices_topk = getattr(self, 'dec_' + dir).ctc_probs_topk( eout_dict[task]['xs'], temperature, topk) return tensor2np(ctc_probs), tensor2np(indices_topk), eout_dict[task]['xlens']
def forward(self, xs, xlens, task): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): `[B]` task (str): not supported now Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (list): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks before RNN layers xs, xlens = self.conv(xs, xlens) # Create the self-attention mask bs, xmax = xs.size()[:2] xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(1).expand( bs, xmax, xmax) xx_mask = xx_mask.unsqueeze(1).expand(bs, self.attn_n_heads, xmax, xmax) xs = self.pos_enc(xs) for l in range(self.n_layers): xs, xx_aws = self.layers[l](xs, xx_mask) if not self.training: setattr(self, 'xx_aws_layer%d' % l, tensor2np(xx_aws)) xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) eouts['ys']['xs'] = xs eouts['ys']['xlens'] = xlens return eouts
def decode(self, ys, state=None, mems=None, cache=None, incremental=False): """Decode function. Args: ys (LongTensor): `[B, L]` state (List): dummy interfance for RNNLM mems (List): length `n_layers` (inter-utterance), each of which contains a FloatTensor of size `[B, mlen, d_model]` cache (List): length `n_layers` (intra-utterance), each of which contains a FloatTensor of size `[B, L-1, d_model]` incremental (bool): ASR decoding mode Returns: logits (FloatTensor): `[B, L, vocab]` out (FloatTensor): `[B, L, d_model]` new_cache (List): length `n_layers`, each of which contains a FloatTensor of size `[B, L, d_model]` """ # for ASR decoding if cache is None: cache = [None] * self.n_layers # 1-th to L-th layer bs, ylen = ys.size()[:2] n_hist = 0 if incremental and cache[0] is not None: n_hist = cache[0].size(1) ylen += n_hist # Create the self-attention mask causal_mask = ys.new_ones(ylen, ylen).byte() causal_mask = torch.tril(causal_mask).unsqueeze(0) causal_mask = causal_mask.repeat([bs, 1, 1]) # `[B, L, L]` out = self.pos_enc(self.embed_token_id(ys), scale=True, offset=max(0, n_hist)) # scaled + dropout new_cache = [None] * self.n_layers hidden_states = [out] for lth, layer in enumerate(self.layers): out = layer(out, causal_mask, cache=cache[lth]) if incremental: new_cache[lth] = out elif lth < self.n_layers - 1: hidden_states.append(out) # NOTE: outputs from the last layer is not used for cache if not self.training and layer.yy_aws is not None: setattr(self, 'yy_aws_layer%d' % lth, tensor2np(layer.yy_aws)) out = self.norm_out(out) if self.adaptive_softmax is None: logits = self.output(out) else: logits = out return logits, out, new_cache
def get_ctc_probs(self, xs, task='ys', temperature=1, topk=None): """Get CTC top-K probabilities. Args: xs (FloatTensor): `[B, T, idim]` task (str): task to evaluate temperature (float): softmax temperature topk (int): top-K classes to sample Returns: probs (np.ndarray): `[B, T, vocab]` topk_ids (np.ndarray): `[B, T, topk]` elens (IntTensor): `[B]` """ self.eval() with torch.no_grad(): eout_dict = self.encode(xs, task) dir = 'fwd' if self.fwd_weight >= self.bwd_weight else 'bwd' if task == 'ys_sub1': dir += '_sub1' elif task == 'ys_sub2': dir += '_sub2' if task == 'ys': assert self.ctc_weight > 0 elif task == 'ys_sub1': assert self.ctc_weight_sub1 > 0 elif task == 'ys_sub2': assert self.ctc_weight_sub2 > 0 probs = getattr(self, 'dec_' + dir).ctc.probs(eout_dict[task]['xs']) if topk is None: topk = probs.size(-1) # return all classes _, topk_ids = torch.topk(probs, k=topk, dim=-1, largest=True, sorted=True) return tensor2np(probs), tensor2np( topk_ids), eout_dict[task]['xlens']
def sub_module(self, xs, xx_mask, lth, pos_embs=None, module='sub1'): if self.task_specific_layer: xs_sub = getattr(self, 'layer_' + module)(xs, xx_mask, pos_embs=pos_embs) else: xs_sub = xs.clone() xs_sub = getattr(self, 'norm_out_' + module)(xs_sub) if getattr(self, 'bridge_' + module) is not None: xs_sub = getattr(self, 'bridge_' + module)(xs_sub) if not self.training: self.aws_dict['xx_aws_%s_layer%d' % (module, lth)] = tensor2np(getattr(self, 'layer_' + module).xx_aws) return xs_sub
def ctc_forced_align(self, xs, ys, task='ys'): """CTC-based forced alignment. Args: xs (FloatTensor): `[B, T, idim]` ys (List): length `B`, each of which contains a list of size `[L]` Returns: trigger_points (np.ndarray): `[B, L]` """ self.eval() with torch.no_grad(): eout_dict = self.encode(xs, 'ys') # NOTE: support the main task only trigger_points = getattr(self, 'dec_fwd').ctc_forced_align( eout_dict[task]['xs'], eout_dict[task]['xlens'], ys) return tensor2np(trigger_points)
def decode(self, ys, state=None, is_asr=False): """Decode function. Args: ys (FloatTensor): `[B, L]` state: previous tokens is_asr (bool): Returns: ys_emb (FloatTensor): `[B, L, n_units]` state: previous tokens """ # Concatenate previous tokens if is_asr and state is not None: ys = torch.cat([state, ys], dim=1) # NOTE: this is used for ASR decoding ys_emb = self.embed(ys.long()) # Create the self-attention mask bs, ymax = ys_emb.size()[:2] ylens = torch.IntTensor([ymax] * bs) yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand( bs, ymax, ymax) yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax, ymax) subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(), diagonal=0) subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand( bs, self.attn_n_heads, ymax, ymax) yy_mask = yy_mask & subsequent_mask ys_emb = self.pos_enc(ys_emb) for l in range(self.n_layers): ys_emb, yy_aws, _ = self.layers[l](ys_emb, yy_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) ys_emb = self.norm_out(ys_emb) if is_asr: state = ys return ys_emb, state
def decode(self, ys, state=None, mems=None, cache=None, incremental=False): """Decode function. Args: ys (LongTensor): `[B, L]` state (list): dummy interfance for RNNLM mems (list): dummy interface for TransformerXL cache (list): length `L`, each of which contains a FloatTensor `[B, L-1, d_model]` incremental (bool): ASR decoding mode Returns: logits (FloatTensor): `[B, L, vocab]` out (FloatTensor): `[B, L, d_model]` new_cache (list): length `n_layers`, each of which contains a FloatTensor `[B, L, d_model]` new_mems: dummy interfance for TransformerXL """ # for ASR decoding if cache is None: cache = [None] * self.n_layers # Create the self-attention mask bs, ylen = ys.size()[:2] if incremental and cache[0] is not None: ylen = cache[0].size(1) + 1 causal_mask = ys.new_ones(ylen, ylen).byte() causal_mask = torch.tril(causal_mask, diagonal=0, out=causal_mask).unsqueeze(0) causal_mask = causal_mask.repeat([bs, 1, 1]) new_cache = [None] * self.n_layers out = self.pos_enc(self.embed(ys.long())) for l, layer in enumerate(self.layers): out, yy_aws = layer(out, causal_mask, cache=cache[l])[:2] if incremental: new_cache[l] = out if not self.training and yy_aws is not None: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) out = self.norm_out(out) if self.adaptive_softmax is None: logits = self.output(out) else: logits = out return logits, out, new_cache
def decode_ctc(self, eouts, x_lens, beam_width=1, rnnlm=None): """Decoding by the CTC layer in the inference stage. This is only used for Joint CTC-Attention model. Args: eouts (FloatTensor): `[B, T, enc_units]` beam_width (int): the size of beam rnnlm (): Returns: best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]` """ logits_ctc = self.output_ctc(eouts) if beam_width == 1: best_hyps = self.decode_ctc_greedy(tensor2np(logits_ctc), x_lens) else: best_hyps = self.decode_ctc_beam(F.log_softmax(logits_ctc, dim=-1), x_lens, beam_width, rnnlm) # TODO(hirofumi): decoding paramters return best_hyps
def decode(self, ys, ys_prev=None, cache=False): """Decode function. Args: ys (LongTensor): `[B, L]` ys_prev (LongTensor): previous tokens cahce (bool): concatenate previous tokens Returns: logits (FloatTensor): `[B, L, vocab]` ys_emb (FloatTensor): `[B, L, d_model]` (for ys_prev) ys_prev (LongTensor): previous tokens """ # Concatenate previous tokens if cache and ys_prev is not None: ys = torch.cat([ys_prev, ys], dim=1) # NOTE: this is used for ASR decoding # Create the self-attention mask bs, ymax = ys.size()[:2] ylens = torch.IntTensor([ymax] * bs) tgt_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat( [1, ymax, 1]) subsequent_mask = tgt_mask.new_ones(ymax, ymax).byte() subsequent_mask = torch.tril(subsequent_mask, out=subsequent_mask).unsqueeze(0) tgt_mask = tgt_mask & subsequent_mask out = self.pos_enc(self.embed(ys.long())) for l in range(self.n_layers): out, yy_aws, _ = self.layers[l](out, tgt_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) out = self.norm_out(out) if self.adaptive_softmax is None: logits = self.output(out) else: logits = out return logits, out, ys
def decode(self, ys, state=None, cache=False): """Decode function. Args: ys (LongTensor): `[B, L]` state (LongTensor): `[B, L]` cahce (bool): concatenate previous tokens Returns: logits (FloatTensor): `[B, L, vocab]` out (FloatTensor): `[B, L, d_model]` new_state (LongTensor): previous tokens """ # Concatenate previous tokens if cache and state is not None: ys = torch.cat([state, ys], dim=1) # NOTE: this is used for ASR decoding # Create the self-attention mask bs, ylen = ys.size()[:2] causal_mask = ys.new_ones(ylen, ylen).byte() causal_mask = torch.tril(causal_mask, diagonal=0, out=causal_mask).unsqueeze(0) causal_mask = causal_mask.repeat([bs, 1, 1]) out = self.pos_enc(self.embed(ys.long())) for l, layer in enumerate(self.layers): out, yy_aws = layer(out, causal_mask)[:2] if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) out = self.norm_out(out) if self.adaptive_softmax is None: logits = self.output(out) else: logits = out return logits, out, ys
def beam_search(self, eouts, elens, params, idx2token=None, lm=None, lm_second=None, lm_second_bwd=None, ctc_log_probs=None, nbest=1, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, ensmbl_eouts=[], ensmbl_elens=[], ensmbl_decs=[], cache_states=True): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` params (dict): decoding hyperparameters idx2token (): converter from index to token lm (torch.nn.module): firsh-pass LM lm_second (torch.nn.module): second-pass LM lm_second_bwd (torch.nn.module): secoding-pass backward LM ctc_log_probs (FloatTensor): nbest (int): number of N-best list exclude_eos (bool): exclude <eos> from hypothesis refs_id (List): reference list utt_ids (List): utterance id list speakers (List): speaker list ensmbl_eouts (List[FloatTensor]): encoder outputs for ensemble models ensmbl_elens (List[IntTensor]) encoder outputs for ensemble models ensmbl_decs (List[torch.nn.Module): decoders for ensemble models cache_states (bool): cache decoder states for fast decoding Returns: nbest_hyps_idx (List): length `[B]`, each of which contains list of N hypotheses aws (List): length `[B]`, each of which contains arrays of size `[H, L, T]` scores (List): """ bs, xmax, _ = eouts.size() n_models = len(ensmbl_decs) + 1 beam_width = params.get('recog_beam_width') assert 1 <= nbest <= beam_width ctc_weight = params.get('recog_ctc_weight') max_len_ratio = params.get('recog_max_len_ratio') min_len_ratio = params.get('recog_min_len_ratio') lp_weight = params.get('recog_length_penalty') length_norm = params.get('recog_length_norm') cache_emb = params.get('recog_cache_embedding') lm_weight = params.get('recog_lm_weight') lm_weight_second = params.get('recog_lm_second_weight') lm_weight_second_bwd = params.get('recog_lm_bwd_weight') eos_threshold = params.get('recog_eos_threshold') lm_state_carry_over = params.get('recog_lm_state_carry_over') softmax_smoothing = params.get('recog_softmax_smoothing') eps_wait = params.get('recog_mma_delay_threshold') helper = BeamSearch(beam_width, self.eos, ctc_weight, lm_weight, self.device) lm = helper.verify_lm_eval_mode(lm, lm_weight, cache_emb) lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second, cache_emb) lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd, lm_weight_second_bwd, cache_emb) # cache token embeddings if cache_emb: self.cache_embedding(eouts.device) if ctc_log_probs is not None: assert ctc_weight > 0 ctc_log_probs = tensor2np(ctc_log_probs) nbest_hyps_idx, aws, scores = [], [], [] eos_flags = [] for b in range(bs): # Initialization per utterance lmstate = None ys = eouts.new_zeros((1, 1), dtype=torch.int64).fill_(self.eos) # print(ys.shape) for layer in self.layers: layer.reset() # For joint CTC-Attention decoding ctc_prefix_scorer = None if ctc_log_probs is not None: if self.bwd: ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b][::-1], self.blank, self.eos) else: ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos) if speakers is not None: if speakers[b] == self.prev_spk: if lm_state_carry_over and isinstance(lm, RNNLM): lmstate = self.lmstate_final self.prev_spk = speakers[b] end_hyps = [] hyps = [{'hyp': [self.eos], 'ys': ys, 'cache': None, 'score': 0., 'score_att': 0., 'score_ctc': 0., 'score_lm': 0., 'aws': [None], 'lmstate': lmstate, 'ensmbl_cache': [[None] * dec.n_layers for dec in ensmbl_decs] if n_models > 1 else None, 'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None, 'quantity_rate': 1., 'streamable': True, 'streaming_failed_point': 1000}] streamable_global = True ymax = math.ceil(elens[b] * max_len_ratio) for i in range(ymax): # batchfy all hypotheses for batch decoding cache = [None] * self.n_layers if cache_states and i > 0: for lth in range(self.n_layers): # cache[lth] = torch.cat([beam['cache'][lth] for beam in hyps], dim=0) ys = eouts.new_zeros((len(hyps), i + 1), dtype=torch.int64) for j, beam in enumerate(hyps): ys[j, :] = beam['ys'] if i > 0: xy_aws_prev = torch.cat([beam['aws'][-1] for beam in hyps], dim=0) # `[B, n_layers, H_ma, 1, klen]` else: xy_aws_prev = None # Update LM states for shallow fusion y_lm = ys[:, -1:].clone() # NOTE: this is important _, lmstate, scores_lm = helper.update_rnnlm_state_batch(lm, hyps, y_lm) # for the main model # print(i) causal_mask = eouts.new_ones(i + 1, i + 1, dtype=torch.uint8) causal_mask = torch.tril(causal_mask).unsqueeze(0).repeat([ys.size(0), 1, 1]) # print(causal_mask.shape) out = self.pos_enc(self.embed_token_id(ys), scale=True) # scaled + dropout # print(out.shape) # assert False, 'vv' n_heads_total = 0 eouts_b = eouts[b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) # [Beam, T, dim] new_cache = [None] * self.n_layers xy_aws_layers = [] xy_aws = None lth_s = self.mma_first_layer - 1 # 自回归解码 for lth, layer in enumerate(self.layers): out = layer( out, causal_mask, eouts_b, None, cache=cache[lth], xy_aws_prev=xy_aws_prev[:, lth - lth_s] if lth >= lth_s and i > 0 else None, eps_wait=eps_wait) xy_aws = layer.xy_aws new_cache[lth] = out if xy_aws is not None: xy_aws_layers.append(xy_aws) logits = self.output(self.norm_out(out[:, -1])) # 取当前时刻概率输出 probs = torch.softmax(logits * softmax_smoothing, dim=1) xy_aws_layers = torch.stack(xy_aws_layers, dim=1) # `[B, H, n_layers, L, T]` # Ensemble initialization ensmbl_cache = [[None] * dec.n_layers for dec in ensmbl_decs] if n_models > 1 and cache_states and i > 0: for i_e, dec in enumerate(ensmbl_decs): for lth in range(dec.n_layers): ensmbl_cache[i_e][lth] = torch.cat([beam['ensmbl_cache'][i_e][lth] for beam in hyps], dim=0) # for the ensemble ensmbl_new_cache = [[None] * dec.n_layers for dec in ensmbl_decs] for i_e, dec in enumerate(ensmbl_decs): out_e = dec.pos_enc(dec.embed(ys)) # scaled + dropout eouts_e = ensmbl_eouts[i_e][b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) for lth in range(dec.n_layers): out_e = dec.layers[lth](out_e, causal_mask, eouts_e, None, cache=ensmbl_cache[i_e][lth]) ensmbl_new_cache[i_e][lth] = out_e logits_e = dec.output(dec.norm_out(out_e[:, -1])) probs += torch.softmax(logits_e * softmax_smoothing, dim=1) # NOTE: sum in the probability scale (not log-scale) # Ensemble 多个模型融合 scores_att = torch.log(probs / n_models) # [1, vocab] # print(scores_att.shape) # assert False, 'vv' new_hyps = [] for j, beam in enumerate(hyps): # hyps [,] # 每个beam生成beam # Attention scores total_scores_att = beam['score_att'] + scores_att[j:j + 1] # current time T # [[vocab]] total_scores = total_scores_att * (1 - ctc_weight) # Add LM score <before> top-K selection if lm is not None: total_scores_lm = beam['score_lm'] + scores_lm[j:j + 1, -1] total_scores += total_scores_lm * lm_weight else: total_scores_lm = eouts.new_zeros(1, self.vocab) # topk_ids total_scores_topk, topk_ids = torch.topk( total_scores, k=beam_width, dim=1, largest=True, sorted=True) # Add length penalty if lp_weight > 0: total_scores_topk += (len(beam['hyp'][1:]) + 1) * lp_weight # Add CTC score new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score( beam['hyp'], topk_ids, beam['ctc_state'], total_scores_topk, ctc_prefix_scorer) new_aws = beam['aws'] + [xy_aws_layers[j:j + 1, :, :, -1:]] aws_j = torch.cat(new_aws[1:], dim=3) # `[1, H, n_layers, L, T]` # forward direction for k in range(beam_width): idx = topk_ids[0, k].item() # k-beam 的索引 length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1 total_score = total_scores_topk[0, k].item() / length_norm_factor # 当前长度 if idx == self.eos: # Exclude short hypotheses # remove 短句 中间的静默信号 if len(beam['hyp'][1:]) < elens[b] * min_len_ratio: continue # EOS threshold # 找到不是EOS的最大得分idx max_score_no_eos = scores_att[j, :idx].max(0)[0].item() max_score_no_eos = max(max_score_no_eos, scores_att[j, idx + 1:].max(0)[0].item()) if scores_att[j, idx].item() <= eos_threshold * max_score_no_eos: # 继续识别 跳过当前帧 continue streaming_failed_point = beam['streaming_failed_point'] quantity_rate = 1. # 流式相关的 if self.attn_type == 'mocha': n_tokens_hyp_k = i + 1 n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item() quantity_diff = n_tokens_hyp_k * n_heads_total - n_quantity_k if quantity_diff != 0: if idx == self.eos: n_tokens_hyp_k -= 1 # NOTE: do not count <eos> for streamability n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item() else: streamable_global = False if n_tokens_hyp_k * n_heads_total == 0: quantity_rate = 0 else: quantity_rate = n_quantity_k / (n_tokens_hyp_k * n_heads_total) if beam['streamable'] and not streamable_global: streaming_failed_point = i new_hyps.append( {'hyp': beam['hyp'] + [idx], 'ys': torch.cat([beam['ys'], eouts.new_zeros((1, 1), dtype=torch.int64).fill_(idx)], dim=-1), 'cache': [new_cache_l[j:j + 1] for new_cache_l in new_cache] if cache_states else cache, 'score': total_score, 'score_att': total_scores_att[0, idx].item(), 'score_ctc': total_scores_ctc[k].item(), 'score_lm': total_scores_lm[0, idx].item(), 'aws': new_aws, 'lmstate': {'hxs': lmstate['hxs'][:, j:j + 1], 'cxs': lmstate['cxs'][:, j:j + 1]} if lmstate is not None else None, 'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None, 'ensmbl_cache': [[new_cache_e_l[j:j + 1] for new_cache_e_l in new_cache_e] for new_cache_e in ensmbl_new_cache] if cache_states else None, 'streamable': streamable_global, 'streaming_failed_point': streaming_failed_point, 'quantity_rate': quantity_rate}) # Local pruning # new_hyps[beamsize,hyps] new_hyps_sorted = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width] # Remove complete hypotheses # 剪枝 结果beamwidth大小的列表 new_hyps, end_hyps, is_finish = helper.remove_complete_hyp( new_hyps_sorted, end_hyps, prune=True) hyps = new_hyps[:] if is_finish: break # Global pruning # 一句识别结束 if len(end_hyps) == 0: end_hyps = hyps[:] elif len(end_hyps) < nbest and nbest > 1: end_hyps.extend(hyps[:nbest - len(end_hyps)]) # forward/backward second-pass LM rescoring end_hyps = helper.lm_rescoring(end_hyps, lm_second, lm_weight_second, length_norm=length_norm, tag='second') end_hyps = helper.lm_rescoring(end_hyps, lm_second_bwd, lm_weight_second_bwd, length_norm=length_norm, tag='second_bwd') # Sort by score end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True) # TODO: for j in range(len(end_hyps[0]['aws'][1:])): tmp = end_hyps[0]['aws'][j + 1] end_hyps[0]['aws'][j + 1] = tmp.view(1, -1, tmp.size(-2), tmp.size(-1)) # metrics for streaming infernece self.streamable = end_hyps[0]['streamable'] self.quantity_rate = end_hyps[0]['quantity_rate'] self.last_success_frame_ratio = None if idx2token is not None: if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) assert self.vocab == idx2token.vocab logger.info('=' * 200) for k in range(len(end_hyps)): if refs_id is not None: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token( end_hyps[k]['hyp'][1:][::-1] if self.bwd else end_hyps[k]['hyp'][1:])) logger.info('num tokens (hyp): %d' % len(end_hyps[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % end_hyps[k]['score']) logger.info('log prob (hyp, att): %.7f' % (end_hyps[k]['score_att'] * (1 - ctc_weight))) if ctc_prefix_scorer is not None: logger.info('log prob (hyp, ctc): %.7f' % (end_hyps[k]['score_ctc'] * ctc_weight)) if lm is not None: logger.info('log prob (hyp, first-pass lm): %.7f' % (end_hyps[k]['score_lm'] * lm_weight)) if lm_second is not None: logger.info('log prob (hyp, second-pass lm): %.7f' % (end_hyps[k]['score_lm_second'] * lm_weight_second)) if lm_second_bwd is not None: logger.info('log prob (hyp, second-pass lm, reverse): %.7f' % (end_hyps[k]['score_lm_second_bwd'] * lm_weight_second_bwd)) if self.attn_type == 'mocha': logger.info('streamable: %s' % end_hyps[k]['streamable']) logger.info('streaming failed point: %d' % (end_hyps[k]['streaming_failed_point'] + 1)) logger.info('quantity rate [%%]: %.2f' % (end_hyps[k]['quantity_rate'] * 100)) logger.info('-' * 50) if self.attn_type == 'mocha' and end_hyps[0]['streaming_failed_point'] < 1000: assert not self.streamable aws_last_success = end_hyps[0]['aws'][1:][end_hyps[0]['streaming_failed_point'] - 1] rightmost_frame = max(0, aws_last_success[0, :, 0].nonzero()[:, -1].max().item()) + 1 frame_ratio = rightmost_frame * 100 / xmax self.last_success_frame_ratio = frame_ratio logger.info('streaming last success frame ratio: %.2f' % frame_ratio) # N-best list if self.bwd: # Reverse the order nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:][::-1]) for n in range(nbest)]] aws += [[tensor2np(torch.cat(end_hyps[n]['aws'][1:][::-1], dim=2).squeeze(0)) for n in range(nbest)]] else: nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)]] aws += [[tensor2np(torch.cat(end_hyps[n]['aws'][1:], dim=2).squeeze(0)) for n in range(nbest)]] scores += [[end_hyps[n]['score_att'] for n in range(nbest)]] # Check <eos> eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)]) # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.bwd: nbest_hyps_idx = [[nbest_hyps_idx[b][n][1:] if eos_flags[b][n] else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)] aws = [[aws[b][n][:, 1:] if eos_flags[b][n] else aws[b][n] for n in range(nbest)] for b in range(bs)] else: nbest_hyps_idx = [[nbest_hyps_idx[b][n][:-1] if eos_flags[b][n] else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)] aws = [[aws[b][n][:, :-1] if eos_flags[b][n] else aws[b][n] for n in range(nbest)] for b in range(bs)] # Store ASR/LM state if bs == 1: self.lmstate_final = end_hyps[0]['lmstate'] return nbest_hyps_idx, aws, scores
def greedy(self, eouts, elens, max_len_ratio, idx2token, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, cache_states=True): """Greedy decoding. Args: eouts (FloatTensor): `[B, T, enc_units]` elens (IntTensor): `[B]` max_len_ratio (int): maximum sequence length of tokens idx2token (): converter from index to token exclude_eos (bool): exclude <eos> from hypothesis refs_id (List): reference list utt_ids (List): utterance id list speakers (List): speaker list cache_states (bool): cache decoder states for fast decoding Returns: hyps (List): length `[B]`, each of which contains arrays of size `[L]` aws (List): length `[B]`, each of which contains arrays of size `[H * n_layers, L, T]` """ bs, xmax = eouts.size()[:2] ys = eouts.new_zeros((bs, 1), dtype=torch.int64).fill_(self.eos) # print(ys) for layer in self.layers: layer.reset() cache = [None] * self.n_layers hyps_batch = [] ylens = torch.zeros(bs).int() eos_flags = [False] * bs xy_aws_layers_steps = [] ymax = math.ceil(xmax * max_len_ratio) for i in range(ymax): # 最长句子 # 下三角mask 频闭未来的信息 causal_mask = eouts.new_ones(i + 1, i + 1, dtype=torch.uint8) causal_mask = torch.tril(causal_mask).unsqueeze(0).repeat([bs, 1, 1]) new_cache = [None] * self.n_layers xy_aws_layers = [] out = self.pos_enc(self.embed_token_id(ys), scale=True) # scaled + dropout for lth, layer in enumerate(self.layers): # decoder layer out = layer(out, causal_mask, eouts, None, cache=cache[lth]) new_cache[lth] = out if layer.xy_aws is not None: xy_aws_layers.append(layer.xy_aws[:, :, -1:]) if cache_states: cache = new_cache[:] # Pick up 1-best y = self.output(self.norm_out(out))[:, -1:].argmax(-1) # hyps_batch += [y] xy_aws_layers = torch.stack(xy_aws_layers, dim=2) # `[B, H, n_layers, 1, T]` xy_aws_layers_steps.append(xy_aws_layers) # Count lengths of hypotheses for b in range(bs): if not eos_flags[b]: if y[b].item() == self.eos: eos_flags[b] = True ylens[b] += 1 # include <eos> # Break if <eos> is outputed in all mini-batch if sum(eos_flags) == bs: break if i == ymax - 1: break ys = torch.cat([ys, y], dim=-1) # Concatenate in L dimension hyps_batch = tensor2np(torch.cat(hyps_batch, dim=1)) xy_aws_layers_steps = torch.cat(xy_aws_layers_steps, dim=-2) # `[B, H, n_layers, L, T]` xy_aws_layers_steps = xy_aws_layers_steps.reshape(bs, self.n_heads * self.n_layers, ys.size(1), xmax) xy_aws = tensor2np(xy_aws_layers_steps) # Truncate by the first <eos> (<sos> in case of the backward decoder) if self.bwd: # Reverse the order hyps = [hyps_batch[b, :ylens[b]][::-1] for b in range(bs)] aws = [xy_aws[b, :, :ylens[b], :][:, ::-1] for b in range(bs)] else: hyps = [hyps_batch[b, :ylens[b]] for b in range(bs)] aws = [xy_aws[b, :, :ylens[b], :] for b in range(bs)] # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.bwd: hyps = [hyps[b][1:] if eos_flags[b] else hyps[b] for b in range(bs)] aws = [aws[b][:, 1:] if eos_flags[b] else aws[b] for b in range(bs)] else: hyps = [hyps[b][:-1] if eos_flags[b] else hyps[b] for b in range(bs)] aws = [aws[b][:, :-1] if eos_flags[b] else aws[b] for b in range(bs)] if idx2token is not None: # idx -> token for b in range(bs): if utt_ids is not None: logger.debug('Utt-id: %s' % utt_ids[b]) if refs_id is not None and self.vocab == idx2token.vocab: logger.debug('Ref: %s' % idx2token(refs_id[b])) if self.bwd: logger.debug('Hyp: %s' % idx2token(hyps[b][::-1])) else: logger.debug('Hyp: %s' % idx2token(hyps[b])) logger.info('=' * 200) # NOTE: do not show with logger.info here return hyps, aws
def forward_att(self, eouts, elens, ys, trigger_points=None): """Compute XE loss for the Transformer decoder. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (List): length `[B]`, each of which contains a list of size `[L]` trigger_points (IntTensor): `[B, L]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity losses_auxiliary (dict): """ losses_auxiliary = {} # Append <sos> and <eos> ys_in, ys_out, ylens = append_sos_eos(ys, self.eos, self.eos, self.pad, self.device, self.bwd) if not self.training: self.data_dict['elens'] = tensor2np(elens) self.data_dict['ylens'] = tensor2np(ylens) self.data_dict['ys'] = tensor2np(ys_out) # Create target self-attention mask bs, ymax = ys_in.size()[:2] tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1]) causal_mask = tgt_mask.new_ones(ymax, ymax, dtype=tgt_mask.dtype) causal_mask = torch.tril(causal_mask).unsqueeze(0) tgt_mask = tgt_mask & causal_mask # `[B, L (query), L (key)]` # Create source-target mask src_mask = make_pad_mask(elens.to(self.device)).unsqueeze(1).repeat([1, ymax, 1]) # `[B, L, T]` # Create attention padding mask for quantity loss if self.attn_type == 'mocha': attn_mask = (ys_out != self.pad).unsqueeze(1).unsqueeze(3) # `[B, 1, L, 1]` else: attn_mask = None # external LM integration lmout = None if self.lm is not None: self.lm.eval() with torch.no_grad(): lmout, lmstate, _ = self.lm.predict(ys_in, None) lmout = self.lm_output_proj(lmout) out = self.pos_enc(self.embed_token_id(ys_in), scale=True) # scaled + dropout xy_aws_layers = [] xy_aws = None for lth, layer in enumerate(self.layers): out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout) # Attention padding xy_aws = layer.xy_aws if xy_aws is not None and self.attn_type == 'mocha': xy_aws_masked = xy_aws.masked_fill_(attn_mask.expand_as(xy_aws) == 0, 0) # NOTE: attention padding is quite effective for quantity loss xy_aws_layers.append(xy_aws_masked.clone()) if not self.training: self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws) self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws) self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta) self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose) self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm) logits = self.output(self.norm_out(out)) # Compute XE loss (+ label smoothing) loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training) # Quantity loss losses_auxiliary['loss_quantity'] = 0. if self.attn_type == 'mocha': # Average over all heads across all layers n_tokens_ref = tgt_mask[:, -1, :].sum(1).float() # `[B]` # NOTE: count <eos> tokens n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1)) for aws in xy_aws_layers]) # `[B]` n_tokens_pred /= len(xy_aws_layers) losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref)) # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out, self.pad) return loss, acc, ppl, losses_auxiliary
def beam_search(self, eouts, elens, params, idx2token=None, lm=None, lm_second=None, lm_bwd=None, ctc_log_probs=None, nbest=1, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, ensmbl_eouts=None, ensmbl_elens=None, ensmbl_decs=[], cache_states=True): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` params (dict): hyperparameters for decoding idx2token (): converter from index to token lm: firsh path LM lm_second: second path LM lm_bwd: first/secoding path backward LM ctc_log_probs (FloatTensor): nbest (int): exclude_eos (bool): exclude <eos> from hypothesis refs_id (list): reference list utt_ids (list): utterance id list speakers (list): speaker list ensmbl_eouts (list): list of FloatTensor ensmbl_elens (list) list of list ensmbl_decs (list): list of torch.nn.Module cache_states (bool): cache decoder states for fast decoding Returns: nbest_hyps_idx (list): length `B`, each of which contains list of N hypotheses aws (list): length `B`, each of which contains arrays of size `[H, L, T]` scores (list): """ bs, xmax, _ = eouts.size() n_models = len(ensmbl_decs) + 1 beam_width = params['recog_beam_width'] assert 1 <= nbest <= beam_width ctc_weight = params['recog_ctc_weight'] max_len_ratio = params['recog_max_len_ratio'] min_len_ratio = params['recog_min_len_ratio'] lp_weight = params['recog_length_penalty'] length_norm = params['recog_length_norm'] lm_weight = params['recog_lm_weight'] lm_weight_second = params['recog_lm_second_weight'] lm_weight_bwd = params['recog_lm_bwd_weight'] eos_threshold = params['recog_eos_threshold'] lm_state_carry_over = params['recog_lm_state_carry_over'] softmax_smoothing = params['recog_softmax_smoothing'] eps_wait = params['recog_mma_delay_threshold'] if lm is not None: assert lm_weight > 0 lm.eval() if lm_second is not None: assert lm_weight_second > 0 lm_second.eval() if lm_bwd is not None: assert lm_weight_bwd > 0 lm_bwd.eval() if ctc_log_probs is not None: assert ctc_weight > 0 ctc_log_probs = tensor2np(ctc_log_probs) nbest_hyps_idx, aws, scores = [], [], [] eos_flags = [] for b in range(bs): # Initialization per utterance lmstate = None ys = eouts.new_zeros(1, 1).fill_(self.eos).long() # For joint CTC-Attention decoding ctc_prefix_scorer = None if ctc_log_probs is not None: if self.bwd: ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b][::-1], self.blank, self.eos) else: ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos) if speakers is not None: if speakers[b] == self.prev_spk: if lm_state_carry_over and isinstance(lm, RNNLM): lmstate = self.lmstate_final self.prev_spk = speakers[b] helper = BeamSearch(beam_width, self.eos, ctc_weight, self.device_id) end_hyps = [] ymax = int(math.floor(elens[b] * max_len_ratio)) + 1 hyps = [{'hyp': [self.eos], 'ys': ys, 'cache': None, 'score': 0., 'score_attn': 0., 'score_ctc': 0., 'score_lm': 0., 'aws': [None], 'lmstate': lmstate, 'ensmbl_aws':[[None]] * (n_models - 1), 'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None, 'streamable': True, 'streaming_failed_point': 1000}] streamable_global = True for t in range(ymax): # batchfy all hypotheses for batch decoding cache = [None] * self.n_layers if cache_states and t > 0: for lth in range(self.n_layers): cache[lth] = torch.cat([beam['cache'][lth] for beam in hyps], dim=0) ys = eouts.new_zeros(len(hyps), t + 1).long() for j, beam in enumerate(hyps): ys[j, :] = beam['ys'] if t > 0: xy_aws_prev = torch.cat([beam['aws'][-1] for beam in hyps], dim=0) # `[B, n_layers, H_ma, 1, klen]` else: xy_aws_prev = None # Update LM states for shallow fusion lmstate, scores_lm = None, None if lm is not None: if hyps[0]['lmstate'] is not None: lm_hxs = torch.cat([beam['lmstate']['hxs'] for beam in hyps], dim=1) lm_cxs = torch.cat([beam['lmstate']['cxs'] for beam in hyps], dim=1) lmstate = {'hxs': lm_hxs, 'cxs': lm_cxs} y = ys[:, -1:].clone() # NOTE: this is important _, lmstate, scores_lm = lm.predict(y, lmstate) # for the main model causal_mask = eouts.new_ones(t + 1, t + 1).byte() causal_mask = torch.tril(causal_mask, out=causal_mask).unsqueeze(0).repeat([ys.size(0), 1, 1]) out = self.pos_enc(self.embed(ys)) # scaled mlen = 0 # TODO: fix later if self.memory_transformer: # NOTE: TransformerXL does not use positional encoding in the token embedding mems = self.init_memory() # adopt zero-centered offset pos_idxs = torch.arange(mlen - 1, -(t + 1) - 1, -1.0, dtype=torch.float) pos_embs = self.pos_emb(pos_idxs, self.device_id) out = self.dropout_emb(out) hidden_states = [out] n_heads_total = 0 eouts_b = eouts[b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) new_cache = [None] * self.n_layers xy_aws_all_layers = [] lth_s = self.mocha_first_layer - 1 for lth, layer in enumerate(self.layers): if self.memory_transformer: out = layer( out, causal_mask, eouts_b, None, cache=cache[lth], pos_embs=pos_embs, memory=mems[lth], u=self.u, v=self.v) hidden_states.append(out) else: out = layer( out, causal_mask, eouts_b, None, cache=cache[lth], xy_aws_prev=xy_aws_prev[:, lth - lth_s] if lth >= lth_s and t > 0 else None, eps_wait=eps_wait) new_cache[lth] = out if layer.xy_aws is not None: xy_aws_all_layers.append(layer.xy_aws) logits = self.output(self.norm_out(out)) probs = torch.softmax(logits[:, -1] * softmax_smoothing, dim=1) xy_aws_all_layers = torch.stack(xy_aws_all_layers, dim=1) # `[B, H, n_layers, L, T]` # for the ensemble ensmbl_new_cache = [] if n_models > 1: # Ensemble initialization # ensmbl_cache = [] # cache_e = [None] * self.n_layers # if cache_states and t > 0: # for lth in range(self.n_layers): # cache_e[lth] = torch.cat([beam['ensmbl_cache'][lth] for beam in hyps], dim=0) for i_e, dec in enumerate(ensmbl_decs): out_e = dec.pos_enc(dec.embed(ys)) # scaled eouts_e = ensmbl_eouts[i_e][b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) new_cache_e = [None] * dec.n_layers for lth in range(dec.n_layers): out_e, _, xy_aws_e, _, _ = dec.layers[lth](out_e, causal_mask, eouts_e, None, cache=cache[lth]) new_cache_e[lth] = out_e ensmbl_new_cache.append(new_cache_e) logits_e = dec.output(dec.norm_out(out_e)) probs += torch.softmax(logits_e[:, -1] * softmax_smoothing, dim=1) # NOTE: sum in the probability scale (not log-scale) # Ensemble in log-scale scores_attn = torch.log(probs) / n_models new_hyps = [] for j, beam in enumerate(hyps): # Attention scores total_scores_attn = beam['score_attn'] + scores_attn[j:j + 1] total_scores = total_scores_attn * (1 - ctc_weight) # Add LM score <before> top-K selection if lm is not None: total_scores_lm = beam['score_lm'] + scores_lm[j:j + 1, -1] total_scores += total_scores_lm * lm_weight else: total_scores_lm = eouts.new_zeros(1, self.vocab) total_scores_topk, topk_ids = torch.topk( total_scores, k=beam_width, dim=1, largest=True, sorted=True) # Add length penalty if lp_weight > 0: total_scores_topk += (len(beam['hyp'][1:]) + 1) * lp_weight # Add CTC score new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score( beam['hyp'], topk_ids, beam['ctc_state'], total_scores_topk, ctc_prefix_scorer) new_aws = beam['aws'] + [xy_aws_all_layers[j:j + 1, :, :, -1:]] aws_j = torch.cat(new_aws[1:], dim=3) # `[1, H, n_layers, L, T]` streaming_failed_point = beam['streaming_failed_point'] # forward direction for k in range(beam_width): idx = topk_ids[0, k].item() length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1 total_scores_topk /= length_norm_factor if idx == self.eos: # Exclude short hypotheses if len(beam['hyp']) - 1 < elens[b] * min_len_ratio: continue # EOS threshold max_score_no_eos = scores_attn[j, :idx].max(0)[0].item() max_score_no_eos = max(max_score_no_eos, scores_attn[j, idx + 1:].max(0)[0].item()) if scores_attn[j, idx].item() <= eos_threshold * max_score_no_eos: continue quantity_rate = 1. if 'mocha' in self.attn_type: n_tokens_hyp_k = t + 1 n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item() quantity_diff = n_tokens_hyp_k * n_heads_total - n_quantity_k if quantity_diff != 0: if idx == self.eos: n_tokens_hyp_k -= 1 # NOTE: do not count <eos> for streamability n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item() else: streamable_global = False if n_tokens_hyp_k * n_heads_total == 0: quantity_rate = 0 else: quantity_rate = n_quantity_k / (n_tokens_hyp_k * n_heads_total) if beam['streamable'] and not streamable_global: streaming_failed_point = t new_hyps.append( {'hyp': beam['hyp'] + [idx], 'ys': torch.cat([beam['ys'], eouts.new_zeros(1, 1).fill_(idx).long()], dim=-1), 'cache': [new_cache_l[j:j + 1] for new_cache_l in new_cache] if cache_states else cache, 'score': total_scores_topk[0, k].item(), 'score_attn': total_scores_attn[0, idx].item(), 'score_ctc': total_scores_ctc[k].item(), 'score_lm': total_scores_lm[0, idx].item(), 'aws': new_aws, 'lmstate': {'hxs': lmstate['hxs'][:, j:j + 1], 'cxs': lmstate['cxs'][:, j:j + 1]} if lmstate is not None else None, 'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None, 'ensmbl_cache': ensmbl_new_cache, 'streamable': streamable_global, 'streaming_failed_point': streaming_failed_point, 'quantity_rate': quantity_rate}) # Local pruning new_hyps_sorted = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width] # Remove complete hypotheses new_hyps, end_hyps, is_finish = helper.remove_complete_hyp( new_hyps_sorted, end_hyps, prune=True) hyps = new_hyps[:] if is_finish: break # Global pruning if len(end_hyps) == 0: end_hyps = hyps[:] elif len(end_hyps) < nbest and nbest > 1: end_hyps.extend(hyps[:nbest - len(end_hyps)]) # forward second path LM rescoring if lm_second is not None: self.lm_rescoring(end_hyps, lm_second, lm_weight_second, tag='second') # backward secodn path LM rescoring if lm_bwd is not None and lm_weight_bwd > 0: self.lm_rescoring(end_hyps, lm_bwd, lm_weight_bwd, tag='second_bwd') # Sort by score end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True) for j in range(len(end_hyps[0]['aws'][1:])): tmp = end_hyps[0]['aws'][j + 1] end_hyps[0]['aws'][j + 1] = tmp.view(1, -1, tmp.size(-2), tmp.size(-1)) # metrics for streaming infernece self.streamable = end_hyps[0]['streamable'] self.quantity_rate = end_hyps[0]['quantity_rate'] self.last_success_frame_ratio = None if idx2token is not None: if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) assert self.vocab == idx2token.vocab logger.info('=' * 200) for k in range(len(end_hyps)): if refs_id is not None: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token( end_hyps[k]['hyp'][1:][::-1] if self.bwd else end_hyps[k]['hyp'][1:])) logger.info('num tokens (hyp): %d' % len(end_hyps[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % end_hyps[k]['score']) logger.info('log prob (hyp, att): %.7f' % (end_hyps[k]['score_attn'] * (1 - ctc_weight))) if ctc_prefix_scorer is not None: logger.info('log prob (hyp, ctc): %.7f' % (end_hyps[k]['score_ctc'] * ctc_weight)) if lm is not None: logger.info('log prob (hyp, first-path lm): %.7f' % (end_hyps[k]['score_lm'] * lm_weight)) if lm_second is not None: logger.info('log prob (hyp, second-path lm): %.7f' % (end_hyps[k]['score_lm_second'] * lm_weight_second)) if lm_bwd is not None: logger.info('log prob (hyp, second-path lm-bwd): %.7f' % (end_hyps[k]['score_lm_second_bwd'] * lm_weight_bwd)) if 'mocha' in self.attn_type: logger.info('streamable: %s' % end_hyps[k]['streamable']) logger.info('streaming failed point: %d' % (end_hyps[k]['streaming_failed_point'] + 1)) logger.info('quantity rate [%%]: %.2f' % (end_hyps[k]['quantity_rate'] * 100)) logger.info('-' * 50) if 'mocha' in self.attn_type and end_hyps[0]['streaming_failed_point'] < 1000: assert not self.streamable aws_last_success = end_hyps[0]['aws'][1:][end_hyps[0]['streaming_failed_point'] - 1] rightmost_frame = max(0, aws_last_success[0, :, 0].nonzero()[:, -1].max().item()) + 1 frame_ratio = rightmost_frame * 100 / xmax self.last_success_frame_ratio = frame_ratio logger.info('streaming last success frame ratio: %.2f' % frame_ratio) # N-best list if self.bwd: # Reverse the order nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:][::-1]) for n in range(nbest)]] aws += [tensor2np(torch.cat(end_hyps[0]['aws'][1:][::-1], dim=2).squeeze(0))] else: nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)]] aws += [tensor2np(torch.cat(end_hyps[0]['aws'][1:], dim=2).squeeze(0))] scores += [[end_hyps[n]['score_attn'] for n in range(nbest)]] # Check <eos> eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)]) # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.bwd: nbest_hyps_idx = [[nbest_hyps_idx[b][n][1:] if eos_flags[b][n] else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)] else: nbest_hyps_idx = [[nbest_hyps_idx[b][n][:-1] if eos_flags[b][n] else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)] # Store ASR/LM state if len(end_hyps) > 0: self.lmstate_final = end_hyps[0]['lmstate'] return nbest_hyps_idx, aws, scores
def greedy(self, eouts, elens, max_len_ratio, idx2token, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, cache_states=True): """Greedy decoding. Args: eouts (FloatTensor): `[B, T, enc_units]` elens (IntTensor): `[B]` max_len_ratio (int): maximum sequence length of tokens idx2token (): converter from index to token exclude_eos (bool): exclude <eos> from hypothesis refs_id (list): reference list utt_ids (list): utterance id list speakers (list): speaker list cache_states (bool): Returns: hyps (list): length `B`, each of which contains arrays of size `[L]` aw (list): length `B`, each of which contains arrays of size `[L, T]` """ bs, xtime = eouts.size()[:2] ys = eouts.new_zeros(bs, 1).fill_(self.eos).long() cache = [None] * self.n_layers hyps_batch = [] ylens = torch.zeros(bs).int() eos_flags = [False] * bs ymax = int(math.floor(xtime * max_len_ratio)) + 1 for t in range(ymax): causal_mask = eouts.new_ones(t + 1, t + 1).byte() causal_mask = torch.tril(causal_mask, out=causal_mask).unsqueeze(0) new_cache = [None] * self.n_layers out = self.pos_enc(self.embed(ys)) # scaled for lth, layer in enumerate(self.layers): out = layer(out, causal_mask, eouts, None, cache=cache[lth]) new_cache[lth] = out if cache_states: cache = new_cache[:] # Pick up 1-best y = self.output(self.norm_out(out))[:, -1:].argmax(-1) hyps_batch += [y] # Count lengths of hypotheses for b in range(bs): if not eos_flags[b]: if y[b].item() == self.eos: eos_flags[b] = True ylens[b] += 1 # include <eos> # Break if <eos> is outputed in all mini-batch if sum(eos_flags) == bs: break if t == ymax - 1: break ys = torch.cat([ys, y], dim=-1) # Concatenate in L dimension hyps_batch = tensor2np(torch.cat(hyps_batch, dim=1)) xy_aws = tensor2np(layer.xy_aws.transpose(1, 2).transpose(2, 3)) # Truncate by the first <eos> (<sos> in case of the backward decoder) if self.bwd: # Reverse the order hyps = [hyps_batch[b, :ylens[b]][::-1] for b in range(bs)] aws = [xy_aws[b, :, :ylens[b]][::-1] for b in range(bs)] else: hyps = [hyps_batch[b, :ylens[b]] for b in range(bs)] aws = [xy_aws[b, :, :ylens[b]] for b in range(bs)] # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.bwd: hyps = [hyps[b][1:] if eos_flags[b] else hyps[b] for b in range(bs)] else: hyps = [hyps[b][:-1] if eos_flags[b] else hyps[b] for b in range(bs)] for b in range(bs): if utt_ids is not None: logger.debug('Utt-id: %s' % utt_ids[b]) if refs_id is not None and self.vocab == idx2token.vocab: logger.debug('Ref: %s' % idx2token(refs_id[b])) if self.bwd: logger.debug('Hyp: %s' % idx2token(hyps[b][::-1])) else: logger.debug('Hyp: %s' % idx2token(hyps[b])) return hyps, aws
def forward_att(self, eouts, elens, ys, return_logits=False, teacher_logits=None, trigger_points=None): """Compute XE loss for the Transformer decoder. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): length `B`, each of which contains a list of size `[L]` return_logits (bool): return logits for knowledge distillation teacher_logits (FloatTensor): `[B, L, vocab]` trigger_points (IntTensor): `[B, T]` Returns: loss (FloatTensor): `[1]` acc (float): accuracy for token prediction ppl (float): perplexity loss_quantity (FloatTensor): `[1]` loss_headdiv (FloatTensor): `[1]` loss_latency (FloatTensor): `[1]` """ # Append <sos> and <eos> ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.eos, self.pad, self.bwd) if not self.training: self.data_dict['elens'] = tensor2np(elens) self.data_dict['ylens'] = tensor2np(ylens) self.data_dict['ys'] = tensor2np(ys_out) # Create target self-attention mask xmax = eouts.size(1) bs, ymax = ys_in.size()[:2] mlen = 0 tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1]) causal_mask = tgt_mask.new_ones(ymax, ymax).byte() causal_mask = torch.tril(causal_mask, diagonal=0 + mlen, out=causal_mask).unsqueeze(0) tgt_mask = tgt_mask & causal_mask # `[B, L (query), L (key)]` # Create source-target mask src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat([1, ymax, 1]) # `[B, L, T]` # external LM integration lmout = None if self.lm is not None: self.lm.eval() with torch.no_grad(): lmout, lmstate, _ = self.lm.predict(ys_in, None) lmout = self.lm_output_proj(lmout) out = self.pos_enc(self.embed(ys_in)) # scaled mems = self.init_memory() pos_embs = None if self.memory_transformer: out = self.dropout_emb(out) # NOTE: TransformerXL does not use positional encoding in the token embedding # adopt zero-centered offset pos_idxs = torch.arange(mlen - 1, -ymax - 1, -1.0, dtype=torch.float) pos_embs = self.pos_emb(pos_idxs, self.device_id) hidden_states = [out] xy_aws_layers = [] for lth, (mem, layer) in enumerate(zip(mems, self.layers)): out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout, pos_embs=pos_embs, memory=mem, u=self.u, v=self.v) if lth < self.n_layers - 1: hidden_states.append(out) # NOTE: outputs from the last layer is not used for momory # Attention padding xy_aws = layer.xy_aws if xy_aws is not None and 'mocha' in self.attn_type: tgt_mask_v2 = (ys_out != self.pad).unsqueeze(1).unsqueeze(3) # `[B, 1, L, 1]` xy_aws = xy_aws.masked_fill_(tgt_mask_v2.repeat([1, xy_aws.size(1), 1, xmax]) == 0, 0) # NOTE: attention padding is quite effective for quantity loss xy_aws_layers.append(xy_aws.clone()) if not self.training: if layer.yy_aws is not None: self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws) if layer.xy_aws is not None: self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws) if layer.xy_aws_beta is not None: self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta) if layer.xy_aws_p_choose is not None: self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose) if layer.yy_aws_lm is not None: self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm) logits = self.output(self.norm_out(out)) # for knowledge distillation if return_logits: return logits # Compute XE loss (+ label smoothing) loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training) losses_auxiliary = {} # Quantity loss losses_auxiliary['loss_quantity'] = 0. if 'mocha' in self.attn_type: # Average over all heads across all layers n_tokens_ref = tgt_mask[:, -1, :].sum(1).float() # `[B]` # NOTE: count <eos> tokens n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1)) for aws in xy_aws_layers]) # `[B]` n_tokens_pred /= len(xy_aws_layers) losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref)) # Compute token-level accuracy in teacher-forcing acc = compute_accuracy(logits, ys_out, self.pad) return loss, acc, ppl, losses_auxiliary
def greedy(self, eouts, elens, max_len_ratio, exclude_eos=False): """Greedy decoding in the inference stage. Args: eouts (FloatTensor): `[B, T, enc_units]` elens (list): A list of length `[B]` max_len_ratio (int): maximum sequence length of tokens exclude_eos (bool): Returns: best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]` aw (list): A list of length `[B]`, which contains arrays of size `[L, T]` """ bs, max_xlen, d_model = eouts.size() # Start from <sos> (<eos> in case of the backward decoder) ys = eouts.new_zeros(bs, 1).fill_(self.eos).long() yy_mask = None best_hyps_tmp = [] ylens = np.zeros((bs, ), dtype=np.int32) yy_aws_tmp = [None] * bs xy_aws_tmp = [None] * bs eos_flags = [False] * bs for t in range(int(np.floor(max_xlen * max_len_ratio)) + 1): # Make source-target attention mask yx_mask = eouts.new_ones(bs, t + 1, max_xlen) for b in range(bs): if elens[b] < max_xlen: yx_mask[b, :, elens[b]:] = 0 # Add positional embedding out = self.embed(ys) * (self.d_model**0.5) if self.pe_type: out = self.pos_emb_out(out) for l in range(self.n_layers): out, yy_aw, xy_aw = self.layers[l](eouts, out, yx_mask, yy_mask) # xy_aw: `[B, head, T, L]` out = self.layer_norm_top(out) logits_t = self.output(out) # Pick up 1-best y = logits_t.detach().argmax(-1)[:, -1:] best_hyps_tmp += [y] # Count lengths of hypotheses for b in range(bs): if not eos_flags[b]: if y[b].item() == self.eos: eos_flags[b] = True yy_aws_tmp[b] = yy_aw[b:b + 1] # TODO: fix this xy_aws_tmp[b] = xy_aw[b:b + 1] ylens[b] += 1 # NOTE: include <eos> # Break if <eos> is outputed in all mini-bs if sum(eos_flags) == bs: break ys = torch.cat([ys, y], dim=-1) # Concatenate in L dimension best_hyps_tmp = torch.cat(best_hyps_tmp, dim=1) # xy_aws_tmp = torch.stack(xy_aws_tmp, dim=0) # Convert to numpy best_hyps_tmp = tensor2np(best_hyps_tmp) # xy_aws_tmp = tensor2np(xy_aws_tmp) # if self.score.attn_nheads > 1: # xy_aws_tmp = xy_aws_tmp[:, :, :, 0] # # TODO(hirofumi): fix for MHA # Truncate by the first <eos> (<sos> in case of the backward decoder) if self.backward: # Reverse the order best_hyps = [best_hyps_tmp[b, :ylens[b]][::-1] for b in range(bs)] # aws = [xy_aws_tmp[b, :ylens[b]][::-1] for b in range(bs)] else: best_hyps = [best_hyps_tmp[b, :ylens[b]] for b in range(bs)] # aws = [xy_aws_tmp[b, :ylens[b]] for b in range(bs)] # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.backward: best_hyps = [ best_hyps[b][1:] if eos_flags[b] else best_hyps[b] for b in range(bs) ] else: best_hyps = [ best_hyps[b][:-1] if eos_flags[b] else best_hyps[b] for b in range(bs) ] # return best_hyps, aws return best_hyps, None
def forward(self, xs, xlens, task, streaming=False, lookback=False, lookahead=False): """Forward pass. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (InteTensor): `[B]` (on CPU) task (str): ys/ys_sub1/ys_sub2 streaming (bool): streaming encoding lookback (bool): truncate leftmost frames for lookback in CNN context lookahead (bool): truncate rightmost frames for lookahead in CNN context Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (InteTensor): `[B]` (on CPU) """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } bs, xmax = xs.size()[:2] n_chunks = 0 unidir = self.unidir lc_bidir = self.lc_bidir N_l, N_c, N_r = self.chunk_size_left, self.chunk_size_current, self.chunk_size_right if streaming and self.streaming_type == 'mask': assert xmax <= N_c elif streaming and self.streaming_type == 'reshape': assert xmax <= (N_l + N_c + N_r) if lc_bidir: if self.streaming_type == 'mask' and not streaming: xs = chunkwise(xs, 0, N_c, 0, padding=True) # `[B * n_chunks, N_c, idim]` # NOTE: CNN consumes inputs in the current chunk to avoid extra lookahead latency # That is, CNN outputs are independent on chunk boundary elif self.streaming_type == 'reshape': xs = chunkwise(xs, N_l, N_c, N_r, padding=not streaming ) # `[B * n_chunks, N_l+N_c+N_r, idim]` n_chunks = xs.size(0) // bs assert bs * n_chunks == xs.size(0) if streaming: assert n_chunks == 1, xs.size() if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks xs, xlens = self.conv(xs, xlens, lookback=False if lc_bidir else lookback, lookahead=False if lc_bidir else lookahead) # NOTE: CNN lookahead surpassing a chunk is not allowed in chunkwise processing N_l = max(0, N_l // self.conv.subsampling_factor) N_c = N_c // self.conv.subsampling_factor N_r = N_r // self.conv.subsampling_factor if lc_bidir: # Do nothing in the streaming mode if self.streaming_type == 'mask' and not streaming: # back to the original shape (during training only) xs = xs.contiguous().view( bs, -1, xs.size(2))[:, :xlens.max()] # `[B, emax, d_model]` elif streaming: xs = xs[:, :xlens.max()] # for unidirectional if self.enc_type == 'conv': eouts['ys']['xs'] = xs eouts['ys']['xlens'] = xlens return eouts if not streaming: self.reset_cache() n_hist = self.cache[0]['input_san'].size( 1) if streaming and self.cache[0] is not None else 0 # positional encoding if self.pe_type in ['relative', 'relative_xl']: xs = xs * self.scale # NOTE: first layer only rel_pos_embs = self.pos_emb(xs, mlen=n_hist) else: xs = self.pos_enc(xs, scale=True, offset=max(0, n_hist)) rel_pos_embs = None new_cache = [None] * self.n_layers if lc_bidir: # chunkwise streaming encoder if self.streaming_type == 'reshape': xx_mask = None # NOTE: no mask to avoid masking all frames in a chunk elif self.streaming_type == 'mask': if streaming: n_chunks = math.ceil((xlens.max().item() + n_hist) / N_c) xx_mask = make_chunkwise_san_mask(xs, xlens + n_hist, N_l, N_c, n_chunks) for lth, layer in enumerate(self.layers): xs, cache = layer(xs, xx_mask, cache=self.cache[lth], pos_embs=rel_pos_embs, u_bias=self.u_bias, v_bias=self.v_bias) if self.streaming_type == 'mask': new_cache[lth] = cache if not self.training and not streaming: if self.streaming_type == 'reshape': n_heads = layer.xx_aws.size(1) xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c, N_l:N_l + N_c] xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c) emax = xlens.max().item() xx_aws_center = xx_aws.new_zeros( bs, n_heads, emax, emax) for chunk_idx in range(n_chunks): offset = chunk_idx * N_c emax_chunk = xx_aws_center[:, :, offset:offset + N_c].size(2) xx_aws_chunk = xx_aws[:, chunk_idx, :, : emax_chunk, :emax_chunk] xx_aws_center[:, :, offset:offset + N_c, offset:offset + N_c] = xx_aws_chunk self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws_center) elif self.streaming_type == 'mask': self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) self.data_dict['elens%d' % lth] = tensor2np(xlens) if self.subsample is not None: xs, xlens = self.subsample[lth](xs, xlens) N_l = max(0, N_l // self.subsample[lth].factor) N_c = N_c // self.subsample[lth].factor N_r = N_r // self.subsample[lth].factor if self.pe_type in ['relative', 'relative_xl']: rel_pos_embs = self.pos_emb(xs) if self.streaming_type == 'mask': xx_mask = make_chunkwise_san_mask( xs, xlens, N_l, N_c, n_chunks) # Extract the center region if self.streaming_type == 'reshape': xs = xs[:, N_l:N_l + N_c] # `[B * n_chunks, N_c, d_model]` xs = xs.contiguous().view(bs, -1, xs.size(2)) xs = xs[:, :xlens.max()] else: xx_mask = make_san_mask(xs, xlens + n_hist, unidir, self.lookaheads[0]) for lth, layer in enumerate(self.layers): xs, cache = layer(xs, xx_mask, cache=self.cache[lth], pos_embs=rel_pos_embs, u_bias=self.u_bias, v_bias=self.v_bias) new_cache[lth] = cache if not self.training and not streaming: self.aws_dict['xx_aws_layer%d' % lth] = tensor2np( layer.xx_aws) self.data_dict['elens%d' % lth] = tensor2np(xlens) # Pick up outputs in the sub task before the projection layer if lth == self.n_layers_sub1 - 1: xs_sub1 = self.sub_module(xs, xx_mask, lth, rel_pos_embs, 'sub1') xlens_sub1 = xlens.clone() if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens_sub1 return eouts if lth == self.n_layers_sub2 - 1: xs_sub2 = self.sub_module(xs, xx_mask, lth, rel_pos_embs, 'sub2') xlens_sub2 = xlens.clone() if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens_sub2 return eouts if lth < len(self.layers) - 1: if self.subsample is not None and self.subsample[ lth].factor > 1: xs, xlens = self.subsample[lth](xs, xlens) n_hist = self.cache[lth + 1]['input_san'].size( 1) if streaming and self.cache[ lth + 1] is not None else 0 if self.pe_type in ['relative', 'relative_xl']: rel_pos_embs = self.pos_emb(xs, mlen=n_hist) xx_mask = make_san_mask(xs, xlens + n_hist, unidir, self.lookaheads[lth + 1]) elif self.lookaheads[lth] != self.lookaheads[lth + 1]: xx_mask = make_san_mask(xs, xlens + n_hist, unidir, self.lookaheads[lth + 1]) xs = self.norm_out(xs) if streaming: self.cache = new_cache # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1'][ 'xlens'] = xs_sub1, xlens_sub1 if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2'][ 'xlens'] = xs_sub2, xlens_sub2 return eouts
def beam_search(self, eouts, elens, params, idx2token, lm=None, lm_second=None, lm_second_bwd=None, ctc_log_probs=None, nbest=1, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, ensmbl_eouts=None, ensmbl_elens=None, ensmbl_decs=[]): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (IntTensor): `[B]` params (dict): recog_beam_width (int): size of beam recog_max_len_ratio (int): maximum sequence length of tokens recog_min_len_ratio (float): minimum sequence length of tokens recog_length_penalty (float): length penalty recog_coverage_penalty (float): coverage penalty recog_coverage_threshold (float): threshold for coverage penalty recog_lm_weight (float): weight of LM score idx2token (): converter from index to token lm: firsh path LM lm_second: second path LM lm_second_bwd: secoding path backward LM ctc_log_probs (FloatTensor): nbest (int): exclude_eos (bool): exclude <eos> from hypothesis refs_id (list): reference list utt_ids (list): utterance id list speakers (list): speaker list ensmbl_eouts (list): list of FloatTensor ensmbl_elens (list) list of list ensmbl_decs (list): list of torch.nn.Module Returns: nbest_hyps_idx (list): A list of length `[B]`, which contains list of N hypotheses aws: dummy scores: dummy """ bs = eouts.size(0) beam_width = params['recog_beam_width'] ctc_weight = params['recog_ctc_weight'] lm_weight = params['recog_lm_weight'] lm_weight_second = params['recog_lm_second_weight'] lm_weight_second_bwd = params['recog_lm_bwd_weight'] asr_state_carry_over = params['recog_asr_state_carry_over'] lm_state_carry_over = params['recog_lm_state_carry_over'] if lm is not None: assert lm_weight > 0 lm.eval() if lm_second is not None: assert lm_weight_second > 0 lm_second.eval() if lm_second_bwd is not None: assert lm_weight_second_bwd > 0 lm_second_bwd.eval() if ctc_log_probs is not None: assert ctc_weight > 0 ctc_log_probs = tensor2np(ctc_log_probs) nbest_hyps_idx = [] eos_flags = [] for b in range(bs): # Initialization per utterance y = eouts.new_zeros(bs, 1).fill_(self.eos).long() y_emb = self.dropout_emb(self.embed(y)) dout, dstate = self.recurrency(y_emb, None) lmstate = None # For joint CTC-Attention decoding ctc_prefix_scorer = None if ctc_log_probs is not None: ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos) if speakers is not None: if speakers[b] == self.prev_spk: if lm_state_carry_over and isinstance(lm, RNNLM): lmstate = self.lmstate_final self.prev_spk = speakers[b] helper = BeamSearch(beam_width, self.eos, ctc_weight, self.device_id) end_hyps = [] hyps = [{ 'hyp': [self.eos], 'ref_id': [self.eos], 'score': 0., 'score_rnnt': 0., 'score_lm': 0., 'score_ctc': 0., 'dout': dout, 'dstate': dstate, 'lmstate': lmstate, 'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None }] for t in range(elens[b]): # preprocess for batch decoding douts = torch.cat([beam['dout'] for beam in hyps], dim=0) outs = self.joint( eouts[b:b + 1, t:t + 1].repeat([douts.size(0), 1, 1]), douts) scores_rnnt = torch.log_softmax(outs.squeeze(2).squeeze(1), dim=-1) # Update LM states for shallow fusion y = eouts.new_zeros(len(hyps), 1).long() for j, beam in enumerate(hyps): y[j, 0] = beam['hyp'][-1] lmstate, scores_lm = None, None if lm is not None: if hyps[0]['lmstate'] is not None: lm_hxs = torch.cat( [beam['lmstate']['hxs'] for beam in hyps], dim=1) lm_cxs = torch.cat( [beam['lmstate']['cxs'] for beam in hyps], dim=1) lmstate = {'hxs': lm_hxs, 'cxs': lm_cxs} lmout, lmstate, scores_lm = lm.predict(y, lmstate) new_hyps = [] for j, beam in enumerate(hyps): dout = douts[j:j + 1] dstate = beam['dstate'] lmstate = beam['lmstate'] # Attention scores total_scores_rnnt = beam['score_rnnt'] + scores_rnnt[j:j + 1] total_scores = total_scores_rnnt * (1 - ctc_weight) # Add LM score <after> top-K selection total_scores_topk, topk_ids = torch.topk(total_scores, k=beam_width, dim=-1, largest=True, sorted=True) if lm is not None: total_scores_lm = beam['score_lm'] + scores_lm[ j, -1, topk_ids[0]] total_scores_topk += total_scores_lm * lm_weight else: total_scores_lm = eouts.new_zeros(beam_width) # Add CTC score new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score( beam['hyp'], topk_ids, beam['ctc_state'], total_scores_topk, ctc_prefix_scorer) for k in range(beam_width): idx = topk_ids[0, k].item() if idx == self.blank: beam['score'] = total_scores_topk[0, k].item() beam['score_rnnt'] = total_scores_topk[0, k].item() new_hyps.append(beam.copy()) continue # skip blank-dominant frames # if total_scores_topk[0, self.blank].item() > 0.7: # continue # Update prediction network only when predicting non-blank labels hyp_id = beam['hyp'] + [idx] hyp_str = ' '.join(list(map(str, hyp_id))) # if hyp_str in self.state_cache.keys(): # # from cache # dout = self.state_cache[hyp_str]['dout'] # new_dstate = self.state_cache[hyp_str]['dstate'] # lmstate = self.state_cache[hyp_str]['lmstate'] # else: y = eouts.new_zeros(1, 1).fill_(idx).long() y_emb = self.dropout_emb(self.embed(y)) dout, new_dstate = self.recurrency(y_emb, dstate) # store in cache self.state_cache[hyp_str] = { 'dout': dout, 'dstate': new_dstate, 'lmstate': { 'hxs': lmstate['hxs'][:, j:j + 1], 'cxs': lmstate['cxs'][:, j:j + 1] } if lmstate is not None else None, } new_hyps.append({ 'hyp': hyp_id, 'score': total_scores_topk[0, k].item(), 'score_rnnt': total_scores_rnnt[0, idx].item(), 'score_ctc': total_scores_ctc[k].item(), 'score_lm': total_scores_lm[k].item(), 'dout': dout, 'dstate': new_dstate, 'lmstate': { 'hxs': lmstate['hxs'][:, j:j + 1], 'cxs': lmstate['cxs'][:, j:j + 1] } if lmstate is not None else None, 'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None }) # Merge hypotheses having the same token sequences new_hyps_merged = {} for beam in new_hyps: hyp_str = ' '.join(list(map(str, beam['hyp']))) if hyp_str not in new_hyps_merged.keys(): new_hyps_merged[hyp_str] = beam elif hyp_str in new_hyps_merged.keys(): if beam['score'] > new_hyps_merged[hyp_str]['score']: new_hyps_merged[hyp_str] = beam new_hyps = [v for v in new_hyps_merged.values()] # Local pruning new_hyps_tmp = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width] # Remove complete hypotheses new_hyps = [] for hyp in new_hyps_tmp: new_hyps += [hyp] if len(end_hyps) >= beam_width: end_hyps = end_hyps[:beam_width] break hyps = new_hyps[:] # Global pruning if len(end_hyps) == 0: end_hyps = hyps[:] elif len(end_hyps) < nbest and nbest > 1: end_hyps.extend(hyps[:nbest - len(end_hyps)]) # forward second path LM rescoring if lm_second is not None: self.lm_rescoring(end_hyps, lm_second, lm_weight_second, tag='second') # backward secodn path LM rescoring if lm_second_bwd is not None: self.lm_rescoring(end_hyps, lm_second_bwd, lm_weight_second_bwd, tag='second_rev') end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True) # Reset state cache self.state_cache = OrderedDict() if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) if idx2token is not None: logger.info('=' * 200) for k in range(len(end_hyps)): if refs_id is not None and self.vocab == idx2token.vocab: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % end_hyps[k]['score']) if ctc_log_probs is not None: logger.info('log prob (hyp, ctc): %.7f' % (end_hyps[k]['score_ctc'])) if lm is not None: logger.info('log prob (hyp, lm): %.7f' % (end_hyps[k]['score_lm'])) logger.info('-' * 50) # N-best list nbest_hyps_idx += [[ np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest) ]] # Check <eos> eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)]) return nbest_hyps_idx, None, None
def greedy(self, eouts, elens, max_len_ratio, exclude_eos=False, idx2token=None, refs_id=None, speakers=None, oracle=False): """Greedy decoding in the inference stage (used only for evaluation during training). Args: eouts (FloatTensor): `[B, T, enc_units]` elens (IntTensor): `[B]` max_len_ratio (int): maximum sequence length of tokens exclude_eos (bool): idx2token (): refs_id (list): speakers (list): oracle (bool): Returns: best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]` aw (list): A list of length `[B]`, which contains arrays of size `[L, T]` """ bs, xmax = eouts.size()[:2] # Start from <sos> (<eos> in case of the backward decoder) ys_all = eouts.new_zeros(bs, 1).fill_(self.eos).long() # TODO(hirofumi): Create the source-target mask for batch decoding best_hyps_batch = [] ylens = torch.zeros(bs).int() yy_aws_tmp = [None] * bs xy_aws_tmp = [None] * bs eos_flags = [False] * bs for t in range(int(np.floor(xmax * max_len_ratio)) + 1): # Create the self-attention mask yy_mask = make_pad_mask(ylens + 1, self.device_id).unsqueeze(1).expand( bs, t + 1, t + 1) yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, t + 1, t + 1) subsequent_mask = torch.tril(yy_mask.new_ones( (t + 1, t + 1)).byte(), diagonal=0) subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand( bs, self.attn_n_heads, t + 1, t + 1) yy_mask = yy_mask & subsequent_mask # Create the source-target mask xmax = eouts.size(1) x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand( bs, t + 1, xmax) y_mask = make_pad_mask(ylens + 1, self.device_id).unsqueeze(2).expand( bs, t + 1, xmax) xy_mask = (x_mask * y_mask).unsqueeze(1).expand( bs, self.attn_n_heads, t + 1, xmax) out = self.pos_enc(self.embed(ys_all)) for l in range(self.n_layers): out, yy_aws, xy_aws = self.layers[l](out, yy_mask, eouts, xy_mask) out = self.norm_out(out) # Pick up 1-best y = self.output(out).argmax(-1)[:, -1:] best_hyps_batch += [y] # Count lengths of hypotheses for b in range(bs): if not eos_flags[b]: if y[b].item() == self.eos: eos_flags[b] = True yy_aws_tmp[b] = yy_aws[b:b + 1] # TODO: fix this xy_aws_tmp[b] = xy_aws[b:b + 1] ylens[b] += 1 # NOTE: include <eos> # Break if <eos> is outputed in all mini-bs if sum(eos_flags) == bs: break ys_all = torch.cat([ys_all, y], dim=-1) # Concatenate in L dimension best_hyps_batch = torch.cat(best_hyps_batch, dim=1) # xy_aws_tmp = torch.stack(xy_aws_tmp, dim=0) # Convert to numpy best_hyps_batch = tensor2np(best_hyps_batch) # xy_aws_tmp = tensor2np(xy_aws_tmp) # if self.score.attn_n_heads > 1: # xy_aws_tmp = xy_aws_tmp[:, :, :, 0] # # TODO(hirofumi): fix for MHA # Truncate by the first <eos> (<sos> in case of the backward decoder) if self.bwd: # Reverse the order best_hyps = [ best_hyps_batch[b, :ylens[b]][::-1] for b in range(bs) ] # aws = [xy_aws_tmp[b, :ylens[b]][::-1] for b in range(bs)] else: best_hyps = [best_hyps_batch[b, :ylens[b]] for b in range(bs)] # aws = [xy_aws_tmp[b, :ylens[b]] for b in range(bs)] # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.bwd: best_hyps = [ best_hyps[b][1:] if eos_flags[b] else best_hyps[b] for b in range(bs) ] else: best_hyps = [ best_hyps[b][:-1] if eos_flags[b] else best_hyps[b] for b in range(bs) ] # return best_hyps, aws return best_hyps, None
def forward_att(self, eouts, elens, ys, return_logits=False): """Compute XE loss for the sequence-to-sequence model. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` ys (list): A list of length `[B]`, which contains a list of size `[L]` return_logits (bool): return logits for knowledge distillation Returns: loss (FloatTensor): `[1]` acc (float): ppl (float): """ bs = eouts.size(0) # Append <sos> and <eos> eos = eouts.new_zeros(1).fill_(self.eos).long() ys = [ np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int64), self.device_id) for y in ys ] ylens = np2tensor( np.fromiter([y.size(0) + 1 for y in ys], dtype=np.int32)) # +1 for <eos> ys_in_pad = pad_list([torch.cat([eos, y], dim=0) for y in ys], self.pad) ys_out_pad = pad_list([torch.cat([y, eos], dim=0) for y in ys], self.pad) # Create the self-attention mask bs, ymax = ys_in_pad.size()[:2] yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand( bs, ymax, ymax) yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax, ymax) subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(), diagonal=0) subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand( bs, self.attn_n_heads, ymax, ymax) yy_mask = yy_mask & subsequent_mask # Create the source-target mask xmax = eouts.size(1) x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand( bs, ymax, xmax) y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).expand( bs, ymax, xmax) xy_mask = (x_mask * y_mask).unsqueeze(1).expand( bs, self.attn_n_heads, ymax, xmax) ys_emb = self.pos_enc(self.embed(ys_in_pad)) for l in range(self.n_layers): ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts, xy_mask) if not self.training: setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws)) setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws)) logits = self.norm_out(ys_emb) if self.adaptive_softmax is None: logits = self.output(logits) if return_logits: return logits # Compute XE sequence loss if self.adaptive_softmax is None: if self.lsm_prob > 0 and self.training: # Label smoothing loss = cross_entropy_lsm(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), self.lsm_prob, self.pad) else: loss = F.cross_entropy(logits.view((-1, logits.size(2))), ys_out_pad.view(-1), ignore_index=self.pad, size_average=True) # Focal loss if self.focal_loss_weight > 0: fl = focal_loss(logits, ys_out_pad, ylens, alpha=self.focal_loss_weight, gamma=self.focal_loss_gamma) loss = loss * ( 1 - self.focal_loss_weight) + fl * self.focal_loss_weight else: loss = self.adaptive_softmax(logits.view((-1, logits.size(2))), ys_out_pad.view(-1)).loss # Compute token-level accuracy in teacher-forcing if self.adaptive_softmax is None: acc = compute_accuracy(logits, ys_out_pad, self.pad) else: acc = compute_accuracy( self.adaptive_softmax.log_prob( logits.view((-1, logits.size(2)))), ys_out_pad, self.pad) ppl = min(np.exp(loss.item()), np.inf) # scale loss for CTC loss *= ylens.float().mean() return loss, acc, ppl
def beam_search(self, eouts, elens, params, rnnlm, nbest=1, exclude_eos=False, id2token=None, refs=None): """Beam search decoding in the inference stage. Args: eouts (FloatTensor): `[B, T, dec_units]` elens (list): A list of length `[B]` params (dict): beam_width (int): the size of beam max_len_ratio (int): the maximum sequence length of tokens min_len_ratio (float): the minimum sequence length of tokens length_penalty (float): length penalty coverage_penalty (float): coverage penalty coverage_threshold (float): threshold for coverage penalty rnnlm_weight (float): the weight of RNNLM score rnnlm (torch.nn.Module): nbest (int): exclude_eos (bool): id2token (): converter from index to token refs (): Returns: nbest_hyps (list): A list of length `[B]`, which contains list of n hypotheses aws (list): A list of length `[B]`, which contains arrays of size `[L, T]` scores (list): """ bs, _, enc_nunits = eouts.size() device_id = eouts.get_device() # For cold fusion if params['rnnlm_weight'] > 0 and not self.cold_fusion: assert self.rnnlm_cf self.rnnlm_cf.eval() # For shallow fusion if rnnlm is not None: rnnlm.eval() if self.backward: sos, eos = self.eos, self.sos else: sos, eos = self.sos, self.eos nbest_hyps, aws, scores = [], [], [] eos_flags = [] for b in range(bs): # Initialization per utterance dout, (hx_list, cx_list) = self.init_dec_state(1, self.nlayers, device_id, eouts[b:b + 1], elens[b:b + 1]) _dout, _dstate = self.init_dec_state(1, 1, device_id, eouts[b:b + 1], elens[b:b + 1]) context = eouts.new_zeros(1, 1, enc_nunits) self.score.reset() complete = [] beam = [{ 'hyp': [sos], 'score': 0, 'scores': [0], 'score_raw': 0, 'dout': dout, 'hx_list': hx_list, 'cx_list': cx_list, 'context': context, 'aws': [None], 'rnnlm_hx_list': None, 'rnnlm_cx_list': None, 'prev_cov': 0, '_dout': _dout, '_dstate': _dstate }] for t in range( int(math.floor(elens[b] * params['max_len_ratio'])) + 1): new_beam = [] for i_beam in range(len(beam)): # Recurrency y = eouts.new_zeros(1, 1).fill_( beam[i_beam]['hyp'][-1]).long() y_emb = self.embed(y) dout, (hx_list, cx_list), _dout, _dstate = self.recurrency( y_emb, beam[i_beam]['context'], (beam[i_beam]['hx_list'], beam[i_beam]['cx_list']), beam[i_beam]['_dstate']) # Score context, aw = self.score(eouts[b:b + 1, :elens[b]], elens[b:b + 1], dout, beam[i_beam]['aws'][-1]) if self.rnnlm_cf: # Update RNNLM states for cold fusion y_lm = eouts.new_zeros(1, 1).fill_( beam[i_beam]['hyp'][-1]).long() y_lm_emb = self.rnnlm_cf.embed(y_lm).squeeze(1) logits_lm_t, lm_out, rnnlm_state = self.rnnlm_cf.predict( y_lm_emb, (beam[i_beam]['rnnlm_hx_list'], beam[i_beam]['rnnlm_cx_list'])) elif rnnlm is not None: # Update RNNLM states for shallow fusion y_lm = eouts.new_zeros(1, 1).fill_( beam[i_beam]['hyp'][-1]).long() y_lm_emb = rnnlm.embed(y_lm).squeeze(1) logits_lm_t, lm_out, rnnlm_state = rnnlm.predict( y_lm_emb, (beam[i_beam]['rnnlm_hx_list'], beam[i_beam]['rnnlm_cx_list'])) else: logits_lm_t, lm_out, rnnlm_state = None, None, None # Generate attentional_t = self.generate(context, dout, logits_lm_t, lm_out) if self.rnnlm_init and self.internal_lm: # Residual connection attentional_t += _dout logits_t = self.output(attentional_t) # Path through the softmax layer & convert to log-scale log_probs = F.log_softmax(logits_t.squeeze(1), dim=1) # log-prob-level # log_probs = logits_t.squeeze(1) # logits-level # NOTE: `[1 (B), 1, vocab]` -> `[1 (B), vocab]` # Pick up the top-k scores log_probs_topk, indices_topk = torch.topk( log_probs, k=params['beam_width'], dim=1, largest=True, sorted=True) for k in range(params['beam_width']): # Exclude short hypotheses if indices_topk[0, k].item() == eos and len( beam[i_beam] ['hyp']) < elens[b] * params['min_len_ratio']: continue # Add length penalty score_raw = beam[i_beam]['score_raw'] + log_probs_topk[ 0, k].item() score = score_raw + params['length_penalty'] # Add coverage penalty if params['coverage_penalty'] > 0: # Recompute converage penalty in each step score -= beam[i_beam]['prev_cov'] * params[ 'coverage_penalty'] aw_stack = torch.stack(beam[i_beam]['aws'][1:] + [aw], dim=-1) cov_sum = aw_stack.detach().cpu().numpy() if params['coverage_threshold'] == 0: cov_sum = np.sum(cov_sum) / self.score.nheads else: cov_sum = np.sum(cov_sum[np.where( cov_sum > params['coverage_threshold'])[0]] ) / self.score.nheads score += cov_sum * params['coverage_penalty'] else: cov_sum = 0 # Add RNNLM score if params['rnnlm_weight'] > 0: lm_log_probs = F.log_softmax( logits_lm_t.squeeze(1), dim=1) assert log_probs.size() == lm_log_probs.size() score += lm_log_probs[0, indices_topk[ 0, k].item()].item() * params['rnnlm_weight'] new_beam.append({ 'hyp': beam[i_beam]['hyp'] + [indices_topk[0, k].item()], 'score': score, 'scores': beam[i_beam]['scores'] + [score], 'score_raw': score_raw, 'score_lm': 0, # TODO(hirofumi): 'score_lp': 0, # TODO(hirofumi): 'score_cp': 0, # TODO(hirofumi): 'hx_list': hx_list[:], 'cx_list': cx_list[:] if cx_list is not None else None, 'dout': dout, 'context': context, 'aws': beam[i_beam]['aws'] + [aw], 'rnnlm_hx_list': rnnlm_state[0][:] if rnnlm_state is not None else None, 'rnnlm_cx_list': rnnlm_state[1][:] if rnnlm_state is not None else None, 'prev_cov': cov_sum, '_dout': _dout, '_dstate': _dstate[:] }) new_beam = sorted(new_beam, key=lambda x: x['score'], reverse=True) # Remove complete hypotheses not_complete = [] for cand in new_beam[:params['beam_width']]: if cand['hyp'][-1] == eos: complete += [cand] else: not_complete += [cand] if len(complete) >= params['beam_width']: complete = complete[:params['beam_width']] break beam = not_complete[:params['beam_width']] # Sort by score if len(complete) == 0: complete = beam elif len(complete) < nbest and nbest > 1: complete.extend(beam[:nbest - len(complete)]) complete = sorted(complete, key=lambda x: x['score'], reverse=True) # N-best list if self.backward: # Reverse the order nbest_hyps += [[ np.array(complete[n]['hyp'][1:][::-1]) for n in range(nbest) ]] if self.score.nheads > 1: aws += [[ complete[n]['aws'][0, 1:][::-1] for n in range(nbest) ]] else: aws += [[ complete[n]['aws'][1:][::-1] for n in range(nbest) ]] scores += [[ complete[n]['scores'][1:][::-1] for n in range(nbest) ]] else: nbest_hyps += [[ np.array(complete[n]['hyp'][1:]) for n in range(nbest) ]] if self.score.nheads > 1: aws += [[complete[n]['aws'][0, 1:] for n in range(nbest)]] else: aws += [[complete[n]['aws'][1:] for n in range(nbest)]] scores += [[complete[n]['scores'][1:] for n in range(nbest)]] # scores += [[complete[n]['score_raw'] for n in range(nbest)]] # Check <eos> eos_flag = [ True if complete[n]['hyp'][-1] == eos else False for n in range(nbest) ] eos_flags.append(eos_flag) if id2token is not None: if refs is not None: logger.info('Ref: %s' % refs[b].lower()) for n in range(nbest): logger.info('Hyp: %s' % id2token(nbest_hyps[0][n])) if refs is not None: logger.info('log prob (ref): ') for n in range(nbest): logger.info('log prob (hyp): %.3f' % complete[n]['score']) logger.info('log prob (hyp, raw): %.3f' % complete[n]['score_raw']) # Concatenate in L dimension for b in range(len(aws)): for n in range(nbest): aws[b][n] = tensor2np(torch.stack(aws[b][n], dim=1).squeeze(0)) # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.backward: nbest_hyps = [[ nbest_hyps[b][n][1:] if eos_flags[b][n] else nbest_hyps[b][n] for n in range(nbest) ] for b in range(bs)] else: nbest_hyps = [[ nbest_hyps[b][n][:-1] if eos_flags[b][n] else nbest_hyps[b][n] for n in range(nbest) ] for b in range(bs)] return nbest_hyps, aws, scores
def greedy(self, eouts, elens, max_len_ratio, exclude_eos=False): """Greedy decoding in the inference stage. Args: eouts (FloatTensor): `[B, T, enc_units]` elens (list): A list of length `[B]` max_len_ratio (int): the maximum sequence length of tokens exclude_eos (bool): Returns: best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]` aw (list): A list of length `[B]`, which contains arrays of size `[L, T]` """ bs, enc_time, enc_nunits = eouts.size() device_id = eouts.get_device() # Initialization dout, dstate = self.init_dec_state(bs, self.nlayers, device_id, eouts, elens) _dout, _dstate = self.init_dec_state(bs, 1, device_id, eouts, elens) context = eouts.new_zeros(bs, 1, enc_nunits) self.score.reset() aw = None rnnlm_state = None if self.backward: sos, eos = self.eos, self.sos else: sos, eos = self.sos, self.eos # Start from <sos> (<eos> in case of the backward decoder) y = eouts.new_zeros(bs, 1).fill_(sos).long() best_hyps_tmp, aws_tmp = [], [] y_lens = np.zeros((bs, ), dtype=np.int32) eos_flags = [False] * bs for t in range(int(math.floor(enc_time * max_len_ratio)) + 1): # Recurrency y_emb = self.embed(y) dout, dstate, _dout, _dstate = self.recurrency( y_emb, context, dstate, _dstate) # Update RNNLM states for cold fusion if self.rnnlm_cf: y_lm = self.rnnlm_cf.embed(y) logits_lm_t, lm_out, rnnlm_state = self.rnnlm_cf.predict( y_lm, rnnlm_state) else: logits_lm_t, lm_out = None, None # Score context, aw = self.score(eouts, elens, dout, aw) # Generate attentional_t = self.generate(context, dout, logits_lm_t, lm_out) if self.rnnlm_init and self.internal_lm: # Residual connection attentional_t += _dout logits_t = self.output(attentional_t) # Pick up 1-best device_id = logits_t.get_device() y = np.argmax(logits_t.squeeze(1).detach(), axis=1).cuda(device_id).unsqueeze(1) best_hyps_tmp += [y] if self.score.nheads > 1: aws_tmp += [aw[0]] else: aws_tmp += [aw] # Count lengths of hypotheses for b in range(bs): if not eos_flags[b]: if y[b].item() == eos: eos_flags[b] = True y_lens[b] += 1 # NOTE: include <eos> # Break if <eos> is outputed in all mini-bs if sum(eos_flags) == bs: break # Concatenate in L dimension best_hyps_tmp = torch.cat(best_hyps_tmp, dim=1) aws_tmp = torch.stack(aws_tmp, dim=1) # Convert to numpy best_hyps_tmp = tensor2np(best_hyps_tmp) aws_tmp = tensor2np(aws_tmp) # Truncate by the first <eos> (<sos> in case of the backward decoder) if self.backward: # Reverse the order best_hyps = [best_hyps_tmp[b, :y_lens[b]][::-1] for b in range(bs)] aws = [aws_tmp[b, :y_lens[b]][::-1] for b in range(bs)] else: best_hyps = [best_hyps_tmp[b, :y_lens[b]] for b in range(bs)] aws = [aws_tmp[b, :y_lens[b]] for b in range(bs)] # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.backward: best_hyps = [ best_hyps[b][1:] if eos_flags[b] else best_hyps[b] for b in range(bs) ] else: best_hyps = [ best_hyps[b][:-1] if eos_flags[b] else best_hyps[b] for b in range(bs) ] return best_hyps, aws