예제 #1
0
def synthesis(text_list, plot=True):
    device = "cuda:2"
    # load graphs
    graph0 = Text2Mel()
    graph1 = SuperRes()
    load(os.path.join(Hyper.logdir, "text2mel/pkg/trained.pkg"),
         graph0,
         device='cpu')
    load(os.path.join(Hyper.logdir, "superres/pkg/trained.pkg"),
         graph1,
         device='cpu')
    graph0.eval()
    graph1.eval()
    # make dir
    syn_dir = os.path.join(Hyper.root_dir, "synthesis")
    if not os.path.exists(syn_dir):
        os.makedirs(syn_dir)

    # phase1: text to mel
    graph0.to(device)
    texts = [process_text(text, padding=True) for text in text_list]
    texts = torch.LongTensor(np.asarray(texts)).to(device)
    mels = torch.FloatTensor(
        np.zeros(
            (len(texts), Hyper.dim_f, Hyper.data_max_mel_length))).to(device)
    prev_atten = None
    bar = PrettyBar(Hyper.data_max_mel_length - 1)
    bar.set_description("Text to Mel")
    for t in bar:
        _, new_mel = graph0(texts, mels, None if t == 0 else t - 1, prev_atten)
        mels[:, :, t + 1].data.copy_(new_mel[:, :, t].data)
        prev_atten = graph0.attention
    for i in range(len(text_list)):
        # mels[:, :, :-1].data.copy_(mels[:, :, 1:].data)
        if plot:
            plot_attention(graph0.attention[i].cpu().data, "atten", i, True,
                           syn_dir)
            plot_spectrum(mels[i].cpu().data, "mels", i, True, syn_dir)
    del graph0

    # phase2: super resolution
    graph1.to(device)
    _, mags = graph1(mels)
    bar = PrettyBar(len(text_list))
    bar.set_description("Super Resolution")
    for i in bar:
        wav = spectrogram2wav(mags[i].cpu().data.numpy())
        wavfile.write(os.path.join(syn_dir, "syn_{}.wav".format(i)),
                      Hyper.audio_samplerate, wav)
        if plot:
            plot_spectrum(mags[i].cpu().data, "mags", i, True, syn_dir)
    del graph1
