Example #1
0
def _shallow_ar_inference(out, stream_sizes, analysis_filts):
    from torchaudio.functional import lfilter

    out_streams = split_streams(out, stream_sizes)
    # back to conv1d friendly (B, C, T) format
    out_streams = map(lambda x: x.transpose(1, 2), out_streams)

    out_syn = []
    for sidx, os in enumerate(out_streams):
        out_stream_syn = torch.zeros_like(os)
        a = analysis_filts[sidx].get_filt_coefs()
        # apply IIR filter for each dimiesion
        for idx in range(os.shape[1]):
            # NOTE: scipy.signal.lfilter accespts b, a in order,
            # but torchaudio expect the oppsite; a, b in order
            ai = a[idx].view(-1).flip(0)
            bi = torch.zeros_like(ai)
            bi[0] = 1
            out_stream_syn[:, idx, :] = lfilter(os[:, idx, :],
                                                ai,
                                                bi,
                                                clamp=False)
        out_syn += [out_stream_syn]

    out_syn = torch.cat(out_syn, 1)
    return out_syn.transpose(1, 2)
Example #2
0
 def preprocess_target(self, y):
     assert sum(self.stream_sizes) == y.shape[-1]
     ys = split_streams(y, self.stream_sizes)
     for idx, yi in enumerate(ys):
         ys[idx] = self.analysis_filts[idx](yi.transpose(1, 2)).transpose(
             1, 2)
     return torch.cat(ys, -1)
Example #3
0
def gen_waveform(labels, acoustic_features, acoustic_out_scaler,
        binary_dict, continuous_dict, stream_sizes, has_dynamic_features,
        subphone_features="coarse_coding", log_f0_conditioning=True, pitch_idx=None,
        num_windows=3, post_filter=True, sample_rate=48000, frame_period=5,
        relative_f0=True):

    windows = get_windows(num_windows)

    # Apply MLPG if necessary
    if np.any(has_dynamic_features):
        acoustic_features = multi_stream_mlpg(
            acoustic_features, acoustic_out_scaler.var_, windows, stream_sizes,
            has_dynamic_features)
        static_stream_sizes = get_static_stream_sizes(
            stream_sizes, has_dynamic_features, len(windows))
    else:
        static_stream_sizes = stream_sizes

    # Split multi-stream features
    mgc, target_f0, vuv, bap = split_streams(acoustic_features, static_stream_sizes)

    # Gen waveform by the WORLD vocodoer
    fftlen = pyworld.get_cheaptrick_fft_size(sample_rate)
    alpha = pysptk.util.mcepalpha(sample_rate)

    if post_filter:
        mgc = merlin_post_filter(mgc, alpha)

    spectrogram = pysptk.mc2sp(mgc, fftlen=fftlen, alpha=alpha)
    aperiodicity = pyworld.decode_aperiodicity(bap.astype(np.float64), sample_rate, fftlen)


    ### F0 ###
    if relative_f0:
        diff_lf0 = target_f0
        # need to extract pitch sequence from the musical score
        linguistic_features = fe.linguistic_features(labels,
                                                    binary_dict, continuous_dict,
                                                    add_frame_features=True,
                                                    subphone_features=subphone_features)
        f0_score = _midi_to_hz(linguistic_features, pitch_idx, False)[:, None]
        lf0_score = f0_score.copy()
        nonzero_indices = np.nonzero(lf0_score)
        lf0_score[nonzero_indices] = np.log(f0_score[nonzero_indices])
        lf0_score = interp1d(lf0_score, kind="slinear")

        f0 = diff_lf0 + lf0_score
        f0[vuv < 0.5] = 0
        f0[np.nonzero(f0)] = np.exp(f0[np.nonzero(f0)])
    else:
        f0 = target_f0

    generated_waveform = pyworld.synthesize(f0.flatten().astype(np.float64),
                                            spectrogram.astype(np.float64),
                                            aperiodicity.astype(np.float64),
                                            sample_rate, frame_period)

    return generated_waveform
