def _shallow_ar_inference(out, stream_sizes, analysis_filts): from torchaudio.functional import lfilter out_streams = split_streams(out, stream_sizes) # back to conv1d friendly (B, C, T) format out_streams = map(lambda x: x.transpose(1, 2), out_streams) out_syn = [] for sidx, os in enumerate(out_streams): out_stream_syn = torch.zeros_like(os) a = analysis_filts[sidx].get_filt_coefs() # apply IIR filter for each dimiesion for idx in range(os.shape[1]): # NOTE: scipy.signal.lfilter accespts b, a in order, # but torchaudio expect the oppsite; a, b in order ai = a[idx].view(-1).flip(0) bi = torch.zeros_like(ai) bi[0] = 1 out_stream_syn[:, idx, :] = lfilter(os[:, idx, :], ai, bi, clamp=False) out_syn += [out_stream_syn] out_syn = torch.cat(out_syn, 1) return out_syn.transpose(1, 2)
def preprocess_target(self, y): assert sum(self.stream_sizes) == y.shape[-1] ys = split_streams(y, self.stream_sizes) for idx, yi in enumerate(ys): ys[idx] = self.analysis_filts[idx](yi.transpose(1, 2)).transpose( 1, 2) return torch.cat(ys, -1)
def gen_waveform(labels, acoustic_features, acoustic_out_scaler, binary_dict, continuous_dict, stream_sizes, has_dynamic_features, subphone_features="coarse_coding", log_f0_conditioning=True, pitch_idx=None, num_windows=3, post_filter=True, sample_rate=48000, frame_period=5, relative_f0=True): windows = get_windows(num_windows) # Apply MLPG if necessary if np.any(has_dynamic_features): acoustic_features = multi_stream_mlpg( acoustic_features, acoustic_out_scaler.var_, windows, stream_sizes, has_dynamic_features) static_stream_sizes = get_static_stream_sizes( stream_sizes, has_dynamic_features, len(windows)) else: static_stream_sizes = stream_sizes # Split multi-stream features mgc, target_f0, vuv, bap = split_streams(acoustic_features, static_stream_sizes) # Gen waveform by the WORLD vocodoer fftlen = pyworld.get_cheaptrick_fft_size(sample_rate) alpha = pysptk.util.mcepalpha(sample_rate) if post_filter: mgc = merlin_post_filter(mgc, alpha) spectrogram = pysptk.mc2sp(mgc, fftlen=fftlen, alpha=alpha) aperiodicity = pyworld.decode_aperiodicity(bap.astype(np.float64), sample_rate, fftlen) ### F0 ### if relative_f0: diff_lf0 = target_f0 # need to extract pitch sequence from the musical score linguistic_features = fe.linguistic_features(labels, binary_dict, continuous_dict, add_frame_features=True, subphone_features=subphone_features) f0_score = _midi_to_hz(linguistic_features, pitch_idx, False)[:, None] lf0_score = f0_score.copy() nonzero_indices = np.nonzero(lf0_score) lf0_score[nonzero_indices] = np.log(f0_score[nonzero_indices]) lf0_score = interp1d(lf0_score, kind="slinear") f0 = diff_lf0 + lf0_score f0[vuv < 0.5] = 0 f0[np.nonzero(f0)] = np.exp(f0[np.nonzero(f0)]) else: f0 = target_f0 generated_waveform = pyworld.synthesize(f0.flatten().astype(np.float64), spectrogram.astype(np.float64), aperiodicity.astype(np.float64), sample_rate, frame_period) return generated_waveform
def gen_spsvs_static_features( labels, acoustic_features, binary_dict, numeric_dict, stream_sizes, has_dynamic_features, subphone_features="coarse_coding", pitch_idx=None, num_windows=3, frame_period=5, relative_f0=True, vibrato_scale=1.0, vuv_threshold=0.3, force_fix_vuv=True, ): """Generate static features from predicted acoustic features Args: labels (HTSLabelFile): HTS labels acoustic_features (ndarray): predicted acoustic features binary_dict (dict): binary feature dictionary numeric_dict (dict): numeric feature dictionary stream_sizes (list): stream sizes has_dynamic_features (list): whether each stream has dynamic features subphone_features (str): subphone feature type pitch_idx (int): index of pitch features num_windows (int): number of windows frame_period (float): frame period relative_f0 (bool): whether to use relative f0 vibrato_scale (float): vibrato scale vuv_threshold (float): vuv threshold force_fix_vuv (bool): whether to use post-processing to fix VUV. Returns: tuple: tuple of mgc, lf0, vuv and bap. """ if np.any(has_dynamic_features): static_stream_sizes = get_static_stream_sizes( stream_sizes, has_dynamic_features, num_windows ) else: static_stream_sizes = stream_sizes # Copy here to avoid inplace operations on input acoustic features acoustic_features = acoustic_features.copy() # Split multi-stream features streams = split_streams(acoustic_features, static_stream_sizes) if len(streams) == 4: mgc, target_f0, vuv, bap = streams vib, vib_flags = None, None elif len(streams) == 5: # Assuming diff-based vibrato parameters mgc, target_f0, vuv, bap, vib = streams vib_flags = None elif len(streams) == 6: # Assuming sine-based vibrato parameters mgc, target_f0, vuv, bap, vib, vib_flags = streams else: raise RuntimeError("Not supported streams") linguistic_features = fe.linguistic_features( labels, binary_dict, numeric_dict, add_frame_features=True, subphone_features=subphone_features, ) # Correct V/UV based on special phone flags if force_fix_vuv: vuv = correct_vuv_by_phone(vuv, binary_dict, linguistic_features) # F0 if relative_f0: diff_lf0 = target_f0 f0_score = _midi_to_hz(linguistic_features, pitch_idx, False)[:, None] lf0_score = f0_score.copy() nonzero_indices = np.nonzero(lf0_score) lf0_score[nonzero_indices] = np.log(f0_score[nonzero_indices]) lf0_score = interp1d(lf0_score, kind="slinear") f0 = diff_lf0 + lf0_score f0[vuv < vuv_threshold] = 0 f0[np.nonzero(f0)] = np.exp(f0[np.nonzero(f0)]) else: f0 = target_f0 f0[vuv < vuv_threshold] = 0 f0[np.nonzero(f0)] = np.exp(f0[np.nonzero(f0)]) if vib is not None: if vib_flags is not None: # Generate sine-based vibrato vib_flags = vib_flags.flatten() m_a, m_f = vib[:, 0], vib[:, 1] # Fill zeros for non-vibrato frames m_a[vib_flags < 0.5] = 0 m_f[vib_flags < 0.5] = 0 # Gen vibrato sr_f0 = int(1 / (frame_period * 0.001)) f0 = gen_sine_vibrato(f0.flatten(), sr_f0, m_a, m_f, vibrato_scale) else: # Generate diff-based vibrato f0 = f0.flatten() + vibrato_scale * vib.flatten() # NOTE: Back to log-domain for convenience lf0 = f0.copy() lf0[np.nonzero(lf0)] = np.log(f0[np.nonzero(lf0)]) # NOTE: interpolation is necessary lf0 = interp1d(lf0, kind="slinear") lf0 = lf0[:, None] if len(lf0.shape) == 1 else lf0 vuv = vuv[:, None] if len(vuv.shape) == 1 else vuv return mgc, lf0, vuv, bap
def train_loop(config, device, model, optimizer, lr_scheduler, data_loaders): criterion = nn.MSELoss(reduction="none") logger.info("Start utterance-wise training...") stream_weights = get_stream_weight(config.model.stream_weights, config.model.stream_sizes).to(device) best_loss = 10000000 for epoch in tqdm(range(1, config.train.nepochs + 1)): for phase in data_loaders.keys(): train = phase.startswith("train") model.train() if train else model.eval() running_loss = 0 for x, y, lengths in data_loaders[phase]: # Sort by lengths . This is needed for pytorch's PackedSequence sorted_lengths, indices = torch.sort(lengths, dim=0, descending=True) x, y = x[indices].to(device), y[indices].to(device) optimizer.zero_grad() # Run forwaard y_hat = model(x, sorted_lengths) # Compute loss mask = make_non_pad_mask(sorted_lengths).unsqueeze(-1).to( device) if config.train.stream_wise_loss: # Strean-wise loss streams = split_streams(y, config.model.stream_sizes) streams_hat = split_streams(y_hat, config.model.stream_sizes) loss = 0 for s_hat, s, sw in zip(streams_hat, streams, stream_weights): s_hat_mask = s_hat.masked_select(mask) s_mask = s.masked_select(mask) loss += sw * criterion(s_hat_mask, s_mask).mean() else: # Joint modeling y_hat = y_hat.masked_select(mask) y = y.masked_select(mask) loss = criterion(y_hat, y).mean() if train: loss.backward() optimizer.step() running_loss += loss.item() ave_loss = running_loss / len(data_loaders[phase]) logger.info(f"[{phase}] [Epoch {epoch}]: loss {ave_loss}") if not train and ave_loss < best_loss: best_loss = ave_loss save_best_checkpoint(config, model, optimizer, best_loss) # step per each epoch (may consider updating per iter.) lr_scheduler.step() if epoch % config.train.checkpoint_epoch_interval == 0: save_checkpoint(config, model, optimizer, lr_scheduler, epoch) # save at last epoch save_checkpoint(config, model, optimizer, lr_scheduler, config.train.nepochs) logger.info(f"The best loss was {best_loss}") return model
def gen_waveform(labels, acoustic_features, binary_dict, continuous_dict, stream_sizes, has_dynamic_features, subphone_features="coarse_coding", log_f0_conditioning=True, pitch_idx=None, num_windows=3, post_filter=True, sample_rate=48000, frame_period=5, relative_f0=True): windows = get_windows(num_windows) # Apply MLPG if necessary if np.any(has_dynamic_features): static_stream_sizes = get_static_stream_sizes(stream_sizes, has_dynamic_features, len(windows)) else: static_stream_sizes = stream_sizes # Split multi-stream features mgc, target_f0, vuv, bap = split_streams(acoustic_features, static_stream_sizes) # Gen waveform by the WORLD vocodoer fftlen = pyworld.get_cheaptrick_fft_size(sample_rate) alpha = pysptk.util.mcepalpha(sample_rate) if post_filter: mgc = merlin_post_filter(mgc, alpha) spectrogram = pysptk.mc2sp(mgc, fftlen=fftlen, alpha=alpha) aperiodicity = pyworld.decode_aperiodicity(bap.astype(np.float64), sample_rate, fftlen) # fill aperiodicity with ones for unvoiced regions aperiodicity[vuv.reshape(-1) < 0.5, :] = 1.0 # WORLD fails catastrophically for out of range aperiodicity aperiodicity = np.clip(aperiodicity, 0.0, 1.0) ### F0 ### if relative_f0: diff_lf0 = target_f0 # need to extract pitch sequence from the musical score linguistic_features = fe.linguistic_features( labels, binary_dict, continuous_dict, add_frame_features=True, subphone_features=subphone_features) f0_score = _midi_to_hz(linguistic_features, pitch_idx, False)[:, None] lf0_score = f0_score.copy() nonzero_indices = np.nonzero(lf0_score) lf0_score[nonzero_indices] = np.log(f0_score[nonzero_indices]) lf0_score = interp1d(lf0_score, kind="slinear") f0 = diff_lf0 + lf0_score f0[vuv < 0.5] = 0 f0[np.nonzero(f0)] = np.exp(f0[np.nonzero(f0)]) else: f0 = target_f0 f0[vuv < 0.5] = 0 f0[np.nonzero(f0)] = np.exp(f0[np.nonzero(f0)]) generated_waveform = pyworld.synthesize(f0.flatten().astype(np.float64), spectrogram.astype(np.float64), aperiodicity.astype(np.float64), sample_rate, frame_period) # 音量を小さくする(音割れ防止) # TODO: ここのかける定数をいい感じにする spectrogram *= 0.000000001 sp = pyworld.code_spectral_envelope(spectrogram, sample_rate, 60) return f0, sp, bap, generated_waveform
def train_step( model, optimizer, grad_scaler, train, in_feats, out_feats, lengths, out_scaler, feats_criterion="mse", stream_wise_loss=False, stream_weights=None, stream_sizes=None, ): model.train() if train else model.eval() optimizer.zero_grad() if feats_criterion in ["l2", "mse"]: criterion = nn.MSELoss(reduction="none") elif feats_criterion in ["l1", "mae"]: criterion = nn.L1Loss(reduction="none") else: raise RuntimeError("not supported criterion") prediction_type = (model.module.prediction_type() if isinstance( model, nn.DataParallel) else model.prediction_type()) # Apply preprocess if required (e.g., FIR filter for shallow AR) # defaults to no-op if isinstance(model, nn.DataParallel): out_feats = model.module.preprocess_target(out_feats) else: out_feats = model.preprocess_target(out_feats) # Run forward with autocast(enabled=grad_scaler is not None): pred_out_feats = model(in_feats, lengths) # Mask (B, T, 1) mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device) # Compute loss if prediction_type == PredictionType.PROBABILISTIC: pi, sigma, mu = pred_out_feats # (B, max(T)) or (B, max(T), D_out) mask_ = mask if len(pi.shape) == 4 else mask.squeeze(-1) # Compute loss and apply mask with autocast(enabled=grad_scaler is not None): loss = mdn_loss(pi, sigma, mu, out_feats, reduce=False) loss = loss.masked_select(mask_).mean() else: if stream_wise_loss: w = get_stream_weight(stream_weights, stream_sizes).to(in_feats.device) streams = split_streams(out_feats, stream_sizes) pred_streams = split_streams(pred_out_feats, stream_sizes) loss = 0 for pred_stream, stream, sw in zip(pred_streams, streams, w): with autocast(enabled=grad_scaler is not None): loss += (sw * criterion(pred_stream.masked_select(mask), stream.masked_select(mask)).mean()) else: with autocast(enabled=grad_scaler is not None): loss = criterion(pred_out_feats.masked_select(mask), out_feats.masked_select(mask)).mean() if prediction_type == PredictionType.PROBABILISTIC: with torch.no_grad(): pred_out_feats_ = mdn_get_most_probable_sigma_and_mu( pi, sigma, mu)[1] else: pred_out_feats_ = pred_out_feats distortions = compute_distortions(pred_out_feats_, out_feats, lengths, out_scaler) if train: if grad_scaler is not None: grad_scaler.scale(loss).backward() grad_scaler.step(optimizer) grad_scaler.update() else: loss.backward() optimizer.step() return loss, distortions
def eval_spss_model( step, netG, in_feats, out_feats, lengths, model_config, out_scaler, writer, sr, trajectory_smoothing=True, trajectory_smoothing_cutoff=50, ): # make sure to be in eval mode netG.eval() is_autoregressive = (netG.module.is_autoregressive() if isinstance( netG, nn.DataParallel) else netG.is_autoregressive()) prediction_type = (netG.module.prediction_type() if isinstance( netG, nn.DataParallel) else netG.prediction_type()) utt_indices = [-1, -2, -3] utt_indices = utt_indices[:min(3, len(in_feats))] if np.any(model_config.has_dynamic_features): static_stream_sizes = get_static_stream_sizes( model_config.stream_sizes, model_config.has_dynamic_features, model_config.num_windows, ) else: static_stream_sizes = model_config.stream_sizes for utt_idx in utt_indices: out_feats_denorm_ = out_scaler.inverse_transform( out_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0)) mgc, lf0, vuv, bap = get_static_features( out_feats_denorm_, model_config.num_windows, model_config.stream_sizes, model_config.has_dynamic_features, )[:4] mgc = mgc.squeeze(0).cpu().numpy() lf0 = lf0.squeeze(0).cpu().numpy() vuv = vuv.squeeze(0).cpu().numpy() bap = bap.squeeze(0).cpu().numpy() f0, spectrogram, aperiodicity = gen_world_params( mgc, lf0, vuv, bap, sr) wav = pyworld.synthesize(f0, spectrogram, aperiodicity, sr, 5) group = f"utt{np.abs(utt_idx)}_reference" wav = wav / np.abs(wav).max() if np.max(wav) > 1.0 else wav writer.add_audio(group, wav, step, sr) # Run forward if is_autoregressive: outs = netG( in_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0), [lengths[utt_idx]], out_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0), ) else: outs = netG(in_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0), [lengths[utt_idx]]) # ResF0 case if isinstance(outs, tuple) and len(outs) == 2: outs, _ = outs if prediction_type == PredictionType.PROBABILISTIC: pi, sigma, mu = outs pred_out_feats = mdn_get_most_probable_sigma_and_mu(pi, sigma, mu)[1] else: pred_out_feats = outs # NOTE: multiple outputs if isinstance(pred_out_feats, list): pred_out_feats = pred_out_feats[-1] if isinstance(pred_out_feats, tuple): pred_out_feats = pred_out_feats[0] if not isinstance(pred_out_feats, list): pred_out_feats = [pred_out_feats] # Run inference if prediction_type == PredictionType.PROBABILISTIC: inference_out_feats, _ = netG.inference( in_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0), [lengths[utt_idx]]) else: inference_out_feats = netG.inference( in_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0), [lengths[utt_idx]]) pred_out_feats.append(inference_out_feats) # Plot normalized input/output in_feats_ = in_feats[utt_idx, :lengths[utt_idx]].cpu().numpy() out_feats_ = out_feats[utt_idx, :lengths[utt_idx]].cpu().numpy() fig, ax = plt.subplots(3, 1, figsize=(8, 8)) ax[0].set_title("Reference features") ax[1].set_title("Input features") ax[2].set_title("Predicted features") mesh = librosa.display.specshow(out_feats_.T, x_axis="frames", y_axis="frames", ax=ax[0], cmap="viridis") # NOTE: assuming normalized to N(0, 1) mesh.set_clim(-4, 4) fig.colorbar(mesh, ax=ax[0]) mesh = librosa.display.specshow(in_feats_.T, x_axis="frames", y_axis="frames", ax=ax[1], cmap="viridis") mesh.set_clim(-4, 4) fig.colorbar(mesh, ax=ax[1]) mesh = librosa.display.specshow( inference_out_feats.squeeze(0).cpu().numpy().T, x_axis="frames", y_axis="frames", ax=ax[2], cmap="viridis", ) mesh.set_clim(-4, 4) fig.colorbar(mesh, ax=ax[2]) for ax_ in ax: ax_.set_ylabel("Feature") plt.tight_layout() group = f"utt{np.abs(utt_idx)}_inference" writer.add_figure(f"{group}/Input-Output", fig, step) plt.close() assert len(pred_out_feats) == 2 for idx, pred_out_feats_ in enumerate(pred_out_feats): pred_out_feats_ = pred_out_feats_.squeeze(0).cpu().numpy() pred_out_feats_denorm = (out_scaler.inverse_transform( torch.from_numpy(pred_out_feats_).to( in_feats.device)).cpu().numpy()) if np.any(model_config.has_dynamic_features): # (T, D_out) -> (T, static_dim) pred_out_feats_denorm = multi_stream_mlpg( pred_out_feats_denorm, (out_scaler.scale_**2).cpu().numpy(), get_windows(model_config.num_windows), model_config.stream_sizes, model_config.has_dynamic_features, ) pred_mgc, pred_lf0, pred_vuv, pred_bap = split_streams( pred_out_feats_denorm, static_stream_sizes)[:4] # Remove high-frequency components of mgc/bap # NOTE: It seems to be effective to suppress artifacts of GAN-based post-filtering if trajectory_smoothing: modfs = int(1 / 0.005) for d in range(pred_mgc.shape[1]): pred_mgc[:, d] = lowpass_filter( pred_mgc[:, d], modfs, cutoff=trajectory_smoothing_cutoff) for d in range(pred_bap.shape[1]): pred_bap[:, d] = lowpass_filter( pred_bap[:, d], modfs, cutoff=trajectory_smoothing_cutoff) # Generated sample f0, spectrogram, aperiodicity = gen_world_params( pred_mgc, pred_lf0, pred_vuv, pred_bap, sr) wav = pyworld.synthesize(f0, spectrogram, aperiodicity, sr, 5) wav = wav / np.abs(wav).max() if np.max(wav) > 1.0 else wav if idx == 1: group = f"utt{np.abs(utt_idx)}_inference" else: group = f"utt{np.abs(utt_idx)}_forward" writer.add_audio(group, wav, step, sr) plot_spsvs_params( step, writer, mgc, lf0, vuv, bap, pred_mgc, pred_lf0, pred_vuv, pred_bap, group=group, sr=sr, )
def train_loop(config, device, model, optimizer, lr_scheduler, data_loaders): criterion = nn.MSELoss(reduction="none") logger.info("Start utterance-wise training...") stream_weights = get_stream_weight( config.model.stream_weights, config.model.stream_sizes).to(device) best_loss = 10000000 for epoch in tqdm(range(1, config.train.nepochs + 1)): for phase in data_loaders.keys(): train = phase.startswith("train") model.train() if train else model.eval() running_loss = 0 for x, y, lengths in data_loaders[phase]: # Sort by lengths . This is needed for pytorch's PackedSequence sorted_lengths, indices = torch.sort(lengths, dim=0, descending=True) x, y = x[indices].to(device), y[indices].to(device) optimizer.zero_grad() # Apply preprocess if required (e.g., FIR filter for shallow AR) # defaults to no-op y = model.preprocess_target(y) # Run forwaard if model.prediction_type() == PredictionType.PROBABILISTIC: pi, sigma, mu = model(x, sorted_lengths) # (B, max(T)) or (B, max(T), D_out) mask = make_non_pad_mask(sorted_lengths).to(device) mask = mask.unsqueeze(-1) if len(pi.shape) == 4 else mask # Compute loss and apply mask loss = mdn_loss(pi, sigma, mu, y, reduce=False) loss = loss.masked_select(mask).mean() else: y_hat = model(x, sorted_lengths) # Compute loss mask = make_non_pad_mask(sorted_lengths).unsqueeze(-1).to(device) if config.train.stream_wise_loss: # Strean-wise loss streams = split_streams(y, config.model.stream_sizes) streams_hat = split_streams(y_hat, config.model.stream_sizes) loss = 0 for s_hat, s, sw in zip(streams_hat, streams, stream_weights): s_hat_mask = s_hat.masked_select(mask) s_mask = s.masked_select(mask) loss += sw * criterion(s_hat_mask, s_mask).mean() else: # Joint modeling y_hat = y_hat.masked_select(mask) y = y.masked_select(mask) loss = criterion(y_hat, y).mean() if train: loss.backward() optimizer.step() running_loss += loss.item() ave_loss = running_loss / len(data_loaders[phase]) logger.info("[%s] [Epoch %s]: loss %s", phase, epoch, ave_loss) if not train and ave_loss < best_loss: best_loss = ave_loss save_best_checkpoint(config, model, optimizer, best_loss) # step per each epoch (may consider updating per iter.) lr_scheduler.step() if epoch % config.train.checkpoint_epoch_interval == 0: save_checkpoint(config, model, optimizer, lr_scheduler, epoch) # save at last epoch save_checkpoint(config, model, optimizer, lr_scheduler, config.train.nepochs) logger.info("The best loss was {%s}", best_loss) return model
def forward(self, x, lengths=None, is_inference=False): """Forward step Each feature stream is processed independently. Args: x (torch.Tensor): input tensor of shape (B, T, C) lengths (torch.Tensor): lengths of shape (B,) Returns: torch.Tensor: output tensor of shape (B, T, C) """ streams = split_streams(x, self.stream_sizes) if len(streams) == 4: mgc, lf0, vuv, bap = streams elif len(streams) == 5: mgc, lf0, vuv, bap, vuv = streams elif len(streams) == 6: mgc, lf0, vuv, bap, vib, vib_flags = streams else: raise ValueError("Invalid number of streams") if self.mgc_postfilter is not None: if self.mgc_offset > 0: # keep unchanged for the 0-to-${mgc_offset}-th dim of mgc mgc0 = mgc[:, :, :self.mgc_offset] if is_inference: mgc_pf = self.mgc_postfilter.inference( mgc[:, :, self.mgc_offset:], lengths) else: mgc_pf = self.mgc_postfilter(mgc[:, :, self.mgc_offset:], lengths) mgc_pf = torch.cat([mgc0, mgc_pf], dim=-1) else: if is_inference: mgc_pf = self.mgc_postfilter.inference(mgc, lengths) else: mgc_pf = self.mgc_postfilter(mgc, lengths) mgc = mgc_pf if self.bap_postfilter is not None: if self.bap_offset > 0: # keep unchanged for the 0-to-${bap_offset}-th dim of bap bap0 = bap[:, :, :self.bap_offset] if is_inference: bap_pf = self.bap_postfilter.inference( bap[:, :, self.bap_offset:], lengths) else: bap_pf = self.bap_postfilter(bap[:, :, self.bap_offset:], lengths) bap_pf = torch.cat([bap0, bap_pf], dim=-1) else: if is_inference: bap_pf = self.bap_postfilter.inference(bap, lengths) else: bap_pf = self.bap_postfilter(bap, lengths) bap = bap_pf if self.lf0_postfilter is not None: if is_inference: lf0 = self.lf0_postfilter.inference(lf0, lengths) else: lf0 = self.lf0_postfilter(lf0, lengths) if len(streams) == 4: out = torch.cat([mgc, lf0, vuv, bap], dim=-1) elif len(streams) == 5: out = torch.cat([mgc, lf0, vuv, bap, vib], dim=-1) elif len(streams) == 6: out = torch.cat([mgc, lf0, vuv, bap, vib, vib_flags], dim=-1) return out