def main(args):
    stft = TacotronSTFT(filter_length=hp.n_fft,
                        hop_length=hp.hop_length,
                        win_length=hp.win_length,
                        n_mel_channels=hp.n_mels,
                        sampling_rate=hp.sample_rate,
                        mel_fmin=hp.fmin,
                        mel_fmax=hp.fmax)

    wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'),
                          recursive=True)
    mel_path = os.path.join(hp.data_dir, 'mels')
    energy_path = os.path.join(hp.data_dir, 'energy')
    pitch_path = os.path.join(hp.data_dir, 'pitch')
    os.makedirs(mel_path, exist_ok=True)
    os.makedirs(energy_path, exist_ok=True)
    os.makedirs(pitch_path, exist_ok=True)
    for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
        sr, wav = read_wav_np(wavpath)
        p = pitch(wav)  # [T, ] T = Number of frames
        wav = torch.from_numpy(wav).unsqueeze(0)
        mel, mag = stft.mel_spectrogram(
            wav)  # mel [1, 80, T]  mag [1, num_mag, T]
        mel = mel.squeeze(0)  # [num_mel, T]
        mag = mag.squeeze(0)  # [num_mag, T]
        e = torch.norm(mag, dim=0)  # [T, ]
        p = p[:mel.shape[1]]
        id = os.path.basename(wavpath).split(".")[0]
        np.save('{}/{}.npy'.format(mel_path, id),
                mel.numpy(),
                allow_pickle=False)
        np.save('{}/{}.npy'.format(energy_path, id),
                e.numpy(),
                allow_pickle=False)
        np.save('{}/{}.npy'.format(pitch_path, id), p, allow_pickle=False)
示例#2
0
def preprocess(data_path, hp, file):
    stft = TacotronSTFT(
        filter_length=hp.audio.n_fft,
        hop_length=hp.audio.hop_length,
        win_length=hp.audio.win_length,
        n_mel_channels=hp.audio.n_mels,
        sampling_rate=hp.audio.sample_rate,
        mel_fmin=hp.audio.fmin,
        mel_fmax=hp.audio.fmax,
    )

    mel_path = os.path.join(hp.data.data_dir, "mels")
    energy_path = os.path.join(hp.data.data_dir, "energy")
    pitch_path = os.path.join(hp.data.data_dir, "pitch")
    avg_mel_phon = os.path.join(hp.data.data_dir, "avg_mel_ph")

    os.makedirs(mel_path, exist_ok=True)
    os.makedirs(energy_path, exist_ok=True)
    os.makedirs(pitch_path, exist_ok=True)
    os.makedirs(avg_mel_phon, exist_ok=True)
    print("Sample Rate : ", hp.audio.sample_rate)

    with open("{}".format(file), encoding="utf-8") as f:
        _metadata = [line.strip().split("|") for line in f]
    for metadata in tqdm.tqdm(_metadata, desc="preprocess wav to mel"):
        wavpath = os.path.join(data_path, metadata[4])
        sr, wav = read_wav_np(wavpath, hp.audio.sample_rate)

        dur = str_to_int_list(metadata[2])
        dur = torch.from_numpy(np.array(dur))

        p = pitch(wav, hp)  # [T, ] T = Number of frames
        wav = torch.from_numpy(wav).unsqueeze(0)
        mel, mag = stft.mel_spectrogram(
            wav)  # mel [1, 80, T]  mag [1, num_mag, T]
        mel = mel.squeeze(0)  # [num_mel, T]
        mag = mag.squeeze(0)  # [num_mag, T]
        e = torch.norm(mag, dim=0)  # [T, ]
        p = p[:mel.shape[1]]

        avg_mel_ph = _average_mel_by_duration(mel, dur)  # [num_mel, L]
        assert (avg_mel_ph.shape[0] == dur.shape[-1])

        id = os.path.basename(wavpath).split(".")[0]
        np.save("{}/{}.npy".format(mel_path, id),
                mel.numpy(),
                allow_pickle=False)
        np.save("{}/{}.npy".format(energy_path, id),
                e.numpy(),
                allow_pickle=False)
        np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False)
        np.save("{}/{}.npy".format(avg_mel_phon, id),
                avg_mel_ph.numpy(),
                allow_pickle=False)