Example #4
0
File: gen.py Project: r9y9/nnsvs
def gen_spsvs_static_features(
    labels,
    acoustic_features,
    binary_dict,
    numeric_dict,
    stream_sizes,
    has_dynamic_features,
    subphone_features="coarse_coding",
    pitch_idx=None,
    num_windows=3,
    frame_period=5,
    relative_f0=True,
    vibrato_scale=1.0,
    vuv_threshold=0.3,
    force_fix_vuv=True,
):
    """Generate static features from predicted acoustic features

    Args:
        labels (HTSLabelFile): HTS labels
        acoustic_features (ndarray): predicted acoustic features
        binary_dict (dict): binary feature dictionary
        numeric_dict (dict): numeric feature dictionary
        stream_sizes (list): stream sizes
        has_dynamic_features (list): whether each stream has dynamic features
        subphone_features (str): subphone feature type
        pitch_idx (int): index of pitch features
        num_windows (int): number of windows
        frame_period (float): frame period
        relative_f0 (bool): whether to use relative f0
        vibrato_scale (float): vibrato scale
        vuv_threshold (float): vuv threshold
        force_fix_vuv (bool): whether to use post-processing to fix VUV.

    Returns:
        tuple: tuple of mgc, lf0, vuv and bap.
    """
    if np.any(has_dynamic_features):
        static_stream_sizes = get_static_stream_sizes(
            stream_sizes, has_dynamic_features, num_windows
        )
    else:
        static_stream_sizes = stream_sizes

    # Copy here to avoid inplace operations on input acoustic features
    acoustic_features = acoustic_features.copy()

    # Split multi-stream features
    streams = split_streams(acoustic_features, static_stream_sizes)

    if len(streams) == 4:
        mgc, target_f0, vuv, bap = streams
        vib, vib_flags = None, None
    elif len(streams) == 5:
        # Assuming diff-based vibrato parameters
        mgc, target_f0, vuv, bap, vib = streams
        vib_flags = None
    elif len(streams) == 6:
        # Assuming sine-based vibrato parameters
        mgc, target_f0, vuv, bap, vib, vib_flags = streams
    else:
        raise RuntimeError("Not supported streams")

    linguistic_features = fe.linguistic_features(
        labels,
        binary_dict,
        numeric_dict,
        add_frame_features=True,
        subphone_features=subphone_features,
    )

    # Correct V/UV based on special phone flags
    if force_fix_vuv:
        vuv = correct_vuv_by_phone(vuv, binary_dict, linguistic_features)

    # F0
    if relative_f0:
        diff_lf0 = target_f0
        f0_score = _midi_to_hz(linguistic_features, pitch_idx, False)[:, None]
        lf0_score = f0_score.copy()
        nonzero_indices = np.nonzero(lf0_score)
        lf0_score[nonzero_indices] = np.log(f0_score[nonzero_indices])
        lf0_score = interp1d(lf0_score, kind="slinear")

        f0 = diff_lf0 + lf0_score
        f0[vuv < vuv_threshold] = 0
        f0[np.nonzero(f0)] = np.exp(f0[np.nonzero(f0)])
    else:
        f0 = target_f0
        f0[vuv < vuv_threshold] = 0
        f0[np.nonzero(f0)] = np.exp(f0[np.nonzero(f0)])

    if vib is not None:
        if vib_flags is not None:
            # Generate sine-based vibrato
            vib_flags = vib_flags.flatten()
            m_a, m_f = vib[:, 0], vib[:, 1]

            # Fill zeros for non-vibrato frames
            m_a[vib_flags < 0.5] = 0
            m_f[vib_flags < 0.5] = 0

            # Gen vibrato
            sr_f0 = int(1 / (frame_period * 0.001))
            f0 = gen_sine_vibrato(f0.flatten(), sr_f0, m_a, m_f, vibrato_scale)
        else:
            # Generate diff-based vibrato
            f0 = f0.flatten() + vibrato_scale * vib.flatten()

    # NOTE: Back to log-domain for convenience
    lf0 = f0.copy()
    lf0[np.nonzero(lf0)] = np.log(f0[np.nonzero(lf0)])
    # NOTE: interpolation is necessary
    lf0 = interp1d(lf0, kind="slinear")

    lf0 = lf0[:, None] if len(lf0.shape) == 1 else lf0
    vuv = vuv[:, None] if len(vuv.shape) == 1 else vuv

    return mgc, lf0, vuv, bap
