def forward(self, xs_pad, ilens): """Encodermix forward. :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :return: list: batch of hidden state sequences [num_spkrs x (B, Tmax, eprojs)] :rtype: torch.Tensor """ # mixture encoder for module in self.enc_mix: xs_pad, ilens, _ = module(xs_pad, ilens) # SD and Rec encoder xs_pad_sd = [xs_pad for i in range(self.num_spkrs)] ilens_sd = [ilens for i in range(self.num_spkrs)] for ns in range(self.num_spkrs): # Encoder_SD: speaker differentiate encoder for module in self.enc_sd[ns]: xs_pad_sd[ns], ilens_sd[ns], _ = module( xs_pad_sd[ns], ilens_sd[ns]) # Encoder_Rec: recognition encoder for module in self.enc_rec: xs_pad_sd[ns], ilens_sd[ns], _ = module( xs_pad_sd[ns], ilens_sd[ns]) # make mask to remove bias value in padded part mask = to_device(self, make_pad_mask(ilens_sd[0]).unsqueeze(-1)) return [x.masked_fill(mask, 0.0) for x in xs_pad_sd], ilens_sd, None
def forward( self, input: torch.Tensor, ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """LabelAggregate forward function. Args: input: (Batch, Nsamples, Label_dim) ilens: (Batch) Returns: output: (Batch, Frames, Label_dim) """ bs = input.size(0) max_length = input.size(1) label_dim = input.size(2) # NOTE(jiatong): # The default behaviour of label aggregation is compatible with # torch.stft about framing and padding. # Step1: center padding if self.center: pad = self.win_length // 2 max_length = max_length + 2 * pad input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0) input[:, :pad, :] = input[:, pad:(2 * pad), :] input[:, (max_length - pad):max_length, :] = input[:, (max_length - 2 * pad):(max_length - pad), :] nframe = (max_length - self.win_length) // self.hop_length + 1 # Step2: framing output = input.as_strided( (bs, nframe, self.win_length, label_dim), (max_length * label_dim, self.hop_length * label_dim, label_dim, 1), ) # Step3: aggregate label output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2) output = output.float() # Step4: process lengths if ilens is not None: if self.center: pad = self.win_length // 2 ilens = ilens + 2 * pad olens = (ilens - self.win_length) // self.hop_length + 1 output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) else: olens = None return output, olens
def forward(self, feat: torch.Tensor, ilens: torch.LongTensor) \ -> Tuple[torch.Tensor, torch.LongTensor]: # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) mel_feat = torch.matmul(feat, self.melmat) logmel_feat = (mel_feat + 1e-20).log() # Zero padding logmel_feat = logmel_feat.masked_fill( make_pad_mask(ilens, logmel_feat, 1), 0.0) # We now create the Scattering1D object that will be used to calculate the scattering coefficients. # scattering = Scattering1D(J, T, Q) # If we are using CUDA, the scattering transform object must be transferred to the GPU by calling its cuda() method. The data is similarly transferred. # if use_cuda: # scattering.cuda() # x_all = x_all.cuda() # y_all = y_all.cuda() # Compute the scattering transform for all signals in the dataset. # Sx_all = scattering.forward(x_all) # Since it does not carry useful information, we remove the zeroth-order scattering coefficients, which are always placed in the first channel of the scattering Tensor. # Sx_all = Sx_all[:,1:,:] # To increase discriminability, we take the logarithm of the scattering coefficients (after adding a small constant to make sure nothing blows up when scattering coefficients are close to zero). # Sx_all = torch.log(torch.abs(Sx_all) + log_eps) # Finally, we average along the last dimension (time) to get a time-shift invariant representation. # Sx_all = torch.mean(Sx_all, dim=-1) return logmel_feat, ilens
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # lid output layer pred_pad = self.lid_lo(hs_pad) # compute lid loss self.loss = self.criterion(pred_pad, ys_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_pad, ignore_label=self.ignore_id) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(self.acc, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def forward( self, input: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Forward.""" input = self.linear_in(input) args = {"return_dict": True} mask = (~make_pad_mask(input_lengths)).to(input.device).float() if self.extend_attention_mask: args["attention_mask"] = _extend_attention_mask(mask) else: args["attention_mask"] = mask if self.use_inputs_embeds: args["inputs_embeds"] = input else: args["hidden_states"] = input if self.transformer.config.model_type == "mpnet": args["head_mask"] = [None for _ in self.transformer.layer] output = self.transformer(**args).last_hidden_state return output, input_lengths
def forward( self, feat: torch.Tensor, ilens: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) # print(feat.device, feat.dtype, feat.shape) # print("\t",self.melmat.device, self.melmat.dtype, self.melmat.shape) mel_feat = torch.matmul(feat, self.melmat.to(feat.device)) mel_feat = torch.clamp(mel_feat, min=1e-10) if self.log_base is None: logmel_feat = mel_feat.log() elif self.log_base == 2.0: logmel_feat = mel_feat.log2() elif self.log_base == 10.0: logmel_feat = mel_feat.log10() else: logmel_feat = mel_feat.log() / torch.log(self.log_base) # Zero padding if ilens is not None: logmel_feat = logmel_feat.masked_fill( make_pad_mask(ilens, logmel_feat, 1), 0.0 ) else: ilens = feat.new_full( [feat.size(0)], fill_value=feat.size(1), dtype=torch.long ) return logmel_feat, ilens
def store_penultimate_state(self, xs_pad, ilens, ys_pad, moe_coes, moe_coe_lens): moe_coes = moe_coes[:, :max(moe_coe_lens)] # for data parallel xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) # multi-encoder forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) moe_coes = moe_coes.unsqueeze(-1) hs_pad = cn_hs_pad * moe_coes[:, :, 1] + en_hs_pad * moe_coes[:, :, 0] self.hs_pad = hs_pad # forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask, penultimate_state = self.decoder( ys_in_pad, ys_mask, hs_pad, hs_mask, moe_coes, return_penultimate_state=True) # plot penultimate_state, (B,T,att_dim) return penultimate_state.squeeze(0).detach().cpu().numpy()
def nll(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = text.size(0) # For data parallel text = text[:, :text_lengths.max()] # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' # train: (Batch, Length) -> x, y: (Batch, Length + 1) x = F.pad(text, [1, 0], "constant", self.eos) t = F.pad(text, [0, 1], "constant", self.ignore_id) for i, l in enumerate(text_lengths): t[i, l] = self.sos x_lengths = text_lengths + 1 # 2. Forward Language model # x: (Batch, Length) -> y: (Batch, Length, NVocab) y, _ = self.lm(x, None) # 3. Calc negative log likelihood # nll: (BxL,) nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") # nll: (BxL,) -> (BxL,) nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0) # nll: (BxL,) -> (B, L) nll = nll.view(batch_size, -1) return nll, x_lengths
def store_penultimate_state(self, xs_pad, ilens, ys_pad): xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) # multi-encoder forward cn_hs_pad, hs_mask = self.cn_encoder(xs_pad, src_mask) en_hs_pad, hs_mask = self.en_encoder(xs_pad, src_mask) hs_pad = torch.cat((cn_hs_pad, en_hs_pad), dim=-1) coe = self.aggregation_linear(hs_pad) # (B,T,2) coe = F.softmax(self.aggre_scaling * coe, dim=-1).unsqueeze(-1) # (B,T,2,1) hs_pad = coe[:,:,0] * cn_hs_pad + coe[:,:,1] * en_hs_pad # soft assign mode # penultimate_state = 1 - coe # for mlme usage, we need to inverse this coe # hard assign mode penultimate_state = 1 - coe penultimate_state = penultimate_state.squeeze(0).detach().cpu().numpy() # (T, 2) ndnarray, choose argmax position penultimate_state = (penultimate_state > 0.5).astype(int) # forward decoder # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) # ys_mask = target_mask(ys_in_pad, self.ignore_id) # pred_pad, pred_mask, penultimate_state = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask, return_penultimate_state=True) # plot penultimate_state, (B,T,att_dim) return penultimate_state
def utterance_mvn( x: torch.Tensor, ilens: torch.LongTensor, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20, ) -> Tuple[torch.Tensor, torch.LongTensor]: """Apply utterance mean and variance normalization Args: x: (B, T, D), assumed zero padded ilens: (B, T, D) norm_means: norm_vars: eps: """ ilens_ = ilens.type_as(x) # mean: (B, D) mean = x.sum(dim=1) / ilens_[:, None] if norm_means: x -= mean[:, None, :] x_ = x else: x_ = x - mean[:, None, :] # Zero padding x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0) if norm_vars: var = x_.pow(2).sum(dim=1) / ilens_[:, None] var = torch.clamp(var, min=eps) x /= var.sqrt()[:, None, :] x_ = x return x_, ilens
def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, ys_pad: torch.Tensor, ys_pad_length: torch.Tensor, prev_states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Forward Hubert Pretrain Encoder. Args: xs_pad: input tensor (B, L, D) ilens: input length (B) prev_states: Not to be used now. Returns: position embedded tensor and mask """ self.cast_mask_emb() masks = make_pad_mask(ilens).to(xs_pad.device) ys_pad = ys_pad[:, :min(ys_pad_length)] enc_outputs = self.encoder( xs_pad, padding_mask=masks, mask=True, target_list=[ys_pad], features_only=False, ) return enc_outputs
def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Embed positions in tensor. Args: xs_pad: input tensor (B, L, D) ilens: input length (B) prev_states: Not to be used now. Returns: position embedded tensor and mask """ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) if ( isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8) ): xs_pad, masks = self.embed(xs_pad, masks) else: xs_pad = self.embed(xs_pad) xs_pad, masks = self.encoders(xs_pad, masks) if self.normalize_before: xs_pad = self.after_norm(xs_pad) olens = masks.squeeze(1).sum(1) return xs_pad, olens, None
def forward( self, x: torch.Tensor, ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward function Args: x: (B, L, ...) ilens: (B,) """ if ilens is None: ilens = x.new_full([x.size(0)], x.size(1)) norm_means = self.norm_means norm_vars = self.norm_vars self.mean = self.mean.to(x.device, x.dtype) self.std = self.std.to(x.device, x.dtype) mask = make_pad_mask(ilens, x, 1) # feat: (B, T, D) if norm_means: if x.requires_grad: x = x - self.mean else: x -= self.mean if x.requires_grad: x = x.masked_fill(mask, 0.0) else: x.masked_fill_(mask, 0.0) if norm_vars: x /= self.std return x, ilens
def forward_mt(self, xs_pad, ys_in_pad, ys_out_pad, ys_mask): """Forward pass in the auxiliary MT task. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ys_in_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_out_pad: batch of padded target sequences (B, Lmax) :param torch.Tensor ys_mask: batch of input token mask (B, Lmax) :return: MT loss value :rtype: torch.Tensor :return: accuracy in MT decoder :rtype: float """ loss, acc = 0.0, None if self.mt_weight == 0: return loss, acc ilens = torch.sum(xs_pad != self.ignore_id, dim=1).cpu().numpy() # NOTE: xs_pad is padded with -1 xs = [x[x != self.ignore_id] for x in xs_pad] # parse padded xs xs_zero_pad = pad_list(xs, self.pad) # re-pad with zero xs_zero_pad = xs_zero_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_zero_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder_mt(xs_zero_pad, src_mask) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) loss = self.criterion(pred_pad, ys_out_pad) acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) return loss, acc
def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Embed positions in tensor. Args: xs_pad: input tensor (B, L, D) ilens: input length (B) prev_states: Not to be used now. Returns: position embedded tensor and mask """ masks = (make_pad_mask(ilens)).to(xs_pad.device) self.wav2vec.feature_grad_mult = 0 # make sure conv feature extraction has been freezed xs_pad = self.wav2vec.forward(xs_pad, mask=True, padding_mask=masks, features_only=True)['x'] feats_lens = [] for lens in ilens: feats_lens.append( get_output_lens(self.wav2vec.feature_extractor.conv_layers, lens)) olens = torch.stack(feats_lens) # xs_pad = self.projection(xs_pad) return xs_pad, olens, None
def forward( self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Calculate forward propagation. Args: xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). ilens (torch.Tensor): Input length (#batch). prev_states (torch.Tensor): Not to be used now. Returns: torch.Tensor: Output tensor (#batch, L, output_size). torch.Tensor: Output length (#batch). torch.Tensor: Not to be used now. """ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) if (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)): xs_pad, masks = self.embed(xs_pad, masks) else: xs_pad = self.embed(xs_pad) xs_pad, masks = self.encoders(xs_pad, masks) if isinstance(xs_pad, tuple): xs_pad = xs_pad[0] if self.normalize_before: xs_pad = self.after_norm(xs_pad) olens = masks.squeeze(1).sum(1) return xs_pad, olens, None
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) self.hs_pad = hs_pad # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # TODO(karita) show predicted text # TODO(karita) calculate these stats # 5. compute bleu if self.training or self.error_calculator is None: bleu = 0.0 else: ys_hat = pred_pad.argmax(dim=-1) bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) # copyied from e2e_mt self.loss = loss loss_data = float(self.loss) if self.normalize_length: self.ppl = np.exp(loss_data) else: ys_out_pad = ys_out_pad.view(-1) ignore = ys_out_pad == self.ignore_id # (B,) total = len(ys_out_pad) - ignore.sum().item() self.ppl = np.exp(loss_data * ys_out_pad.size(0) / total) if not math.isnan(loss_data): self.reporter.report(loss_data, self.acc, self.ppl, bleu) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def _forward( self, xs, ilens, ys=None, olens=None, spembs=None, ds=None, is_inference=False, alpha=1.0, ): # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and length regulator d_masks = make_pad_mask(ilens).to(xs.device) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, d_outs, ilens, alpha) # (B, Lmax, adim) else: if ds is None: with torch.no_grad(): ds = self.duration_calculator(xs, ilens, ys, olens, spembs) # (B, Tmax) d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim) # forward decoder if olens is not None: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) if is_inference: return before_outs, after_outs, d_outs else: return before_outs, after_outs, ds, d_outs
def forward(self, xs, activation=None): ilens = [x.shape[0] for x in xs] xs_pad = pad_sequence(xs, batch_first=True, padding_value=-1) pad_shape = xs.shape src_mask = (~make_pad_mask(ilens)).to(xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.enc.forward(xs_pad, src_mask) ys_pad = self.linear(hs_pad) return ys_pad
def forward( self, hs_pad: torch.Tensor, hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. Args: hs_pad: encoded memory, float32 (batch, maxlen_in, feat) hlens: (batch) ys_in_pad: input token ids, int64 (batch, maxlen_out) if input_layer == "embed" input tensor (batch, maxlen_out, #mels) in the other cases ys_in_lens: (batch) Returns: (tuple): tuple containing: x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True, olens: (batch, ) """ tgt = ys_in_pad # tgt_mask: (B, 1, L) tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) # m: (1, L, L) m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) # tgt_mask: (B, L, L) tgt_mask = tgt_mask & m memory = hs_pad memory_mask = (~make_pad_mask(hlens))[:, None, :].to(memory.device) x = self.embed(tgt) x, tgt_mask, memory, memory_mask = self.decoders( x, tgt_mask, memory, memory_mask) if self.normalize_before: x = self.after_norm(x) if self.output_layer is not None: x = self.output_layer(x) olens = tgt_mask.sum(1) return x, olens
def store_penultimate_state(self, xs_pad, ilens, ys_pad): self.eval() xs_pad = xs_pad[:, :max(ilens)] # for data parallel src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # plot penultimate_state, (B,T,att_dim) return hs_pad.squeeze(0).detach().cpu().numpy()
def forward(self, xs_pad, ilens, ys_pad): """E2E forward. Args: xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 1. encoder if self.etype == 'transformer': xs_pad = xs_pad[:, :max(ilens)] src_mask = (~make_pad_mask(ilens.tolist())).to( xs_pad.device).unsqueeze(-2) hs_pad, hs_mask = self.encoder(xs_pad, src_mask) else: hs_pad, hlens = xs_pad, ilens hs_pad, hlens, _ = self.encoder(hs_pad, hlens) hs_mask = hlens self.hs_pad = hs_pad # 1.5. transducer preparation related ys_in_pad, target, pred_len, target_len = prepare_loss_inputs( ys_pad, hs_mask) # 2. decoder if self.dtype == 'transformer': ys_mask = target_mask(ys_in_pad, self.blank_id) pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) else: if self.rnnt_mode == 'rnnt': pred_pad = self.decoder(hs_pad, ys_in_pad) else: pred_pad = self.decoder(hs_pad, ys_in_pad, pred_len) self.pred_pad = pred_pad # 3. loss computation loss = self.criterion(pred_pad, target, pred_len, target_len) self.loss = loss loss_data = float(self.loss) # 4. compute cer/wer if self.training or self.error_calculator is None: cer, wer = None, None else: cer, wer = self.error_calculator(hs_pad, ys_pad) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
def utterance_mvn( x: torch.Tensor, ilens: torch.Tensor = None, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply utterance mean and variance normalization Args: x: (B, T, D), assumed zero padded ilens: (B,) norm_means: norm_vars: eps: """ if ilens is None: ilens = x.new_full([x.size(0)], x.size(1)) ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)]) # Zero padding if x.requires_grad: x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) else: x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) # mean: (B, 1, D) mean = x.sum(dim=1, keepdim=True) / ilens_ if norm_means: x -= mean if norm_vars: var = x.pow(2).sum(dim=1, keepdim=True) / ilens_ std = torch.clamp(var.sqrt(), min=eps) x = x / std.sqrt() return x, ilens else: if norm_vars: y = x - mean y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0) var = y.pow(2).sum(dim=1, keepdim=True) / ilens_ std = torch.clamp(var.sqrt(), min=eps) x /= std return x, ilens
def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor = None, olens: torch.Tensor = None, ds: torch.Tensor = None, spembs: torch.Tensor = None, is_inference: bool = False, alpha: float = 1.0, ) -> Sequence[torch.Tensor]: # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # integrate with GST if self.use_gst: style_embs = self.gst(ys) hs = hs + style_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and length regulator d_masks = make_pad_mask(ilens).to(xs.device) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, d_outs, ilens, alpha) # (B, Lmax, adim) else: d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim) # forward decoder if olens is not None and not is_inference: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) return before_outs, after_outs, d_outs
def forward( self, xs: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: """Mask estimator forward function. Args: xs: (B, F, C, T) ilens: (B,) Returns: hs (torch.Tensor): The hidden vector (B, F, C, T) masks: A tuple of the masks. (B, F, C, T) ilens: (B,) """ assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0)) _, _, C, input_length = xs.size() # (B, F, C, T) -> (B, C, T, F) xs = xs.permute(0, 2, 3, 1) # Calculate amplitude: (B, C, T, F) -> (B, C, T, F) if is_complex(xs): xs = (xs.real**2 + xs.imag**2)**0.5 # xs: (B, C, T, F) -> xs: (B * C, T, F) xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1)) # ilens: (B,) -> ilens_: (B * C) ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1) # xs: (B * C, T, F) -> xs: (B * C, T, D) xs, _, _ = self.brnn(xs, ilens_) # xs: (B * C, T, D) -> xs: (B, C, T, D) xs = xs.view(-1, C, xs.size(-2), xs.size(-1)) masks = [] for linear in self.linears: # xs: (B, C, T, D) -> mask:(B, C, T, F) mask = linear(xs) if self.nonlinear == "sigmoid": mask = torch.sigmoid(mask) elif self.nonlinear == "relu": mask = torch.relu(mask) elif self.nonlinear == "tanh": mask = torch.tanh(mask) elif self.nonlinear == "crelu": mask = torch.clamp(mask, min=0, max=1) # Zero padding mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0) # (B, C, T, F) -> (B, F, C, T) mask = mask.permute(0, 3, 1, 2) # Take cares of multi gpu cases: If input_length > max(ilens) if mask.size(-1) < input_length: mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0) masks.append(mask) return tuple(masks), ilens
def forward(self, x: torch.Tensor, ilens: torch.LongTensor) \ -> Tuple[torch.Tensor, torch.LongTensor]: # feat: (B, T, D) if self.norm_means: x += self.bias.type_as(x) x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) if self.norm_vars: x *= self.scale.type_as(x) return x, ilens
def forward( self, feat: torch.Tensor, ilens: torch.LongTensor ) -> Tuple[torch.Tensor, torch.LongTensor]: # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) mel_feat = torch.matmul(feat, self.melmat) logmel_feat = (mel_feat + 1e-20).log() # Zero padding logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0) return logmel_feat, ilens
def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor = None, olens: torch.Tensor = None, ds: torch.Tensor = None, ps: torch.Tensor = None, es: torch.Tensor = None, ): x_masks = self._source_mask(ilens) # (B, 1, Tmax) y_masks = self._source_mask(olens) # (B, 1, Lmax) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) d_masks = make_pad_mask(ilens).to(xs.device) if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) d_outs = self.duration_predictor(hs, d_masks) # use groundtruth in training p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs mu = self.length_regulator(hs, ds) # (B, Lmax, adim) mu, _ = self.pre_decoder(mu, y_masks) # (B, Lmax, adim) mu = self.feat_out(mu) # (B, Lmax, odim) mu = mu.transpose(1, 2) # (B, odim, Lmax) if mu.size(2) % 4 != 0: mu = torch.cat([ mu, torch.zeros([mu.size(0), self.odim, 4 - mu.size(2) % 4], dtype=mu.dtype, device=mu.device) ], dim=2) y_masks = torch.cat([ y_masks, torch.zeros([y_masks.size(0), 1, 4 - y_masks.size(2) % 4], dtype=y_masks.dtype, device=y_masks.device) ], dim=2) noise_estimation, z = self.decoder(ys, y_masks, mu) return noise_estimation, z, d_outs, p_outs, e_outs, mu, y_masks
def expand_to(self, xs, lens): """ xs: (B, D) lens: (B,) """ # (B, T, 1) mask = to_device(xs, make_pad_mask(lens).unsqueeze(-1)) # (B, D) -> (B, 1, D) -> (B, T, D) xs = xs.unsqueeze(1).expand(-1, mask.size(1), -1).masked_fill(mask, 0.0) return xs
def forward( self, input: torch.Tensor, ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """STFT forward function. Args: input: (Batch, Nsamples) or (Batch, Nsample, Channels) ilens: (Batch) Returns: output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) """ bs = input.size(0) if input.dim() == 3: multi_channel = True # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) input = input.transpose(1, 2).reshape(-1, input.size(1)) else: multi_channel = False # output: (Batch, Freq, Frames, 2=real_imag) # or (Batch, Channel, Freq, Frames, 2=real_imag) output = torch.stft( input, n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, center=self.center, pad_mode=self.pad_mode, normalized=self.normalized, onesided=self.onesided, ) # output: (Batch, Freq, Frames, 2=real_imag) # -> (Batch, Frames, Freq, 2=real_imag) output = output.transpose(1, 2) if multi_channel: # output: (Batch * Channel, Frames, Freq, 2=real_imag) # -> (Batch, Frame, Channel, Freq, 2=real_imag) output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(1, 2) if ilens is not None: if self.center: pad = self.win_length // 2 ilens = ilens + 2 * pad olens = (ilens - self.win_length) // self.hop_length + 1 output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) else: olens = None return output, olens