def main(args, hp):
    stft = TacotronSTFT(
        filter_length=hp.audio.n_fft,
        hop_length=hp.audio.hop_length,
        win_length=hp.audio.win_length,
        n_mel_channels=hp.audio.n_mels,
        sampling_rate=hp.audio.sample_rate,
        mel_fmin=hp.audio.fmin,
        mel_fmax=hp.audio.fmax,
    )

    wav_files = glob.glob(os.path.join(args.data_path, "**", "*.wav"),
                          recursive=True)
    mel_path = os.path.join(hp.data.data_dir, "mels")
    energy_path = os.path.join(hp.data.data_dir, "energy")
    pitch_path = os.path.join(hp.data.data_dir, "pitch")
    os.makedirs(mel_path, exist_ok=True)
    os.makedirs(energy_path, exist_ok=True)
    os.makedirs(pitch_path, exist_ok=True)
    print("Sample Rate : ", hp.audio.sample_rate)
    for wavpath in tqdm.tqdm(wav_files, desc="preprocess wav to mel"):
        sr, wav = read_wav_np(wavpath, hp.audio.sample_rate)
        p = pitch(wav, hp)  # [T, ] T = Number of frames
        wav = torch.from_numpy(wav).unsqueeze(0)
        mel, mag = stft.mel_spectrogram(
            wav)  # mel [1, 80, T]  mag [1, num_mag, T]
        mel = mel.squeeze(0)  # [num_mel, T]
        mag = mag.squeeze(0)  # [num_mag, T]
        e = torch.norm(mag, dim=0)  # [T, ]
        p = p[:mel.shape[1]]
        id = os.path.basename(wavpath).split(".")[0]
        np.save("{}/{}.npy".format(mel_path, id),
                mel.numpy(),
                allow_pickle=False)
        np.save("{}/{}.npy".format(energy_path, id),
                e.numpy(),
                allow_pickle=False)
        np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False)