Example #5
0
def train_loop(config, device, model, optimizer, lr_scheduler, data_loaders):
    criterion = nn.MSELoss(reduction="none")
    logger.info("Start utterance-wise training...")

    stream_weights = get_stream_weight(config.model.stream_weights,
                                       config.model.stream_sizes).to(device)

    best_loss = 10000000
    for epoch in tqdm(range(1, config.train.nepochs + 1)):
        for phase in data_loaders.keys():
            train = phase.startswith("train")
            model.train() if train else model.eval()
            running_loss = 0
            for x, y, lengths in data_loaders[phase]:
                # Sort by lengths . This is needed for pytorch's PackedSequence
                sorted_lengths, indices = torch.sort(lengths,
                                                     dim=0,
                                                     descending=True)
                x, y = x[indices].to(device), y[indices].to(device)

                optimizer.zero_grad()

                # Run forwaard
                y_hat = model(x, sorted_lengths)

                # Compute loss
                mask = make_non_pad_mask(sorted_lengths).unsqueeze(-1).to(
                    device)

                if config.train.stream_wise_loss:
                    # Strean-wise loss
                    streams = split_streams(y, config.model.stream_sizes)
                    streams_hat = split_streams(y_hat,
                                                config.model.stream_sizes)
                    loss = 0
                    for s_hat, s, sw in zip(streams_hat, streams,
                                            stream_weights):
                        s_hat_mask = s_hat.masked_select(mask)
                        s_mask = s.masked_select(mask)
                        loss += sw * criterion(s_hat_mask, s_mask).mean()
                else:
                    # Joint modeling
                    y_hat = y_hat.masked_select(mask)
                    y = y.masked_select(mask)
                    loss = criterion(y_hat, y).mean()

                if train:
                    loss.backward()
                    optimizer.step()

                running_loss += loss.item()
            ave_loss = running_loss / len(data_loaders[phase])
            logger.info(f"[{phase}] [Epoch {epoch}]: loss {ave_loss}")
            if not train and ave_loss < best_loss:
                best_loss = ave_loss
                save_best_checkpoint(config, model, optimizer, best_loss)

        # step per each epoch (may consider updating per iter.)
        lr_scheduler.step()

        if epoch % config.train.checkpoint_epoch_interval == 0:
            save_checkpoint(config, model, optimizer, lr_scheduler, epoch)

    # save at last epoch
    save_checkpoint(config, model, optimizer, lr_scheduler,
                    config.train.nepochs)
    logger.info(f"The best loss was {best_loss}")

    return model
Example #6
0
def gen_waveform(labels,
                 acoustic_features,
                 binary_dict,
                 continuous_dict,
                 stream_sizes,
                 has_dynamic_features,
                 subphone_features="coarse_coding",
                 log_f0_conditioning=True,
                 pitch_idx=None,
                 num_windows=3,
                 post_filter=True,
                 sample_rate=48000,
                 frame_period=5,
                 relative_f0=True):
    windows = get_windows(num_windows)

    # Apply MLPG if necessary
    if np.any(has_dynamic_features):
        static_stream_sizes = get_static_stream_sizes(stream_sizes,
                                                      has_dynamic_features,
                                                      len(windows))
    else:
        static_stream_sizes = stream_sizes

    # Split multi-stream features
    mgc, target_f0, vuv, bap = split_streams(acoustic_features,
                                             static_stream_sizes)

    # Gen waveform by the WORLD vocodoer
    fftlen = pyworld.get_cheaptrick_fft_size(sample_rate)
    alpha = pysptk.util.mcepalpha(sample_rate)

    if post_filter:
        mgc = merlin_post_filter(mgc, alpha)

    spectrogram = pysptk.mc2sp(mgc, fftlen=fftlen, alpha=alpha)
    aperiodicity = pyworld.decode_aperiodicity(bap.astype(np.float64),
                                               sample_rate, fftlen)

    # fill aperiodicity with ones for unvoiced regions
    aperiodicity[vuv.reshape(-1) < 0.5, :] = 1.0
    # WORLD fails catastrophically for out of range aperiodicity
    aperiodicity = np.clip(aperiodicity, 0.0, 1.0)

    ### F0 ###
    if relative_f0:
        diff_lf0 = target_f0
        # need to extract pitch sequence from the musical score
        linguistic_features = fe.linguistic_features(
            labels,
            binary_dict,
            continuous_dict,
            add_frame_features=True,
            subphone_features=subphone_features)
        f0_score = _midi_to_hz(linguistic_features, pitch_idx, False)[:, None]
        lf0_score = f0_score.copy()
        nonzero_indices = np.nonzero(lf0_score)
        lf0_score[nonzero_indices] = np.log(f0_score[nonzero_indices])
        lf0_score = interp1d(lf0_score, kind="slinear")

        f0 = diff_lf0 + lf0_score
        f0[vuv < 0.5] = 0
        f0[np.nonzero(f0)] = np.exp(f0[np.nonzero(f0)])
    else:
        f0 = target_f0
        f0[vuv < 0.5] = 0
        f0[np.nonzero(f0)] = np.exp(f0[np.nonzero(f0)])

    generated_waveform = pyworld.synthesize(f0.flatten().astype(np.float64),
                                            spectrogram.astype(np.float64),
                                            aperiodicity.astype(np.float64),
                                            sample_rate, frame_period)

    # 音量を小さくする(音割れ防止)
    # TODO: ここのかける定数をいい感じにする
    spectrogram *= 0.000000001
    sp = pyworld.code_spectral_envelope(spectrogram, sample_rate, 60)

    return f0, sp, bap, generated_waveform
