def translate(self, x, trans_args, char_list, rnnlm=None): """E2E beam search. :param ndarray x: input source text feature (B, T, D) :param Namespace trans_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() # 1. encoder # make a utt list (1) to use the same interface for encoder if self.multilingual: ilen = [len(x[0][1:])] h = to_device( self, torch.from_numpy( np.fromiter(map(int, x[0][1:]), dtype=np.int64))) else: ilen = [len(x[0])] h = to_device( self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64))) hs, _, _ = self.enc(self.dropout(self.embed(h.unsqueeze(0))), ilen) # 2. decoder # decode the first utterance y = self.dec.recognize_beam(hs[0], None, trans_args, char_list, rnnlm) if prev: self.train() return y
def recognize(self, h, recog_args): """Greedy search implementation. Args: h (torch.Tensor): encoder hidden state sequences (Tmax, Henc) recog_args (Namespace): argument Namespace containing options Returns: hyp (list of dicts): 1-best decoding results """ z_list, c_list = self.zero_state(h.unsqueeze(0)) ey = to_device(self, torch.zeros((1, self.embed_dim))) hyp = {'score': 0.0, 'yseq': [self.blank]} y, (z_list, c_list) = self.rnn_forward(ey, (z_list, c_list)) for hi in h: ytu = F.log_softmax(self.joint(hi, y[0]), dim=0) logp, pred = torch.max(ytu, dim=0) if pred != self.blank: hyp['yseq'].append(int(pred)) hyp['score'] += float(logp) eys = to_device( self, torch.full((1, 1), hyp['yseq'][-1], dtype=torch.long)) ey = self.dropout_embed(self.embed(eys)) y, (z_list, c_list) = self.rnn_forward(ey[0], (z_list, c_list)) return [hyp]
def translate_batch(self, xs, trans_args, char_list, rnnlm=None): """E2E batch beam search. :param list xs: list of input source text feature arrays [(T_1, D), (T_2, D), ...] :param Namespace trans_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() # 1. Encoder if self.multilingual: ilens = np.fromiter((len(xx[1:]) for xx in xs), dtype=np.int64) hs = [to_device(self, torch.from_numpy(xx[1:])) for xx in xs] else: ilens = np.fromiter((len(xx) for xx in xs), dtype=np.int64) hs = [to_device(self, torch.from_numpy(xx)) for xx in xs] xpad = pad_list(hs, self.pad) hs_pad, hlens, _ = self.enc(self.dropout(self.embed(xpad)), ilens) # 2. Decoder hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor y = self.dec.recognize_beam_batch(hs_pad, hlens, None, trans_args, char_list, rnnlm) if prev: self.train() return y
def forward(self, state, x): # update state with input label x if state is None: # make initial states and log-prob vectors self.var_word_eos = to_device(x, self.var_word_eos) self.var_word_unk = to_device(x, self.var_word_eos) wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) wlm_logprobs = F.log_softmax(z_wlm, dim=1) clm_state, z_clm = self.subwordlm(None, x) log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight new_node = self.lexroot clm_logprob = 0.0 xi = self.space else: clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state xi = int(x) if xi == self.space: # inter-word transition if node is not None and node[ 1] >= 0: # check if the node is word end w = to_device(x, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk # update wordlm state and log-prob vector wlm_state, z_wlm = self.wordlm(wlm_state, w) wlm_logprobs = F.log_softmax(z_wlm, dim=1) new_node = self.lexroot # move to the tree root clm_logprob = 0.0 elif node is not None and xi in node[0]: # intra-word transition new_node = node[0][xi] clm_logprob += log_y[0, xi] elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode new_node = None clm_logprob += log_y[0, xi] else: # if open_vocab flag is disabled, return 0 probabilities log_y = to_device( x, torch.full((1, self.subword_dict_size), self.logzero)) return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.0), log_y clm_state, z_clm = self.subwordlm(clm_state, x) log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight # apply word-level probabilies for <space> and <eos> labels if xi != self.space: if new_node is not None and new_node[ 1] >= 0: # if new node is word end wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob else: wlm_logprob = wlm_logprobs[:, self. word_unk] + self.log_oov_penalty log_y[:, self.space] = wlm_logprob log_y[:, self.eos] = wlm_logprob else: log_y[:, self.space] = self.logzero log_y[:, self.eos] = self.logzero return ( (clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)), log_y, )
def decoder_forward(self, hs_pad, hlens, ys_pad): ys = [y[y != self.ignore_id] for y in ys_pad] hlens = list(map(int, hlens)) blank = ys[0].new([self.blank]) ys_in = [torch.cat([blank, y], dim=0) for y in ys] ys_in_pad = pad_list(ys_in, self.blank) olength = ys_in_pad.size(1) z_list, c_list = self.zero_state(hs_pad) eys = self.dropout_embed(self.embed(ys_in_pad)) z_all = [] for i in six.moves.range(olength): y, (z_list, c_list) = self.rnn_forward(eys[:, i, :], (z_list, c_list)) z_all.append(y) h_dec = torch.stack(z_all, dim=1) h_enc = hs_pad.unsqueeze(2) h_dec = h_dec.unsqueeze(1) z = self.joint(h_enc, h_dec) y = pad_list(ys, self.blank).type(torch.int32) z_len = to_device(self, torch.IntTensor(hlens)) y_len = to_device(self, torch.IntTensor([_y.size(0) for _y in ys])) return z, y, z_len, y_len
def init_state(self, init_tensor): """Initialize decoder states. Args: init_tensor (torch.Tensor): batch of input features (B, emb_dim / dec_dim) Returns: (tuple): batch of decoder states ([L x (B, dec_dim)], [L x (B, dec_dim)]) """ dtype = init_tensor.dtype z_list = [ to_device(init_tensor, torch.zeros(init_tensor.size(0), self.dunits)).to( dtype ) for _ in range(self.dlayers) ] c_list = [ to_device(init_tensor, torch.zeros(init_tensor.size(0), self.dunits)).to( dtype ) for _ in range(self.dlayers) ] return (z_list, c_list)
def forward(self, state, x): """Forward neural networks.""" if state is None: h = [ to_device(self, self.zero_state(x.size(0))) for n in range(self.n_layers) ] state = {"h": h} if self.typ == "lstm": c = [ to_device(self, self.zero_state(x.size(0))) for n in range(self.n_layers) ] state = {"c": c, "h": h} h = [None] * self.n_layers emb = self.embed(x) if self.typ == "lstm": c = [None] * self.n_layers h[0], c[0] = self.rnn[0](self.dropout[0](emb), (state["h"][0], state["c"][0])) for n in range(1, self.n_layers): h[n], c[n] = self.rnn[n](self.dropout[n](h[n - 1]), (state["h"][n], state["c"][n])) state = {"c": c, "h": h} else: h[0] = self.rnn[0](self.dropout[0](emb), state["h"][0]) for n in range(1, self.n_layers): h[n] = self.rnn[n](self.dropout[n](h[n - 1]), state["h"][n]) state = {"h": h} y = self.lo(self.dropout[-1](h[-1])) return state, y
def forward( self, feats: torch.Tensor, feats_len: torch.Tensor, prev_states: Optional[List[torch.Tensor]] = None, ): """Forward encoder. Args: feats: Feature sequences. (B, F, D_feats) feats_len: Feature sequences lengths. (B,) prev_states: Previous encoder hidden states. [N x (B, T, D_enc)] Returns: enc_out: Encoder output sequences. (B, T, D_enc) with or without encoder intermediate output sequences. ((B, T, D_enc), [N x (B, T, D_enc)]) enc_out_len: Encoder output sequences lengths. (B,) current_states: Encoder hidden states. [N x (B, T, D_enc)] """ if prev_states is None: prev_states = [None] * len(self.enc) assert len(prev_states) == len(self.enc) _enc_out = feats _enc_out_len = feats_len current_states = [] for rnn_module, prev_state in zip(self.enc, prev_states): _enc_out, _enc_out_len, states = rnn_module( _enc_out, _enc_out_len, prev_states=prev_state, ) current_states.append(states) if isinstance(_enc_out, tuple): enc_out, aux_enc_out = _enc_out[0], _enc_out[1] enc_out_len, aux_enc_out_len = _enc_out_len[0], _enc_out_len[1] enc_out_mask = to_device(enc_out, make_pad_mask(enc_out_len).unsqueeze(-1)) enc_out = enc_out.masked_fill(enc_out_mask, 0.0) for i in range(len(aux_enc_out)): aux_mask = to_device( aux_enc_out[i], make_pad_mask(aux_enc_out_len[i]).unsqueeze(-1) ) aux_enc_out[i] = aux_enc_out[i].masked_fill(aux_mask, 0.0) return ( (enc_out, aux_enc_out), (enc_out_len, aux_enc_out_len), current_states, ) else: enc_out_mask = to_device( _enc_out, make_pad_mask(_enc_out_len).unsqueeze(-1) ) return _enc_out.masked_fill(enc_out_mask, 0.0), _enc_out_len, current_states
def recognize(self, h, recog_args): """Greedy search implementation for transformer-transducer. Args: h (torch.Tensor): encoder hidden state sequences (maxlen_in, Henc) recog_args (Namespace): argument Namespace containing options Returns: hyp (list of dicts): 1-best decoding results """ hyp = {'score': 0.0, 'yseq': [self.blank]} ys = to_device(self, torch.tensor(hyp['yseq'], dtype=torch.long)).unsqueeze(0) ys_mask = to_device(self, subsequent_mask(1).unsqueeze(0)) y, c = self.forward_one_step(ys, ys_mask, None) for i, hi in enumerate(h): ytu = torch.log_softmax(self.joint(hi, y[0]), dim=0) logp, pred = torch.max(ytu, dim=0) if pred != self.blank: hyp['yseq'].append(int(pred)) hyp['score'] += float(logp) ys = to_device(self, torch.tensor(hyp['yseq']).unsqueeze(0)) ys_mask = to_device( self, subsequent_mask(len(hyp['yseq'])).unsqueeze(0)) y, c = self.forward_one_step(ys, ys_mask, c) return [hyp]
def forward(self, hs_pad, hlens, ys_pad): """Decoder forward Args: hs_pad (torch.Tensor): batch of padded hidden state sequences (B, Tmax, D) hlens (torch.Tensor): batch of lengths of hidden state sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: loss (torch.Tensor): rnnt-att loss value """ ys = [y[y != self.ignore_id] for y in ys_pad] hlens = list(map(int, hlens)) blank = ys[0].new([self.blank]) ys_in = [torch.cat([blank, y], dim=0) for y in ys] ys_in_pad = pad_list(ys_in, self.blank) olength = ys_in_pad.size(1) c_list = [self.zero_state(hs_pad)] z_list = [self.zero_state(hs_pad)] for _ in six.moves.range(1, self.dlayers): c_list.append(self.zero_state(hs_pad)) z_list.append(self.zero_state(hs_pad)) att_w = None self.att[0].reset() eys = self.dropout_emb(self.embed(ys_in_pad)) z_all = [] for i in six.moves.range(olength): att_c, att_w = self.att[0](hs_pad, hlens, self.dropout_dec[0](z_list[0]), att_w) ey = torch.cat((eys[:, i, :], att_c), dim=1) z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) z_all.append(self.dropout_dec[-1](z_list[-1])) h_dec = torch.stack(z_all, dim=1) h_enc = hs_pad.unsqueeze(2) h_dec = h_dec.unsqueeze(1) z = self.joint(h_enc, h_dec) y = pad_list(ys, self.blank).type(torch.int32) z_len = to_device(self, torch.IntTensor(hlens)) y_len = to_device(self, torch.IntTensor([_y.size(0) for _y in ys])) loss = to_device(self, self.rnnt_loss(z, y, z_len, y_len)) return loss
def forward(self, hs_pad, hlens, ys_pad): """CTC forward :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: ctc loss value :rtype: torch.Tensor """ # TODO(kan-bayashi): need to make more smart way ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys self.loss = None hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32)) olens = torch.from_numpy(np.fromiter((x.size(0) for x in ys), dtype=np.int32)) # zero padding for hs ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) # zero padding for ys ys_true = torch.cat(ys).cpu().int() # batch x olen # get length info logging.info( self.__class__.__name__ + " input lengths: " + "".join(str(hlens).split("\n")) ) logging.info( self.__class__.__name__ + " output lengths: " + "".join(str(olens).split("\n")) ) # get ctc loss # expected shape of seqLength x batchSize x alphabet_size dtype = ys_hat.dtype ys_hat = ys_hat.transpose(0, 1) if self.ctc_type == "warpctc" or dtype == torch.float16: # warpctc only supports float32 # torch.ctc does not support float16 (#1751) ys_hat = ys_hat.to(dtype=torch.float32) if self.ctc_type == "builtin": # use GPU when using the cuDNN implementation ys_true = to_device(self, ys_true) self.loss = to_device(self, self.loss_fn(ys_hat, ys_true, hlens, olens)).to( dtype=dtype ) if self.reduce: # NOTE: sum() is needed to keep consistency # since warpctc return as tensor w/ shape (1,) # but builtin return as tensor w/o shape (scalar). self.loss = self.loss.sum() logging.info("ctc loss:" + str(float(self.loss))) return self.loss
def forward(self, xs_pad, ilens, prev_states=None): """Forward encoder. Args: xs_pad: Batch of padded input sequences (B, Tmax, idim) ilens: Batch of lengths of input sequences (B) prev_state: Batch of previous encoder hidden states (B, ??) Returns: : Batch of padded output sequences (B, Tmax, hdim) or tuple w/ aux outputs ((B, Tmax, hdim), [L x (B, Tmax, hdim)]) : Batch of lengths of output sequences (B) : Batch of hidden state sequences (B, Tmax, hdim) """ if prev_states is None: prev_states = [None] * len(self.enc) assert len(prev_states) == len(self.enc) current_states = [] for module, prev_state in zip(self.enc, prev_states): xs_pad, ilens, states = module( xs_pad, ilens, prev_state=prev_state, ) current_states.append(states) if isinstance(xs_pad, tuple): final_xs_pad, aux_xs_list = xs_pad[0], xs_pad[1] mask = to_device(final_xs_pad, make_pad_mask(ilens).unsqueeze(-1)) aux_xs_list = [ layer.masked_fill(mask, 0.0) for layer in aux_xs_list ] return ( ( final_xs_pad.masked_fill(mask, 0.0), aux_xs_list, ), ilens, current_states, ) else: mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1)) return xs_pad.masked_fill(mask, 0.0), ilens, current_states
def translate_batch(self, xs, trans_args, char_list, rnnlm=None): """E2E batch beam search. :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] :param Namespace trans_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) # 1. Encoder hs_pad, hlens, _ = self.enc(xs_pad, ilens) # 2. Decoder hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor y = self.dec.recognize_beam_batch(hs_pad, hlens, None, trans_args, char_list, rnnlm) if prev: self.train() return y
def subsample_frames(self, x): # subsample frame x = x[::self.subsample[0], :] ilen = [x.shape[0]] h = to_device(self, torch.from_numpy(np.array(x, dtype=np.float32))) h.contiguous() return h, ilen
def create_length_mask(self, length, max_len, num_output): batch_size = len(length) mask = torch.zeros(batch_size, max_len, num_output) for i in range(batch_size): mask[i, :length[i], :] = 1 mask = to_device(self, mask) return mask
def score(self, hyp, cache, init_tensor=None): """Forward one step. Args: hyp (dataclass): hypothesis cache (dict): states cache Returns: y (torch.Tensor): decoder outputs (1, dec_dim) state (tuple): decoder states ([L x (1, dec_dim)], [L x (1, dec_dim)]), (torch.Tensor): token id for LM (1) """ vy = to_device(self, torch.full((1, 1), hyp.yseq[-1], dtype=torch.long)) str_yseq = "".join([str(x) for x in hyp.yseq]) if str_yseq in cache: y, state = cache[str_yseq] else: ey = self.embed(vy) y, state = self.rnn_forward(ey[0], hyp.dec_state) cache[str_yseq] = (y, state) return y, state, vy[0]
def recognize_batch(self, xs, recog_args, char_list, rnnlm=None): """E2E beam search. :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[:: self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) # 0. Frontend if self.frontend is not None: enhanced, hlens, mask = self.frontend(xs_pad, ilens) hs_pad, hlens = self.feature_transform(enhanced, hlens) else: hs_pad, hlens = xs_pad, ilens batchsize = hs_pad.size(0) # 1. Encoder hyps, hlens, _ = self.enc(hs_pad, hlens) hyps = hyps.view(batchsize, -1, self.odim) return hyps
def enhance(self, xs): """Forward only the frontend stage. :param ndarray xs: input acoustic feature (T, C, F) """ if self.frontend is None: raise RuntimeError('Frontend doesn\'t exist') prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) enhanced, hlensm, mask = self.frontend(xs_pad, ilens) if prev: self.train() if isinstance(enhanced, (tuple, list)): enhanced = list(enhanced) mask = list(mask) for idx in range(len(enhanced)): # number of speakers enhanced[idx] = enhanced[idx].cpu().numpy() mask[idx] = mask[idx].cpu().numpy() return enhanced, mask, ilens return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
def recognize(self, x, recog_args, char_list, rnnlm=None): """E2E beam search :param ndarray x: input acoustic feature (T, D) :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() # subsample frame x = x[::self.subsample[0], :] ilen = [x.shape[0]] h = to_device(self, torch.from_numpy(np.array(x, dtype=np.float32))) # 1. encoder # make a utt list (1) to use the same interface for encoder h = h.contiguous() h, _ = self.enc(h.unsqueeze(0), ilen) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(h)[0] else: lpz = None # 2. decoder # decode the first utterance y = self.dec.recognize_beam(h[0], lpz, recog_args, char_list, rnnlm) if prev: self.train() return y
def enhance(self, xs): """Forward only the frontend stage. Args: xs (ndarray): input acoustic feature (T, C, F) Returns: enhanced (ndarray): mask (torch.Tensor): ilens (torch.Tensor): batch of lengths of input sequences (B) """ if self.frontend is None: raise RuntimeError('Frontend does\'t exist') prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) enhanced, hlensm, mask = self.frontend(xs_pad, ilens) if prev: self.train() return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
def forward(self, xs, labels=None): ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) xs = [to_device(self.slu, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) embeddings = self.slu(xs_pad, ilens, None) outputs = self.classifier(embeddings, labels) return outputs
def forward(self, xs_pad, ilens): """Encoder forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :return: list: batch of hidden state sequences [num_spkrs x (B, Tmax, eprojs)] :rtype: torch.Tensor """ # mixture encoder for module in self.enc_mix: xs_pad, ilens, _ = module(xs_pad, ilens) # SD and Rec encoder xs_pad_sd = [xs_pad for i in range(self.num_spkrs)] ilens_sd = [ilens for i in range(self.num_spkrs)] for ns in range(self.num_spkrs): # Encoder_SD: speaker differentiate encoder for module in self.enc_sd[ns]: xs_pad_sd[ns], ilens_sd[ns], _ = module( xs_pad_sd[ns], ilens_sd[ns]) # Encoder_Rec: recognition encoder for module in self.enc_rec: xs_pad_sd[ns], ilens_sd[ns], _ = module( xs_pad_sd[ns], ilens_sd[ns]) # make mask to remove bias value in padded part mask = to_device(self, make_pad_mask(ilens_sd[0]).unsqueeze(-1)) return [x.masked_fill(mask, 0.0) for x in xs_pad_sd], ilens_sd[0]
def forward(self, xs, ilens=None): """Encode input sequence. Args: xs (torch.Tensor): Input tensor (#batch, time, idim). masks (torch.Tensor): Mask tensor (#batch, time). Returns: torch.Tensor: Output tensor (#batch, time, attention_dim). torch.Tensor: Mask tensor (#batch, time). """ masks = to_device(xs, make_non_pad_mask(ilens)).unsqueeze(-2) xs = self.embed(xs) xs, _ = self.encoders(xs, masks) if self.normalize_before: xs = self.after_norm(xs) if self.query is None: return xs, ilens, None # Predict sentence type # (B, T, 1) mask = to_device(xs, make_pad_mask(ilens)).unsqueeze(-1) # (B, T, D) keys = torch.tanh(self.K(xs)) values = xs # (B, T, D) -> (B, T, 1) logits = torch.sum(keys * self.query, dim=-1).unsqueeze(-1) logits = logits.masked_fill(mask, -float('inf')) scores = F.softmax(logits, dim=1) # (B, T, 1) -> (B, 1, T) scores = self.score_dropout(scores.masked_fill(mask, 0.0)).transpose(1, 2) # (B, 1, T) * (B, T, D) -> (B, 1, D) x = torch.matmul(scores, values) # Predict intonation type intotype_logits = None if self.pred_prj is not None: intotype_logits = self.pred_prj(x.squeeze(1)) # Return repeated squeezed character embeddings or original encoded embedings if self.reduce_character_embedding: return x.squeeze(1), None, intotype_logits return xs, ilens, intotype_logits
def final(self, state): wlm_state, cumsum_probs, node = state if node is not None and node[1] >= 0: # check if the node is word end w = to_device(cumsum_probs, torch.LongTensor([node[1]])) else: # this node is not a word end, which means <unk> w = self.var_word_unk wlm_state, z_wlm = self.wordlm(wlm_state, w) return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
def forward(self, hs_pad, hlens, ys_pad): """Forward function for transducer. Args: hs_pad (torch.Tensor): batch of padded hidden state sequences (B, Tmax, D) hlens (torch.Tensor): batch of lengths of hidden state sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: loss (float): rnnt loss value """ ys = [y[y != self.ignore_id] for y in ys_pad] hlens = list(map(int, hlens)) blank = ys[0].new([self.blank]) ys_in = [torch.cat([blank, y], dim=0) for y in ys] ys_in_pad = pad_list(ys_in, self.blank) olength = ys_in_pad.size(1) z_list, c_list = self.zero_state(hs_pad) eys = self.dropout_embed(self.embed(ys_in_pad)) z_all = [] for i in six.moves.range(olength): y, (z_list, c_list) = self.rnn_forward(eys[:, i, :], (z_list, c_list)) z_all.append(y) h_dec = torch.stack(z_all, dim=1) h_enc = hs_pad.unsqueeze(2) h_dec = h_dec.unsqueeze(1) z = self.joint(h_enc, h_dec) y = pad_list(ys, self.blank).type(torch.int32) z_len = to_device(self, torch.IntTensor(hlens)) y_len = to_device(self, torch.IntTensor([_y.size(0) for _y in ys])) loss = to_device(self, self.rnnt_loss(z, y, z_len, y_len)) return loss
def recognize(self, x, recog_args, char_list, rnnlm=None): """E2E beam search. :param ndarray x: input acoustic feature (T, D) :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = [x.shape[0]] # subsample frame x = x[::self.subsample[0], :] h = to_device(self, to_torch_tensor(x).float()) # make a utt list (1) to use the same interface for encoder hs = h.contiguous().unsqueeze(0) # 0. Frontend if self.frontend is not None: hs, hlens, mask = self.frontend(hs, ilens) hlens_n = [None] * self.num_spkrs for i in range(self.num_spkrs): hs[i], hlens_n[i] = self.feature_transform(hs[i], hlens) hlens = hlens_n else: hs, hlens = hs, ilens # 1. Encoder if not isinstance(hs, list): # single-channel multi-speaker input x hs, hlens, _ = self.enc(hs, hlens) else: # multi-channel multi-speaker input x for i in range(self.num_spkrs): hs[i], hlens[i], _ = self.enc(hs[i], hlens[i]) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: lpz = [self.ctc.log_softmax(i)[0] for i in hs] else: lpz = None # 2. decoder # decode the first utterance y = [ self.dec.recognize_beam(hs[i][0], lpz[i], recog_args, char_list, rnnlm, strm_idx=i) for i in range(self.num_spkrs) ] if prev: self.train() return y
def attractor_loss(self, att_prob, label): batch_size = len(label) bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") # create attractor label [1, 1, ..., 1, 0] # att_label: (Batch, num_spk + 1, 1) att_label = to_device(self, torch.zeros(batch_size, label.size(2) + 1, 1)) att_label[:, : label.size(2), :] = 1 loss = bce_loss(att_prob, att_label) loss = torch.mean(torch.mean(loss, dim=1)) return loss
def expand_to(self, xs, lens): """ xs: (B, D) lens: (B,) """ # (B, T, 1) mask = to_device(xs, make_pad_mask(lens).unsqueeze(-1)) # (B, D) -> (B, 1, D) -> (B, T, D) xs = xs.unsqueeze(1).expand(-1, mask.size(1), -1).masked_fill(mask, 0.0) return xs
def min_pit_ctc_batch(self, losses): '''E2E min_pit_ctc_batch :param torch.Tensor losses: CTC losses (B, 1|4|9) :return: min ctc loss value :rtype: torch.Tensor (B) :return: permutation of min ctc loss value :rtype: torch.Tensor (B, 1|2|3) ''' if self.num_spkrs == 1: return to_device(self, torch.mean(losses[:, 0])), to_device( self, torch.zeros(losses.size(0)).long()) else: bs = losses.size(0) ret = [self.min_pit_process(losses[i]) for i in range(bs)] loss_perm = torch.stack([r[0] for r in ret], dim=0) permutation = torch.tensor([r[1] for r in ret]).long() return torch.mean(to_device(self, loss_perm)), to_device( self, permutation)
def recognize_batch(self, xs, recog_args, char_list, rnnlm=None): """E2E batch beam search. :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) # 0. Frontend if self.frontend is not None: enhanced, hlens, mask = self.frontend(xs_pad, ilens) hs_pad, hlens = self.feature_transform(enhanced, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. Encoder hs_pad, hlens, _ = self.enc(hs_pad, hlens) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(hs_pad) normalize_score = False else: lpz = None normalize_score = True # 2. Decoder hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor y = self.dec.recognize_beam_batch( hs_pad, hlens, lpz, recog_args, char_list, rnnlm, normalize_score=normalize_score, ) if prev: self.train() return y