예제 #2
0
def train_text2mel(load_trained):
    # create log dir
    logdir = os.path.join(Hyper.logdir, "text2mel")
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    if not os.path.exists(os.path.join(logdir, "pkg")):
        os.mkdir(os.path.join(logdir, "pkg"))

    # device
    device = Hyper.device_text2mel

    graph = Text2Mel().to(device)
    # set the training flag
    graph.train()
    # load data and get batch maker
    names, lengths, texts = load_data()
    batch_maker = BatchMaker(Hyper.batch_size, names, lengths, texts)

    criterion_mels = nn.L1Loss().to(device)
    criterion_bd1 = nn.BCEWithLogitsLoss().to(device)
    criterion_atten = nn.L1Loss().to(device)
    optimizer = torch.optim.Adam(graph.parameters(),
                                 lr=Hyper.adam_alpha,
                                 betas=Hyper.adam_betas,
                                 eps=Hyper.adam_eps)

    lossplot_mels = LogHelper("mel_l1", logdir)
    lossplot_bd1 = LogHelper("mel_BCE", logdir)
    lossplot_atten = LogHelper("atten", logdir)

    dynamic_guide = float(Hyper.guide_weight)
    global_step = 0

    # check if load
    if load_trained > 0:
        print("load model trained for {}k batches".format(load_trained))
        global_step = load(
            os.path.join(logdir, "pkg/save_{}k.pkg".format(load_trained)),
            graph, {
                "mels": criterion_mels,
                "bd1": criterion_bd1,
                "atten": criterion_atten
            }, optimizer)
        dynamic_guide *= Hyper.guide_decay**(load_trained * 1000)

    for loop_cnt in range(
            int(Hyper.num_batches / batch_maker.num_batches() + 0.5)):
        print("loop", loop_cnt)
        bar = PrettyBar(batch_maker.num_batches())
        bar.set_description("training...")
        loss_str0 = MovingAverage()
        loss_str1 = MovingAverage()
        loss_str2 = MovingAverage()
        for bi in bar:
            batch = batch_maker.next_batch()
            # make batch
            texts = torch.LongTensor(batch["texts"]).to(device)
            # shift mel
            shift_mels = torch.FloatTensor(
                np.concatenate((np.zeros(
                    (batch["mels"].shape[0], batch["mels"].shape[1], 1)),
                                batch["mels"][:, :, :-1]),
                               axis=2)).to(device)
            # ground truth
            mels = torch.FloatTensor(batch["mels"]).to(device)

            # forward
            pred_logits, pred_mels = graph(texts, shift_mels)
            # loss
            if False:
                loss_mels = sum(
                    criterion_mels(
                        torch.narrow(pred_mels[i], -1, 0, batch["mel_lengths"]
                                     [i]),
                        torch.narrow(mels[i], -1, 0, batch["mel_lengths"][i]))
                    for i in range(batch_maker.batch_size())) / float(
                        batch_maker.batch_size())
                loss_bd1 = sum(
                    criterion_bd1(
                        torch.narrow(pred_logits[i], -1, 0,
                                     batch["mel_lengths"][i]),
                        torch.narrow(mels[i], -1, 0, batch["mel_lengths"][i]))
                    for i in range(batch_maker.batch_size())) / float(
                        batch_maker.batch_size())
            else:
                loss_mels = criterion_mels(pred_mels, mels)
                loss_bd1 = criterion_bd1(pred_logits, mels)
            # guide attention
            atten_guide = torch.FloatTensor(batch["atten_guides"]).to(device)
            atten_mask = torch.FloatTensor(batch["atten_masks"]).to(device)
            atten_mask = torch.ones_like(graph.attention)
            loss_atten = criterion_atten(
                atten_guide * graph.attention * atten_mask,
                torch.zeros_like(graph.attention)) * dynamic_guide
            loss = loss_mels + loss_bd1 + loss_atten

            # backward
            graph.zero_grad()
            optimizer.zero_grad()
            loss.backward()
            # clip grad
            nn.utils.clip_grad_value_(graph.parameters(), 1)
            optimizer.step()
            # log
            loss_str0.add(loss_mels.cpu().data.mean())
            loss_str1.add(loss_bd1.cpu().data.mean())
            loss_str2.add(loss_atten.cpu().data.mean())
            lossplot_mels.add(loss_str0(), global_step)
            lossplot_bd1.add(loss_str1(), global_step)
            lossplot_atten.add(loss_str2(), global_step)

            # adjust dynamic_guide
            # dynamic_guide = float((loss_mels + loss_bd1).cpu().data.mean() / loss_atten.cpu().data.mean())
            dynamic_guide *= Hyper.guide_decay
            if dynamic_guide < Hyper.guide_lowbound:
                dynamic_guide = Hyper.guide_lowbound
            bar.set_description(
                "gs: {}, mels: {}, bd1: {}, atten: {}, scale: {}".format(
                    global_step, loss_str0(), loss_str1(), loss_str2(),
                    "%4f" % dynamic_guide))

            # plot
            if global_step % 100 == 0:
                gs = 0
                plot_spectrum(mels[0].cpu().data, "mel_true", gs, dir=logdir)
                plot_spectrum(shift_mels[0].cpu().data,
                              "mel_input",
                              gs,
                              dir=logdir)
                plot_spectrum(pred_mels[0].cpu().data,
                              "mel_pred",
                              gs,
                              dir=logdir)
                plot_spectrum(graph.query[0].cpu().data,
                              "query",
                              gs,
                              dir=logdir)
                plot_attention(graph.attention[0].cpu().data,
                               "atten",
                               gs,
                               True,
                               dir=logdir)
                plot_attention((atten_guide)[0].cpu().data,
                               "atten_guide",
                               gs,
                               True,
                               dir=logdir)
                if global_step % 500 == 0:
                    lossplot_mels.plot()
                    lossplot_bd1.plot()
                    lossplot_atten.plot()

                if global_step % 10000 == 0:
                    save(
                        os.path.join(logdir, "pkg/save_{}k.pkg").format(
                            global_step // 1000), graph, {
                                "mels": criterion_mels,
                                "bd1": criterion_bd1,
                                "atten": criterion_atten
                            }, optimizer, global_step, True)

            # increase global step
            global_step += 1