Example #7
0
def train_step(
    model,
    optimizer,
    grad_scaler,
    train,
    in_feats,
    out_feats,
    lengths,
    out_scaler,
    feats_criterion="mse",
    stream_wise_loss=False,
    stream_weights=None,
    stream_sizes=None,
):
    model.train() if train else model.eval()
    optimizer.zero_grad()

    if feats_criterion in ["l2", "mse"]:
        criterion = nn.MSELoss(reduction="none")
    elif feats_criterion in ["l1", "mae"]:
        criterion = nn.L1Loss(reduction="none")
    else:
        raise RuntimeError("not supported criterion")

    prediction_type = (model.module.prediction_type() if isinstance(
        model, nn.DataParallel) else model.prediction_type())

    # Apply preprocess if required (e.g., FIR filter for shallow AR)
    # defaults to no-op
    if isinstance(model, nn.DataParallel):
        out_feats = model.module.preprocess_target(out_feats)
    else:
        out_feats = model.preprocess_target(out_feats)

    # Run forward
    with autocast(enabled=grad_scaler is not None):
        pred_out_feats = model(in_feats, lengths)

    # Mask (B, T, 1)
    mask = make_non_pad_mask(lengths).unsqueeze(-1).to(in_feats.device)

    # Compute loss
    if prediction_type == PredictionType.PROBABILISTIC:
        pi, sigma, mu = pred_out_feats
        # (B, max(T)) or (B, max(T), D_out)
        mask_ = mask if len(pi.shape) == 4 else mask.squeeze(-1)
        # Compute loss and apply mask
        with autocast(enabled=grad_scaler is not None):
            loss = mdn_loss(pi, sigma, mu, out_feats, reduce=False)
        loss = loss.masked_select(mask_).mean()
    else:
        if stream_wise_loss:
            w = get_stream_weight(stream_weights,
                                  stream_sizes).to(in_feats.device)
            streams = split_streams(out_feats, stream_sizes)
            pred_streams = split_streams(pred_out_feats, stream_sizes)
            loss = 0
            for pred_stream, stream, sw in zip(pred_streams, streams, w):
                with autocast(enabled=grad_scaler is not None):
                    loss += (sw * criterion(pred_stream.masked_select(mask),
                                            stream.masked_select(mask)).mean())
        else:
            with autocast(enabled=grad_scaler is not None):
                loss = criterion(pred_out_feats.masked_select(mask),
                                 out_feats.masked_select(mask)).mean()

    if prediction_type == PredictionType.PROBABILISTIC:
        with torch.no_grad():
            pred_out_feats_ = mdn_get_most_probable_sigma_and_mu(
                pi, sigma, mu)[1]
    else:
        pred_out_feats_ = pred_out_feats
    distortions = compute_distortions(pred_out_feats_, out_feats, lengths,
                                      out_scaler)

    if train:
        if grad_scaler is not None:
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()
        else:
            loss.backward()
            optimizer.step()

    return loss, distortions