def train(args, hp, hp_str, logger, vocoder):
    os.makedirs(os.path.join(hp.train.chkpt_dir, args.name), exist_ok=True)
    os.makedirs(os.path.join(args.outdir, args.name), exist_ok=True)
    os.makedirs(os.path.join(args.outdir, args.name, "assets"), exist_ok=True)
    device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu")

    dataloader = loader.get_tts_dataset(hp.data.data_dir, hp.train.batch_size,
                                        hp)
    validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True)

    idim = len(valid_symbols)
    odim = hp.audio.num_mels
    model = fastspeech.FeedForwardTransformer(idim, odim, hp)
    # set torch device
    model = model.to(device)
    print("Model is loaded ...")
    githash = get_commit_hash()
    if args.checkpoint_path is not None:
        if os.path.exists(args.checkpoint_path):
            logger.info("Resuming from checkpoint: %s" % args.checkpoint_path)
            checkpoint = torch.load(args.checkpoint_path)
            model.load_state_dict(checkpoint["model"])
            optimizer = get_std_opt(
                model,
                hp.model.adim,
                hp.model.transformer_warmup_steps,
                hp.model.transformer_lr,
            )
            optimizer.load_state_dict(checkpoint["optim"])
            global_step = checkpoint["step"]

            if hp_str != checkpoint["hp_str"]:
                logger.warning(
                    "New hparams is different from checkpoint. Will use new.")

            if githash != checkpoint["githash"]:
                logger.warning(
                    "Code might be different: git hash is different.")
                logger.warning("%s -> %s" % (checkpoint["githash"], githash))

        else:
            print("Checkpoint does not exixts")
            global_step = 0
            return None
    else:
        print("New Training")
        global_step = 0
        optimizer = get_std_opt(
            model,
            hp.model.adim,
            hp.model.transformer_warmup_steps,
            hp.model.transformer_lr,
        )

    print("Batch Size :", hp.train.batch_size)

    num_params(model)

    os.makedirs(os.path.join(hp.train.log_dir, args.name), exist_ok=True)
    writer = SummaryWriter(os.path.join(hp.train.log_dir, args.name))
    model.train()
    forward_count = 0
    # print(model)
    for epoch in range(hp.train.epochs):
        start = time.time()
        running_loss = 0
        j = 0

        pbar = tqdm.tqdm(dataloader, desc="Loading train data")
        for data in pbar:
            global_step += 1
            x, input_length, y, _, out_length, _, dur, e, p = data
            # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel]
            #             # stop_token : [batch, T_in], out_length : [batch]

            loss, report_dict = model(
                x.cuda(),
                input_length.cuda(),
                y.cuda(),
                out_length.cuda(),
                dur.cuda(),
                e.cuda(),
                p.cuda(),
            )
            loss = loss.mean() / hp.train.accum_grad
            running_loss += loss.item()

            loss.backward()

            # update parameters
            forward_count += 1
            j = j + 1
            if forward_count != hp.train.accum_grad:
                continue
            forward_count = 0
            step = global_step

            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       hp.train.grad_clip)
            logging.debug("grad norm={}".format(grad_norm))
            if math.isnan(grad_norm):
                logging.warning("grad norm is nan. Do not update model.")
            else:
                optimizer.step()
            optimizer.zero_grad()

            if step % hp.train.summary_interval == 0:
                pbar.set_description(
                    "Average Loss %.04f Loss %.04f | step %d" %
                    (running_loss / j, loss.item(), step))

                for r in report_dict:
                    for k, v in r.items():
                        if k is not None and v is not None:
                            if "cupy" in str(type(v)):
                                v = v.get()
                            if "cupy" in str(type(k)):
                                k = k.get()
                            writer.add_scalar("main/{}".format(k), v, step)

            if step % hp.train.validation_step == 0:

                for valid in validloader:
                    x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid
                    model.eval()
                    with torch.no_grad():
                        loss_, report_dict_ = model(
                            x_.cuda(),
                            input_length_.cuda(),
                            y_.cuda(),
                            out_length_.cuda(),
                            dur_.cuda(),
                            e_.cuda(),
                            p_.cuda(),
                        )

                        mels_ = model.inference(x_[-1].cuda())  # [T, num_mel]

                    model.train()
                    for r in report_dict_:
                        for k, v in r.items():
                            if k is not None and v is not None:
                                if "cupy" in str(type(v)):
                                    v = v.get()
                                if "cupy" in str(type(k)):
                                    k = k.get()
                                writer.add_scalar("validation/{}".format(k), v,
                                                  step)

                    mels_ = mels_.T  # Out: [num_mels, T]
                    writer.add_image(
                        "melspectrogram_target_{}".format(ids_[-1]),
                        plot_spectrogram_to_numpy(
                            y_[-1].T.data.cpu().numpy()[:, :out_length_[-1]]),
                        step,
                        dataformats="HWC",
                    )
                    writer.add_image(
                        "melspectrogram_prediction_{}".format(ids_[-1]),
                        plot_spectrogram_to_numpy(mels_.data.cpu().numpy()),
                        step,
                        dataformats="HWC",
                    )

                    # print(mels.unsqueeze(0).shape)

                    audio = generate_audio(
                        mels_.unsqueeze(0), vocoder
                    )  # selecting the last data point to match mel generated above
                    audio = audio.cpu().float().numpy()
                    audio = audio / (audio.max() - audio.min()
                                     )  # get values between -1 and 1

                    writer.add_audio(
                        tag="generated_audio_{}".format(ids_[-1]),
                        snd_tensor=torch.Tensor(audio),
                        global_step=step,
                        sample_rate=hp.audio.sample_rate,
                    )

                    _, target = read_wav_np(
                        hp.data.wav_dir + f"{ids_[-1]}.wav",
                        sample_rate=hp.audio.sample_rate,
                    )

                    writer.add_audio(
                        tag=" target_audio_{}".format(ids_[-1]),
                        snd_tensor=torch.Tensor(target),
                        global_step=step,
                        sample_rate=hp.audio.sample_rate,
                    )

                ##
            if step % hp.train.save_interval == 0:
                avg_p, avg_e, avg_d = evaluate(hp, validloader, model)
                writer.add_scalar("evaluation/Pitch Loss", avg_p, step)
                writer.add_scalar("evaluation/Energy Loss", avg_e, step)
                writer.add_scalar("evaluation/Dur Loss", avg_d, step)
                save_path = os.path.join(
                    hp.train.chkpt_dir,
                    args.name,
                    "{}_fastspeech_{}_{}k_steps.pyt".format(
                        args.name, githash, step // 1000),
                )

                torch.save(
                    {
                        "model": model.state_dict(),
                        "optim": optimizer.state_dict(),
                        "step": step,
                        "hp_str": hp_str,
                        "githash": githash,
                    },
                    save_path,
                )
                logger.info("Saved checkpoint to: %s" % save_path)
        print("Time taken for epoch {} is {} sec\n".format(
            epoch + 1, int(time.time() - start)))