def _calc_att_loss( self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, ) -> Tuple[torch.Tensor, float]: ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_in_lens = ys_pad_lens + 1 # reverse the seq, used for right to left decoder r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id)) r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos, self.ignore_id) # 1. Forward decoder decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad, self.reverse_weight) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) r_loss_att = torch.tensor(0.0) if self.reverse_weight > 0.0: r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) loss_att = loss_att * ( 1 - self.reverse_weight) + r_loss_att * self.reverse_weight acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) return loss_att, acc_att
def forward_attention_decoder( self, hyps: torch.Tensor, hyps_lens: torch.Tensor, encoder_out: torch.Tensor, reverse_weight: float = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Export interface for c++ call, forward decoder with multiple hypothesis from ctc prefix beam search and one encoder output Args: hyps (torch.Tensor): hyps from ctc prefix beam search, already pad sos at the begining hyps_lens (torch.Tensor): length of each hyp in hyps encoder_out (torch.Tensor): corresponding encoder output r_hyps (torch.Tensor): hyps from ctc prefix beam search, already pad eos at the begining which is used fo right to left decoder reverse_weight: used for verfing whether used right to left decoder, > 0 will use. Returns: torch.Tensor: decoder output """ assert encoder_out.size(0) == 1 num_hyps = hyps.size(0) assert hyps_lens.size(0) == num_hyps encoder_out = encoder_out.repeat(num_hyps, 1, 1) encoder_mask = torch.ones(num_hyps, 1, encoder_out.size(1), dtype=torch.bool, device=encoder_out.device) # input for right to left decoder # this hyps_lens has count <sos> token, we need minus it. r_hyps_lens = hyps_lens - 1 # this hyps has included <sos> token, so it should be # convert the original hyps. r_hyps = hyps[:, 1:] r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight) # (num_hyps, max_hyps_len, vocab_size) decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) # right to left decoder may be not used during decoding process, # which depends on reverse_weight param. # r_dccoder_out will be 0.0, if reverse_weight is 0.0 r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) return decoder_out, r_decoder_out
def attention_rescoring( self, speech: torch.Tensor, speech_lengths: torch.Tensor, beam_size: int, decoding_chunk_size: int = -1, num_decoding_left_chunks: int = -1, ctc_weight: float = 0.0, simulate_streaming: bool = False, reverse_weight: float = 0.0, ) -> List[int]: """ Apply attention rescoring decoding, CTC prefix beam search is applied first to get nbest, then we resoring the nbest on attention decoder with corresponding encoder out Args: speech (torch.Tensor): (batch, max_len, feat_dim) speech_length (torch.Tensor): (batch, ) beam_size (int): beam size for beam search decoding_chunk_size (int): decoding chunk for dynamic chunk trained model. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here simulate_streaming (bool): whether do encoder forward in a streaming fashion reverse_weight (float): right to left decoder weight ctc_weight (float): ctc score weight Returns: List[int]: Attention rescoring result """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 if reverse_weight > 0.0: # decoder should be a bitransformer decoder if reverse_weight > 0.0 assert hasattr(self.decoder, 'right_decoder') device = speech.device batch_size = speech.shape[0] # For attention rescoring we only support batch_size=1 assert batch_size == 1 # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size hyps, encoder_out = self._ctc_prefix_beam_search( speech, speech_lengths, beam_size, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) assert len(hyps) == beam_size hyps_pad = pad_sequence([ torch.tensor(hyp[0], device=device, dtype=torch.long) for hyp in hyps ], True, self.ignore_id) # (beam_size, max_hyps_len) ori_hyps_pad = hyps_pad hyps_lens = torch.tensor([len(hyp[0]) for hyp in hyps], device=device, dtype=torch.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_lens = hyps_lens + 1 # Add <sos> at begining encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = torch.ones(beam_size, 1, encoder_out.size(1), dtype=torch.bool, device=device) # used for right to left decoder r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, self.ignore_id) decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, reverse_weight) # (beam_size, max_hyps_len, vocab_size) decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) decoder_out = decoder_out.cpu().numpy() # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a # conventional transformer decoder. r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) r_decoder_out = r_decoder_out.cpu().numpy() # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 for i, hyp in enumerate(hyps): score = 0.0 for j, w in enumerate(hyp[0]): score += decoder_out[i][j][w] score += decoder_out[i][len(hyp[0])][self.eos] # add right to left decoder score if reverse_weight > 0: r_score = 0.0 for j, w in enumerate(hyp[0]): r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] r_score += r_decoder_out[i][len(hyp[0])][self.eos] score = score * (1 - reverse_weight) + r_score * reverse_weight # add ctc score score += hyp[1] * ctc_weight if score > best_score: best_score = score best_index = i return hyps[best_index][0]