Example #8
0
def eval_spss_model(
    step,
    netG,
    in_feats,
    out_feats,
    lengths,
    model_config,
    out_scaler,
    writer,
    sr,
    trajectory_smoothing=True,
    trajectory_smoothing_cutoff=50,
):
    # make sure to be in eval mode
    netG.eval()
    is_autoregressive = (netG.module.is_autoregressive() if isinstance(
        netG, nn.DataParallel) else netG.is_autoregressive())
    prediction_type = (netG.module.prediction_type() if isinstance(
        netG, nn.DataParallel) else netG.prediction_type())
    utt_indices = [-1, -2, -3]
    utt_indices = utt_indices[:min(3, len(in_feats))]

    if np.any(model_config.has_dynamic_features):
        static_stream_sizes = get_static_stream_sizes(
            model_config.stream_sizes,
            model_config.has_dynamic_features,
            model_config.num_windows,
        )
    else:
        static_stream_sizes = model_config.stream_sizes

    for utt_idx in utt_indices:
        out_feats_denorm_ = out_scaler.inverse_transform(
            out_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0))
        mgc, lf0, vuv, bap = get_static_features(
            out_feats_denorm_,
            model_config.num_windows,
            model_config.stream_sizes,
            model_config.has_dynamic_features,
        )[:4]
        mgc = mgc.squeeze(0).cpu().numpy()
        lf0 = lf0.squeeze(0).cpu().numpy()
        vuv = vuv.squeeze(0).cpu().numpy()
        bap = bap.squeeze(0).cpu().numpy()

        f0, spectrogram, aperiodicity = gen_world_params(
            mgc, lf0, vuv, bap, sr)
        wav = pyworld.synthesize(f0, spectrogram, aperiodicity, sr, 5)
        group = f"utt{np.abs(utt_idx)}_reference"
        wav = wav / np.abs(wav).max() if np.max(wav) > 1.0 else wav
        writer.add_audio(group, wav, step, sr)

        # Run forward
        if is_autoregressive:
            outs = netG(
                in_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0),
                [lengths[utt_idx]],
                out_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0),
            )
        else:
            outs = netG(in_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0),
                        [lengths[utt_idx]])

        # ResF0 case
        if isinstance(outs, tuple) and len(outs) == 2:
            outs, _ = outs

        if prediction_type == PredictionType.PROBABILISTIC:
            pi, sigma, mu = outs
            pred_out_feats = mdn_get_most_probable_sigma_and_mu(pi, sigma,
                                                                mu)[1]
        else:
            pred_out_feats = outs
        # NOTE: multiple outputs
        if isinstance(pred_out_feats, list):
            pred_out_feats = pred_out_feats[-1]
        if isinstance(pred_out_feats, tuple):
            pred_out_feats = pred_out_feats[0]

        if not isinstance(pred_out_feats, list):
            pred_out_feats = [pred_out_feats]

        # Run inference
        if prediction_type == PredictionType.PROBABILISTIC:
            inference_out_feats, _ = netG.inference(
                in_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0),
                [lengths[utt_idx]])
        else:
            inference_out_feats = netG.inference(
                in_feats[utt_idx, :lengths[utt_idx]].unsqueeze(0),
                [lengths[utt_idx]])
        pred_out_feats.append(inference_out_feats)

        # Plot normalized input/output
        in_feats_ = in_feats[utt_idx, :lengths[utt_idx]].cpu().numpy()
        out_feats_ = out_feats[utt_idx, :lengths[utt_idx]].cpu().numpy()
        fig, ax = plt.subplots(3, 1, figsize=(8, 8))
        ax[0].set_title("Reference features")
        ax[1].set_title("Input features")
        ax[2].set_title("Predicted features")
        mesh = librosa.display.specshow(out_feats_.T,
                                        x_axis="frames",
                                        y_axis="frames",
                                        ax=ax[0],
                                        cmap="viridis")
        # NOTE: assuming normalized to N(0, 1)
        mesh.set_clim(-4, 4)
        fig.colorbar(mesh, ax=ax[0])
        mesh = librosa.display.specshow(in_feats_.T,
                                        x_axis="frames",
                                        y_axis="frames",
                                        ax=ax[1],
                                        cmap="viridis")
        mesh.set_clim(-4, 4)
        fig.colorbar(mesh, ax=ax[1])
        mesh = librosa.display.specshow(
            inference_out_feats.squeeze(0).cpu().numpy().T,
            x_axis="frames",
            y_axis="frames",
            ax=ax[2],
            cmap="viridis",
        )
        mesh.set_clim(-4, 4)
        fig.colorbar(mesh, ax=ax[2])
        for ax_ in ax:
            ax_.set_ylabel("Feature")
        plt.tight_layout()
        group = f"utt{np.abs(utt_idx)}_inference"
        writer.add_figure(f"{group}/Input-Output", fig, step)
        plt.close()

        assert len(pred_out_feats) == 2
        for idx, pred_out_feats_ in enumerate(pred_out_feats):
            pred_out_feats_ = pred_out_feats_.squeeze(0).cpu().numpy()
            pred_out_feats_denorm = (out_scaler.inverse_transform(
                torch.from_numpy(pred_out_feats_).to(
                    in_feats.device)).cpu().numpy())
            if np.any(model_config.has_dynamic_features):
                # (T, D_out) -> (T, static_dim)
                pred_out_feats_denorm = multi_stream_mlpg(
                    pred_out_feats_denorm,
                    (out_scaler.scale_**2).cpu().numpy(),
                    get_windows(model_config.num_windows),
                    model_config.stream_sizes,
                    model_config.has_dynamic_features,
                )
            pred_mgc, pred_lf0, pred_vuv, pred_bap = split_streams(
                pred_out_feats_denorm, static_stream_sizes)[:4]

            # Remove high-frequency components of mgc/bap
            # NOTE: It seems to be effective to suppress artifacts of GAN-based post-filtering
            if trajectory_smoothing:
                modfs = int(1 / 0.005)
                for d in range(pred_mgc.shape[1]):
                    pred_mgc[:, d] = lowpass_filter(
                        pred_mgc[:, d],
                        modfs,
                        cutoff=trajectory_smoothing_cutoff)
                for d in range(pred_bap.shape[1]):
                    pred_bap[:, d] = lowpass_filter(
                        pred_bap[:, d],
                        modfs,
                        cutoff=trajectory_smoothing_cutoff)

            # Generated sample
            f0, spectrogram, aperiodicity = gen_world_params(
                pred_mgc, pred_lf0, pred_vuv, pred_bap, sr)
            wav = pyworld.synthesize(f0, spectrogram, aperiodicity, sr, 5)
            wav = wav / np.abs(wav).max() if np.max(wav) > 1.0 else wav
            if idx == 1:
                group = f"utt{np.abs(utt_idx)}_inference"
            else:
                group = f"utt{np.abs(utt_idx)}_forward"
            writer.add_audio(group, wav, step, sr)
            plot_spsvs_params(
                step,
                writer,
                mgc,
                lf0,
                vuv,
                bap,
                pred_mgc,
                pred_lf0,
                pred_vuv,
                pred_bap,
                group=group,
                sr=sr,
            )