예제 #3
0
def synthesis(text_list, plot=True):
    info = {}

    absolute_beginning = time.time()
    info["start_time"] = absolute_beginning
    Latency_beginning = None
    device = "cpu"  # "cuda:0"
    # load graphs
    graph0 = Text2Mel()
    graph1 = SuperRes()
    load(os.path.join(Hyper.logdir, "text2mel/pkg/trained.pkg"),
         graph0,
         device='cpu')
    load(os.path.join(Hyper.logdir, "superres/pkg/trained.pkg"),
         graph1,
         device='cpu')
    graph0.eval()
    graph1.eval()
    # make dir
    syn_dir = os.path.join(Hyper.root_dir, "synthesis")
    if not os.path.exists(syn_dir):
        os.makedirs(syn_dir)

    # phase1: text to mel
    graph0.to(device)
    texts = [process_text(text, padding=True) for text in text_list]
    texts = torch.LongTensor(np.asarray(texts)).to(device)
    mels = torch.FloatTensor(
        np.zeros(
            (len(texts), Hyper.dim_f, Hyper.data_max_mel_length))).to(device)
    prev_atten = None
    bar = PrettyBar(Hyper.data_max_mel_length - 1)
    bar.set_description("Text to Mel")

    begin_of_frame_synthesis = time.time()
    init_time = begin_of_frame_synthesis - absolute_beginning

    for t in bar:
        _, new_mel = graph0(texts, mels, None if t == 0 else t - 1, prev_atten)
        mels[:, :, t + 1].data.copy_(new_mel[:, :, t].data)
        prev_atten = graph0.attention

        if Latency_beginning == None:
            # from the start time of the synthesis until the first time we reach this point is called the Latency
            Latency_beginning = time.time() - absolute_beginning
            Latency_synthesis = time.time() - begin_of_frame_synthesis

    duration_mels = time.time() - begin_of_frame_synthesis

    info_mels = {}
    for i, mel in enumerate(mels):
        info_mels["syn_{}.wav".format(i)] = mel

    for i in range(len(text_list)):
        # mels[:, :, :-1].data.copy_(mels[:, :, 1:].data)
        if plot:
            plot_attention(graph0.attention[i].cpu().data, "atten", i, True,
                           syn_dir)
            plot_spectrum(mels[i].cpu().data, "mels", i, True, syn_dir)
    del graph0

    # phase2: super resolution
    graph1.to(device)
    _, mags = graph1(mels)
    duration_mags = time.time() - begin_of_frame_synthesis

    bar = PrettyBar(len(text_list))
    bar.set_description("Super Resolution")

    info_samples = {}
    info_mags = {}
    for i in bar:
        info_mags["syn_{}.wav".format(i)] = mags[i]
        wav = spectrogram2wav(mags[i].cpu().data.numpy())
        info_samples["syn_{}.wav".format(i)] = wav
        wavfile.write(os.path.join(syn_dir, "syn_{}.wav".format(i)),
                      Hyper.audio_samplerate, wav)
        if plot:
            plot_spectrum(mags[i].cpu().data, "mags", i, True, syn_dir)
    del graph1

    duration_total = time.time() - begin_of_frame_synthesis
    time_measurents = {
        "init_time": init_time,
        "Latency_beginning": Latency_beginning,
        "Latency_synthesis": Latency_synthesis,
        "duration_mels": duration_mels,
        "duration_mags": duration_mags,
        "duration_total": duration_total,
    }

    info["mels"] = info_mels
    info["mags"] = info_mags
    info["samples"] = info_samples
    info["time_measurements"] = time_measurents
    return info