Exemplo n.º 1
0
def train(args, hp, hp_str, logger, vocoder):
    os.makedirs(os.path.join(hp.train.chkpt_dir, args.name), exist_ok=True)
    os.makedirs(os.path.join(args.outdir, args.name), exist_ok=True)
    os.makedirs(os.path.join(args.outdir, args.name, "assets"), exist_ok=True)
    device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu")

    dataloader = loader.get_tts_dataset(hp.data.data_dir, hp.train.batch_size,
                                        hp)
    validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True)

    idim = len(valid_symbols)
    odim = hp.audio.num_mels
    model = fastspeech.FeedForwardTransformer(idim, odim, hp)
    # set torch device
    model = model.to(device)
    print("Model is loaded ...")
    githash = get_commit_hash()
    if args.checkpoint_path is not None:
        if os.path.exists(args.checkpoint_path):
            logger.info("Resuming from checkpoint: %s" % args.checkpoint_path)
            checkpoint = torch.load(args.checkpoint_path)
            model.load_state_dict(checkpoint["model"])
            optimizer = get_std_opt(
                model,
                hp.model.adim,
                hp.model.transformer_warmup_steps,
                hp.model.transformer_lr,
            )
            optimizer.load_state_dict(checkpoint["optim"])
            global_step = checkpoint["step"]

            if hp_str != checkpoint["hp_str"]:
                logger.warning(
                    "New hparams is different from checkpoint. Will use new.")

            if githash != checkpoint["githash"]:
                logger.warning(
                    "Code might be different: git hash is different.")
                logger.warning("%s -> %s" % (checkpoint["githash"], githash))

        else:
            print("Checkpoint does not exixts")
            global_step = 0
            return None
    else:
        print("New Training")
        global_step = 0
        optimizer = get_std_opt(
            model,
            hp.model.adim,
            hp.model.transformer_warmup_steps,
            hp.model.transformer_lr,
        )

    print("Batch Size :", hp.train.batch_size)

    num_params(model)

    os.makedirs(os.path.join(hp.train.log_dir, args.name), exist_ok=True)
    writer = SummaryWriter(os.path.join(hp.train.log_dir, args.name))
    model.train()
    forward_count = 0
    # print(model)
    for epoch in range(hp.train.epochs):
        start = time.time()
        running_loss = 0
        j = 0

        pbar = tqdm.tqdm(dataloader, desc="Loading train data")
        for data in pbar:
            global_step += 1
            x, input_length, y, _, out_length, _, dur, e, p = data
            # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel]
            #             # stop_token : [batch, T_in], out_length : [batch]

            loss, report_dict = model(
                x.cuda(),
                input_length.cuda(),
                y.cuda(),
                out_length.cuda(),
                dur.cuda(),
                e.cuda(),
                p.cuda(),
            )
            loss = loss.mean() / hp.train.accum_grad
            running_loss += loss.item()

            loss.backward()

            # update parameters
            forward_count += 1
            j = j + 1
            if forward_count != hp.train.accum_grad:
                continue
            forward_count = 0
            step = global_step

            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       hp.train.grad_clip)
            logging.debug("grad norm={}".format(grad_norm))
            if math.isnan(grad_norm):
                logging.warning("grad norm is nan. Do not update model.")
            else:
                optimizer.step()
            optimizer.zero_grad()

            if step % hp.train.summary_interval == 0:
                pbar.set_description(
                    "Average Loss %.04f Loss %.04f | step %d" %
                    (running_loss / j, loss.item(), step))

                for r in report_dict:
                    for k, v in r.items():
                        if k is not None and v is not None:
                            if "cupy" in str(type(v)):
                                v = v.get()
                            if "cupy" in str(type(k)):
                                k = k.get()
                            writer.add_scalar("main/{}".format(k), v, step)

            if step % hp.train.validation_step == 0:

                for valid in validloader:
                    x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid
                    model.eval()
                    with torch.no_grad():
                        loss_, report_dict_ = model(
                            x_.cuda(),
                            input_length_.cuda(),
                            y_.cuda(),
                            out_length_.cuda(),
                            dur_.cuda(),
                            e_.cuda(),
                            p_.cuda(),
                        )

                        mels_ = model.inference(x_[-1].cuda())  # [T, num_mel]

                    model.train()
                    for r in report_dict_:
                        for k, v in r.items():
                            if k is not None and v is not None:
                                if "cupy" in str(type(v)):
                                    v = v.get()
                                if "cupy" in str(type(k)):
                                    k = k.get()
                                writer.add_scalar("validation/{}".format(k), v,
                                                  step)

                    mels_ = mels_.T  # Out: [num_mels, T]
                    writer.add_image(
                        "melspectrogram_target_{}".format(ids_[-1]),
                        plot_spectrogram_to_numpy(
                            y_[-1].T.data.cpu().numpy()[:, :out_length_[-1]]),
                        step,
                        dataformats="HWC",
                    )
                    writer.add_image(
                        "melspectrogram_prediction_{}".format(ids_[-1]),
                        plot_spectrogram_to_numpy(mels_.data.cpu().numpy()),
                        step,
                        dataformats="HWC",
                    )

                    # print(mels.unsqueeze(0).shape)

                    audio = generate_audio(
                        mels_.unsqueeze(0), vocoder
                    )  # selecting the last data point to match mel generated above
                    audio = audio.cpu().float().numpy()
                    audio = audio / (audio.max() - audio.min()
                                     )  # get values between -1 and 1

                    writer.add_audio(
                        tag="generated_audio_{}".format(ids_[-1]),
                        snd_tensor=torch.Tensor(audio),
                        global_step=step,
                        sample_rate=hp.audio.sample_rate,
                    )

                    _, target = read_wav_np(
                        hp.data.wav_dir + f"{ids_[-1]}.wav",
                        sample_rate=hp.audio.sample_rate,
                    )

                    writer.add_audio(
                        tag=" target_audio_{}".format(ids_[-1]),
                        snd_tensor=torch.Tensor(target),
                        global_step=step,
                        sample_rate=hp.audio.sample_rate,
                    )

                ##
            if step % hp.train.save_interval == 0:
                avg_p, avg_e, avg_d = evaluate(hp, validloader, model)
                writer.add_scalar("evaluation/Pitch Loss", avg_p, step)
                writer.add_scalar("evaluation/Energy Loss", avg_e, step)
                writer.add_scalar("evaluation/Dur Loss", avg_d, step)
                save_path = os.path.join(
                    hp.train.chkpt_dir,
                    args.name,
                    "{}_fastspeech_{}_{}k_steps.pyt".format(
                        args.name, githash, step // 1000),
                )

                torch.save(
                    {
                        "model": model.state_dict(),
                        "optim": optimizer.state_dict(),
                        "step": step,
                        "hp_str": hp_str,
                        "githash": githash,
                    },
                    save_path,
                )
                logger.info("Saved checkpoint to: %s" % save_path)
        print("Time taken for epoch {} is {} sec\n".format(
            epoch + 1, int(time.time() - start)))
Exemplo n.º 2
0
def create_gta(args, hp, hp_str, logger):
    os.makedirs(os.path.join(hp.data.data_dir, "gta"), exist_ok=True)
    device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu")

    dataloader = loader.get_tts_dataset(hp.data.data_dir, 1)
    validloader = loader.get_tts_dataset(hp.data.data_dir, 1, True)
    global_step = 0
    idim = len(valid_symbols)
    odim = hp.audio.num_mels
    model = fastspeech.FeedForwardTransformer(idim, odim, args)
    # set torch device
    if os.path.exists(args.checkpoint_path):
        print("\nSynthesis GTA Session...\n")
        checkpoint = torch.load(args.checkpoint_path)
        model.load_state_dict(checkpoint["model"])
    else:
        print("Checkpoint not exixts")
        return None
    model.eval()
    model = model.to(device)
    print("Model is loaded ...")
    print("Batch Size :", hp.train.batch_size)
    num_params(model)
    onlyValidation = False
    if not onlyValidation:
        pbar = tqdm.tqdm(dataloader, desc="Loading train data")
        for data in pbar:
            # start_b = time.time()
            global_step += 1
            x, input_length, y, _, out_length, ids = data
            with torch.no_grad():
                _, gta, _, _, _ = model._forward(x.cuda(), input_length.cuda(),
                                                 y.cuda(), out_length.cuda())
                # gta = model._forward(x.cuda(), input_length.cuda(), is_inference=False)
            gta = gta.cpu().numpy()

            for j in range(len(ids)):
                mel = gta[j]
                mel = mel.T
                mel = mel[:, :out_length[j]]
                mel = (mel + 4) / 8
                id = ids[j]
                np.save(
                    "{}/{}.npy".format(os.path.join(hp.data.data_dir, "gta"),
                                       id),
                    mel,
                    allow_pickle=False,
                )

    pbar = tqdm.tqdm(validloader, desc="Loading Valid data")
    for data in pbar:
        # start_b = time.time()
        global_step += 1
        x, input_length, y, _, out_length, ids = data
        with torch.no_grad():
            gta, _, _ = model._forward(x.cuda(), input_length.cuda(), y.cuda(),
                                       out_length.cuda())
            # gta = model._forward(x.cuda(), input_length.cuda(), is_inference=True)
        gta = gta.cpu().numpy()

        for j in range(len(ids)):
            print("Actual mel specs : {} = {}".format(ids[j], y[j].shape))
            print("Out length:", out_length[j])
            print("GTA size: {} = {}".format(ids[j], gta[j].shape))
            mel = gta[j]
            mel = mel.T
            mel = mel[:, :out_length[j]]
            mel = (mel + 4) / 8
            print("Mel size: {} = {}".format(ids[j], mel.shape))
            id = ids[j]
            np.save(
                "{}/{}.npy".format(os.path.join(hp.data.data_dir, "gta"), id),
                mel,
                allow_pickle=False,
            )
Exemplo n.º 3
0
def train(args):
    os.makedirs(hp.chkpt_dir, exist_ok=True)
    os.makedirs(args.outdir, exist_ok=True)
    os.makedirs(os.path.join(args.outdir, 'img'), exist_ok=True)
    device = torch.device("cuda" if hp.ngpu > 0 else "cpu")

    dataloader = loader.get_tts_dataset(hp.data_dir, hp.batch_size)
    validloader = loader.get_tts_dataset(hp.data_dir, 5, True)
    global_step = 0
    idim = hp.symbol_len
    odim = hp.num_mels
    model = fastspeech.FeedForwardTransformer(idim, odim)
    # set torch device
    model = model.to(device)
    print("Model is loaded ...")

    if args.resume is not None:
        if os.path.exists(args.resume):
            print('\nSynthesis Session...\n')
            model.load_state_dict(torch.load(args.resume), strict=False)
            optimizer = get_std_opt(model, hp.adim,
                                    hp.transformer_warmup_steps,
                                    hp.transformer_lr)
            optimizer.load_state_dict(
                torch.load(args.resume.replace("model", "optim")))
            global_step = hp.accum_grad * optimizer._step
        else:
            print("Checkpoint not exixts")
            return None
    else:
        optimizer = get_std_opt(model, hp.adim, hp.transformer_warmup_steps,
                                hp.transformer_lr)

    print("Batch Size :", hp.batch_size)
    num_params(model)

    writer = SummaryWriter(hp.log_dir)
    model.train()
    forward_count = 0
    print(model)
    for epoch in range(hp.epochs):
        start = time.time()
        running_loss = 0
        j = 0

        pbar = tqdm.tqdm(dataloader, desc='Loading train data')
        for data in pbar:
            global_step += 1
            x, input_length, y, _, out_length, _, dur, e, p = data
            # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel]
            #             # stop_token : [batch, T_in], out_length : [batch]

            loss, report_dict = model(x.cuda(), input_length.cuda(), y.cuda(),
                                      out_length.cuda(), dur.cuda(), e.cuda(),
                                      p.cuda())
            loss = loss.mean() / hp.accum_grad
            running_loss += loss.item()

            loss.backward()

            # update parameters
            forward_count += 1
            j = j + 1
            if forward_count != hp.accum_grad:
                continue
            forward_count = 0
            step = global_step
            #

            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       hp.grad_clip)
            logging.debug('grad norm={}'.format(grad_norm))
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                optimizer.step()
            optimizer.zero_grad()

            if step % hp.summary_interval == 0:
                #torch.cuda.empty_cache()

                pbar.set_description(
                    "Average Loss %.04f Loss %.04f | step %d" %
                    (running_loss / j, loss.item(), step))

                print("Losses :")
                for r in report_dict:
                    for k, v in r.items():
                        if k == 'l1_loss':
                            print("\nL1 loss :", v)
                        if k == 'before_loss':
                            print("\nBefore loss :", v)
                        if k == 'after_loss':
                            print("\nAfter loss :", v)
                        if k == 'duration_loss':
                            print("\nD loss :", v)
                        if k == 'pitch_loss':
                            print("\nP loss :", v)
                        if k == 'energy_loss':
                            print("\nE loss :", v)
                        if k is not None and v is not None:
                            if 'cupy' in str(type(v)):
                                v = v.get()
                            if 'cupy' in str(type(k)):
                                k = k.get()
                            writer.add_scalar("main/{}".format(k), v, step)

            if step % hp.validation_step == 0:
                plot_class = model.attention_plot_class
                plot_fn = plot_class(args.outdir + '/att_ws', device)
                for valid in validloader:
                    x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid
                    model.eval()
                    with torch.no_grad():
                        loss_, report_dict_ = model(x_.cuda(),
                                                    input_length_.cuda(),
                                                    y_.cuda(),
                                                    out_length_.cuda(),
                                                    dur_.cuda(), e_.cuda(),
                                                    p_.cuda())
                    att_ws = model.calculate_all_attentions(
                        x_.cuda(), input_length_.cuda(), y_.cuda(),
                        out_length_.cuda(), dur_.cuda(), e_.cuda(), p_.cuda())
                    model.train()
                    print(" Validation Losses :")
                    for r in report_dict_:
                        for k, v in r.items():
                            if k == 'l1_loss':
                                print("\nL1 loss :", v)
                            if k == 'before_loss':
                                print("\nBefore loss :", v)
                            if k == 'after_loss':
                                print("\nAfter loss :", v)
                            if k == 'duration_loss':
                                print("\nD loss :", v)
                            if k == 'pitch_loss':
                                print("\nP loss :", v)
                            if k == 'energy_loss':
                                print("\nE loss :", v)
                            if k is not None and v is not None:
                                if 'cupy' in str(type(v)):
                                    v = v.get()
                                if 'cupy' in str(type(k)):
                                    k = k.get()

                    for r in report_dict_:
                        for k, v in r.items():
                            if k is not None and v is not None:
                                if 'cupy' in str(type(v)):
                                    v = v.get()
                                if 'cupy' in str(type(k)):
                                    k = k.get()
                                writer.add_scalar("validation/{}".format(k), v,
                                                  step)

                    plot_fn.__call__(step, input_length_, out_length_, att_ws)
                    plot_fn.log_attentions(writer, step, input_length_,
                                           out_length_, att_ws)

            if step % hp.save_interval == 0:
                save_path = os.path.join(
                    hp.chkpt_dir,
                    'checkpoint_model_{}k_steps.pyt'.format(step // 1000))
                optim_path = os.path.join(
                    hp.chkpt_dir,
                    'checkpoint_optim_{}k_steps.pyt'.format(step // 1000))
                torch.save(model.state_dict(), save_path)
                torch.save(optimizer.state_dict(), optim_path)
                print("Model Saved")
        print('Time taken for epoch {} is {} sec\n'.format(
            epoch + 1, int(time.time() - start)))