Example #9
0
def train_loop(config, device, model, optimizer, lr_scheduler, data_loaders):
    criterion = nn.MSELoss(reduction="none")

    logger.info("Start utterance-wise training...")

    stream_weights = get_stream_weight(
        config.model.stream_weights, config.model.stream_sizes).to(device)

    best_loss = 10000000
    for epoch in tqdm(range(1, config.train.nepochs + 1)):
        for phase in data_loaders.keys():
            train = phase.startswith("train")
            model.train() if train else model.eval()
            running_loss = 0
            for x, y, lengths in data_loaders[phase]:
                # Sort by lengths . This is needed for pytorch's PackedSequence
                sorted_lengths, indices = torch.sort(lengths, dim=0, descending=True)
                x, y = x[indices].to(device), y[indices].to(device)

                optimizer.zero_grad()

                # Apply preprocess if required (e.g., FIR filter for shallow AR)
                # defaults to no-op
                y = model.preprocess_target(y)

                # Run forwaard
                if model.prediction_type() == PredictionType.PROBABILISTIC:
                    pi, sigma, mu = model(x, sorted_lengths)

                    # (B, max(T)) or (B, max(T), D_out)
                    mask = make_non_pad_mask(sorted_lengths).to(device)
                    mask = mask.unsqueeze(-1) if len(pi.shape) == 4 else mask
                    # Compute loss and apply mask
                    loss = mdn_loss(pi, sigma, mu, y, reduce=False)
                    loss = loss.masked_select(mask).mean()
                else:
                    y_hat = model(x, sorted_lengths)

                    # Compute loss
                    mask = make_non_pad_mask(sorted_lengths).unsqueeze(-1).to(device)

                    if config.train.stream_wise_loss:
                        # Strean-wise loss
                        streams = split_streams(y, config.model.stream_sizes)
                        streams_hat = split_streams(y_hat, config.model.stream_sizes)
                        loss = 0
                        for s_hat, s, sw in zip(streams_hat, streams, stream_weights):
                            s_hat_mask = s_hat.masked_select(mask)
                            s_mask = s.masked_select(mask)
                            loss += sw * criterion(s_hat_mask, s_mask).mean()
                    else:
                        # Joint modeling
                        y_hat = y_hat.masked_select(mask)
                        y = y.masked_select(mask)
                        loss = criterion(y_hat, y).mean()

                if train:
                    loss.backward()
                    optimizer.step()

                running_loss += loss.item()
            ave_loss = running_loss / len(data_loaders[phase])
            logger.info("[%s] [Epoch %s]: loss %s", phase, epoch, ave_loss)
            if not train and ave_loss < best_loss:
                best_loss = ave_loss
                save_best_checkpoint(config, model, optimizer, best_loss)

        # step per each epoch (may consider updating per iter.)
        lr_scheduler.step()

        if epoch % config.train.checkpoint_epoch_interval == 0:
            save_checkpoint(config, model, optimizer, lr_scheduler, epoch)

    # save at last epoch
    save_checkpoint(config, model, optimizer, lr_scheduler, config.train.nepochs)
    logger.info("The best loss was {%s}", best_loss)

    return model
Example #10
0
    def forward(self, x, lengths=None, is_inference=False):
        """Forward step

        Each feature stream is processed independently.

        Args:
            x (torch.Tensor): input tensor of shape (B, T, C)
            lengths (torch.Tensor): lengths of shape (B,)

        Returns:
            torch.Tensor: output tensor of shape (B, T, C)
        """
        streams = split_streams(x, self.stream_sizes)
        if len(streams) == 4:
            mgc, lf0, vuv, bap = streams
        elif len(streams) == 5:
            mgc, lf0, vuv, bap, vuv = streams
        elif len(streams) == 6:
            mgc, lf0, vuv, bap, vib, vib_flags = streams
        else:
            raise ValueError("Invalid number of streams")

        if self.mgc_postfilter is not None:
            if self.mgc_offset > 0:
                # keep unchanged for the 0-to-${mgc_offset}-th dim of mgc
                mgc0 = mgc[:, :, :self.mgc_offset]
                if is_inference:
                    mgc_pf = self.mgc_postfilter.inference(
                        mgc[:, :, self.mgc_offset:], lengths)
                else:
                    mgc_pf = self.mgc_postfilter(mgc[:, :, self.mgc_offset:],
                                                 lengths)
                mgc_pf = torch.cat([mgc0, mgc_pf], dim=-1)
            else:
                if is_inference:
                    mgc_pf = self.mgc_postfilter.inference(mgc, lengths)
                else:
                    mgc_pf = self.mgc_postfilter(mgc, lengths)
            mgc = mgc_pf

        if self.bap_postfilter is not None:
            if self.bap_offset > 0:
                # keep unchanged for the 0-to-${bap_offset}-th dim of bap
                bap0 = bap[:, :, :self.bap_offset]
                if is_inference:
                    bap_pf = self.bap_postfilter.inference(
                        bap[:, :, self.bap_offset:], lengths)
                else:
                    bap_pf = self.bap_postfilter(bap[:, :, self.bap_offset:],
                                                 lengths)
                bap_pf = torch.cat([bap0, bap_pf], dim=-1)
            else:
                if is_inference:
                    bap_pf = self.bap_postfilter.inference(bap, lengths)
                else:
                    bap_pf = self.bap_postfilter(bap, lengths)
            bap = bap_pf

        if self.lf0_postfilter is not None:
            if is_inference:
                lf0 = self.lf0_postfilter.inference(lf0, lengths)
            else:
                lf0 = self.lf0_postfilter(lf0, lengths)

        if len(streams) == 4:
            out = torch.cat([mgc, lf0, vuv, bap], dim=-1)
        elif len(streams) == 5:
            out = torch.cat([mgc, lf0, vuv, bap, vib], dim=-1)
        elif len(streams) == 6:
            out = torch.cat([mgc, lf0, vuv, bap, vib, vib_flags], dim=-1)

        return out