예제 #1
0
def synthesize():
    if not os.path.exists(hp.sampledir): os.mkdir(hp.sampledir)

    # Load graph
    g = Graph(mode="synthesize"); print("Graph loaded")

    # Load data
    texts = load_data(mode="synthesize")
    saver = tf.train.Saver()
    config = tf.ConfigProto(allow_soft_placement = True)

    with tf.Session(config=config) as sess:
        saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)); print("Restored!")

        # Feed Forward
        ## mel
        y_hat = np.zeros((texts.shape[0], 200, hp.n_mels*hp.r), np.float32)  # hp.n_mels*hp.r
        for j in tqdm.tqdm(range(200)):
            _y_hat = sess.run(g.y_hat, {g.x: texts, g.y: y_hat})
            y_hat[:, j, :] = _y_hat[:, j, :]
        ## mag
        mags = sess.run(g.z_hat, {g.y_hat: y_hat})
        for i, mag in enumerate(mags):
            print("File {}.wav is being generated ...".format(i+1))
            audio = spectrogram2wav(mag)
            write(os.path.join(hp.sampledir, '{}.wav'.format(i+1)), hp.sr, audio)
예제 #2
0
def copy_synth_SSRN_GL(hp, outdir):

    safe_makedir(outdir)

    dataset = load_data(hp, mode="synthesis") 
    fnames, texts = dataset['fpaths'], dataset['texts']
    bases = [basename(fname) for fname in fnames]
    mels = [np.load(os.path.join(hp.coarse_audio_dir, base + '.npy')) for base in bases]
    lengths = [a.shape[0] for a in mels]
    mels = list2batch(mels, 0)

    g = SSRNGraph(hp, mode="synthesize"); print("Graph (ssrn) loaded")

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        ssrn_epoch = restore_latest_model_parameters(sess, hp, 'ssrn')

        print('Run SSRN...')
        Z = synth_mel2mag(hp, mels, g, sess)

        for i, mag in enumerate(Z):
            print("Working on %s"%(bases[i]))
            mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length             
            wav = spectrogram2wav(hp, mag)
            soundfile.write(outdir + "/%s.wav"%(bases[i]), wav, hp.sr)
예제 #3
0
def synthesis(text, args):
    m = Model()
    m_post = ModelPostNet()

    m.load_state_dict(load_checkpoint(args.restore_step1, "transformer"))
    m_post.load_state_dict(load_checkpoint(args.restore_step2, "postnet"))

    text = np.asarray(text_to_sequence(text, [hp.cleaners]))
    text = t.LongTensor(text).unsqueeze(0)
    text = text.cuda()
    mel_input = t.zeros([1, 1, 80]).cuda()
    pos_text = t.arange(1, text.size(1) + 1).unsqueeze(0)
    pos_text = pos_text.cuda()

    m = m.cuda()
    m_post = m_post.cuda()
    m.train(False)
    m_post.train(False)

    pbar = tqdm(range(args.max_len))
    with t.no_grad():
        for i in pbar:
            pos_mel = t.arange(1, mel_input.size(1) + 1).unsqueeze(0).cuda()
            mel_pred, postnet_pred, attn, stop_token, _, attn_dec = m.forward(
                text, mel_input, pos_text, pos_mel)
            mel_input = t.cat([mel_input, postnet_pred[:, -1:, :]], dim=1)

        mag_pred = m_post.forward(postnet_pred)

    wav = spectrogram2wav(mag_pred.squeeze(0).cpu().numpy())
    write(hp.sample_path + "/test.wav", hp.sr, wav)
예제 #4
0
def synthesize():
    if not os.path.exists(hp.taco_sampledir): os.mkdir(hp.taco_sampledir)

    # Load graph
    g = Graph(mode="synthesize")
    print("Graph loaded")

    # Load data
    texts = load_data(mode="synthesize")
    _, mel_ref, _ = load_spectrograms(hp.ref_wavfile)
    mel_ref = np.tile(mel_ref, (texts.shape[0], 1, 1))

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(hp.taco_logdir))
        print("Restored!")

        # Feed Forward
        ## mel
        _y_hat = sess.run(g.diff_mels_taco_hat, {
            g.random_texts_taco: texts,
            g.mels_taco: mel_ref
        })
        y_hat = _y_hat  # we can plot spectrogram

        mags = sess.run(g.diff_mags_taco_hat, {g.diff_mels_taco_hat: y_hat})
        for i, mag in enumerate(mags):
            print("File {}.wav is being generated ...".format(i + 1))
            audio = spectrogram2wav(mag)
            write(os.path.join(hp.taco_sampledir, '{}.wav'.format(i + 1)),
                  hp.sr, audio)
예제 #5
0
def synthesize():
    if not os.path.exists(hp.sampledir): os.mkdir(hp.sampledir)

    # Load graph
    g = Graph(mode="synthesize"); print("Graph loaded")

    # Load data
    texts = load_data(mode="synthesize")

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)); print("Restored!")

        # Feed Forward
        ## mel
        y_hat = np.zeros((texts.shape[0], 200, hp.n_mels*hp.r), np.float32)  # hp.n_mels*hp.r
        for j in tqdm.tqdm(range(200)):
            _y_hat = sess.run(g.y_hat, {g.x: texts, g.y: y_hat})
            y_hat[:, j, :] = _y_hat[:, j, :]
        ## mag
        mags = sess.run(g.z_hat, {g.y_hat: y_hat})
        for i, mag in enumerate(mags):
            print("File {}.wav is being generated ...".format(i+1))
            audio = spectrogram2wav(mag)
            write(os.path.join(hp.sampledir, '{}.wav'.format(i+1)), hp.sr, audio)
예제 #6
0
def synthesis(text, args):
    m = Model()
    m_post = ModelPostNet()

    m.load_state_dict(load_checkpoint(args.step1, "transformer"))
    m_post.load_state_dict(load_checkpoint(args.step2, "postnet"))

    text = np.asarray(text_to_sequence(text, [hp.cleaners]))
    text = torch.LongTensor(text).unsqueeze(0)
    text = text.cuda()

    mel_input = np.load('3_0.pt.npy')

    pos_text = torch.arange(1, text.size(1) + 1).unsqueeze(0)
    pos_text = pos_text.cuda()

    m = m.cuda()
    m_post = m_post.cuda()
    m.train(False)
    m_post.train(False)

    with torch.no_grad():
        mag_pred = m_post.forward(
            torch.from_numpy(mel_input).unsqueeze(0).cuda())

    wav = spectrogram2wav(mag_pred.squeeze(0).cpu().numpy())
    write(hp.sample_path + "/test.wav", hp.sr, wav)
예제 #7
0
def synthesize():
    if not os.path.exists(hp.sampledir): os.mkdir(hp.sampledir)

    # Load graph
    g = Graph(mode="synthesize")
    print("Graph loaded")

    # Load data
    texts, max_len = load_data(mode="synthesize")

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(hp.syn_logdir))
        print("Restored!")

        # Feed Forward
        ## mel
        y_hat = np.zeros((texts.shape[0], 200, hp.n_mels * hp.r),
                         np.float32)  # hp.n_mels*hp.r
        for j in tqdm.tqdm(range(200)):
            _y_hat = sess.run(g.y_hat, {g.x: texts, g.y: y_hat})
            y_hat[:, j, :] = _y_hat[:, j, :]

        ## alignments
        alignments = sess.run([g.alignments], {g.x: texts, g.y: y_hat})[0]
        ## mag
        mags = sess.run(g.z_hat, {g.y_hat: y_hat})
        for i, mag in enumerate(mags):
            print("File {}.wav is being generated ...".format(i + 1))
            text, alignment = texts[i], alignments[i]
            print(alignment.shape)
            print("len text", float(len(text)))
            min_sample_sec = float(
                get_EOS_index(text)) * SEC_PER_CHAR  #/SEC_PER_ITER
            print("min sec ", min_sample_sec)

            plot_test_alignment(alignment, i + 1)
            al_EOS_index = get_EOS_fire(alignment, text)
            al_EOS_index = None

            if not al_EOS_index == None:
                # trim the audio
                audio = spectrogram2wav(mag[:al_EOS_index * hp.r, :])
            else:
                audio = spectrogram2wav(mag, min_sample_sec)
            write(os.path.join(hp.sampledir, '{}.wav'.format(i + 1)), hp.sr,
                  audio)
예제 #8
0
def synth_wave(hp, mag, outfile):
    if hp.vocoder == 'griffin_lim':
        wav = spectrogram2wav(hp, mag)
        if hp.store_synth_features: # To synthesize using WaveRNN save the mag spectrum created by SSRN
           np.save(outfile.replace('.wav','.npy'), mag)   
        soundfile.write(outfile, wav, hp.sr)
    elif hp.vocoder == 'world':
        world_synthesis(mag, outfile, hp)
예제 #9
0
def synth_wave(hp, mag, outfile):
    if hp.vocoder == 'griffin_lim':
        wav = spectrogram2wav(hp, mag)
        if hp.store_synth_features:
            np.save(outfile.replace('.wav',''), mag)   
        soundfile.write(outfile, wav, hp.sr)
    elif hp.vocoder == 'world':
        world_synthesis(mag, outfile, hp)
예제 #10
0
def synth_wave(hp, mag, outfile):
    if hp.vocoder == 'griffin_lim':
        wav = spectrogram2wav(hp, mag)
        #outfile = magfile.replace('.mag.npy', '.wav')
        #outfile = outfile.replace('.npy', '.wav')
        soundfile.write(outfile, wav, hp.sr)
    elif hp.vocoder == 'world':
        world_synthesis(mag, outfile, hp)
예제 #11
0
def synthesize():
    # Load data
    X = load_test_data()

    # Load graph
    g = Graph(training=False);
    print("Graph loaded")

    # Inference
    with g.graph.as_default():
        sv = tf.train.Supervisor()
        with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            # Restore parameters
            sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir));
            print("Restored!")

            # Get model name
            mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1]

            # Synthesize
            file_id = 1
            for i in range(0, len(X), hp.batch_size):
                x = X[i:i + hp.batch_size]

                # Get melspectrogram
                mel_output = np.zeros((hp.batch_size, hp.T_y // hp.r, hp.n_mels * hp.r), np.float32)
                decoder_output = np.zeros((hp.batch_size, hp.T_y // hp.r, hp.embed_size), np.float32)
                prev_max_attentions = np.zeros((hp.batch_size,), np.int32)
                max_attentions = np.zeros((hp.batch_size, hp.T_y // hp.r))
                alignments = np.zeros((hp.T_x, hp.T_y // hp.r), np.float32)
                for j in range(hp.T_y // hp.r):
                    _mel_output, _decoder_output, _max_attentions, _alignments = \
                        sess.run([g.mel_output, g.decoder_output, g.max_attentions, g.alignments],
                                 {g.x: x,
                                  g.y1: mel_output,
                                  g.prev_max_attentions: prev_max_attentions})
                    mel_output[:, j, :] = _mel_output[:, j, :]
                    decoder_output[:, j, :] = _decoder_output[:, j, :]
                    alignments[:, j] = _alignments[0].T[:, j]
                    prev_max_attentions = _max_attentions[:, j]
                    max_attentions[:, j] = _max_attentions[:, j]
                plot_alignment(alignments[::-1, :], "sanity-check", 0)

                # Get magnitude
                mags = sess.run(g.mag_output, {g.decoder_output: decoder_output})

                # Generate wav files
                if not os.path.exists(hp.sampledir): os.makedirs(hp.sampledir)
                for mag in mags:
                    print("file id=", file_id)
                    # generate wav files
                    mag = mag * hp.mag_std + hp.mag_mean  # denormalize
                    audio = spectrogram2wav(np.power(10, mag) ** hp.sharpening_factor)
                    audio = signal.lfilter([1], [1, -hp.preemphasis], audio)
                    write(hp.sampledir + "/{}_{}.wav".format(mname, file_id), hp.sr, audio)
                    file_id += 1
예제 #12
0
def synthesis(text, args, num):
    m = Model()
    m_post = ModelPostNet()

    m.load_state_dict(load_checkpoint(args.restore_step1, "transformer"))
    m_post.load_state_dict(load_checkpoint(args.restore_step2, "postnet"))

    text = np.asarray(text_to_sequence(text, [hp.cleaners]))
    text = t.LongTensor(text).unsqueeze(0)
    text = text.cuda()
    mel_input = t.zeros([1, 1, 80]).cuda()
    pos_text = t.arange(1, text.size(1) + 1).unsqueeze(0)
    pos_text = pos_text.cuda()

    m = m.cuda()
    m_post = m_post.cuda()
    m.train(False)
    m_post.train(False)

    pbar = tqdm(range(args.max_len))
    with t.no_grad():
        for i in pbar:
            pos_mel = t.arange(1, mel_input.size(1) + 1).unsqueeze(0).cuda()
            mel_pred, postnet_pred, attn, stop_token, _, attn_dec = m.forward(
                text, mel_input, pos_text, pos_mel)
            # print('mel_pred==================',mel_pred.shape)
            # print('postnet_pred==================', postnet_pred.shape)
            mel_input = t.cat([mel_input, postnet_pred[:, -1:, :]], dim=1)
            #print(postnet_pred[:, -1:, :])
            #print(t.argmax(attn[1][1][i]).item())
            #print('mel_input==================', mel_input.shape)

    # #直接用真实mel测试postnet效果
    #aa = t.from_numpy(np.load('D:\SSHdownload\\000101.pt.npy')).cuda().unsqueeze(0)
    # # print(aa.shape)
    mag_pred = m_post.forward(postnet_pred)
    #real_mag = t.from_numpy((np.load('D:\SSHdownload\\003009.mag.npy'))).cuda().unsqueeze(dim=0)
    #wav = spectrogram2wav(postnet_pred)

    #print('shappe============',attn[2][0].shape)
    # count = 0
    # for j in range(4):
    #     count += 1
    #     attn1 = attn[0][j].cpu()
    #     plot_alignment(attn1, path='./training_loss/'+ str(args.restore_step1)+'_'+str(count)+'_'+'S'+str(num)+'.png', title='sentence'+str(num))

    attn1 = attn[0][1].cpu()
    plot_alignment(attn1,
                   path='./training_loss/' + str(args.restore_step1) + '_' +
                   'S' + str(num) + '.png',
                   title='sentence' + str(num))

    wav = spectrogram2wav(mag_pred.squeeze(0).cpu().detach().numpy())
    write(
        hp.sample_path + '/' + str(args.restore_step1) + '-' + "test" +
        str(num) + ".wav", hp.sr, wav)
예제 #13
0
def synthesize(t2m, ssrn, data_loader, batch_size=100):
    '''
    DCTTS Architecture
    Text --> Text2Mel --> SSRN --> Wav file
    '''

    text2mel_total_time = 0

    # Text2Mel
    idx2char = load_vocab()[-1]
    with torch.no_grad():
        print('='*10, ' Text2Mel ', '='*10)
        is_test = [True, False]
        total_mel_hats = torch.zeros([len(data_loader.dataset), args.max_Ty, args.n_mels]).to(DEVICE)
        mags = torch.zeros([len(data_loader.dataset), args.max_Ty*args.r, args.n_mags]).to(DEVICE)
        
        for step, (texts, mel, _) in enumerate(data_loader):
            texts = texts.to(DEVICE)
            prev_mel_hats = torch.zeros([len(texts), args.max_Ty, args.n_mels]).to(DEVICE)


            text2mel_start_time = time.time()         
            for t in tqdm(range(args.max_Ty-1), unit='B', ncols=70):
                if t == args.max_Ty - 2:
                    is_test[1] = True
                mel_hats, A, result_tuple = t2m(texts, prev_mel_hats, t, is_test) # mel: (N, Ty/r, n_mels)
                prev_mel_hats[:, t+1, :] = mel_hats[:, t, :]
		print(mel_hats.sum(), mel.sum())
            
            text2mel_finish_time = time.time()
            text2mel_total_time += (text2mel_finish_time - text2mel_start_time)

            total_mel_hats[step*batch_size:(step+1)*batch_size, :, :] = prev_mel_hats

            
            print('='*10, ' Alignment ', '='*10)
            alignments = A.cpu().detach().numpy()
            visual_texts = texts.cpu().detach().numpy()
            for idx in range(len(alignments)):
                text = [idx2char[ch] for ch in visual_texts[idx]]
                utils.plot_att(alignments[idx], text, args.global_step, path=os.path.join(args.sampledir, 'A'), name='{}.png'.format(idx))
            print('='*10, ' SSRN ', '='*10)
            # Mel --> Mag
            mags[step*batch_size:(step+1)*batch_size:, :, :] = \
                ssrn(total_mel_hats[step*batch_size:(step+1)*batch_size, :, :]) # mag: (N, Ty, n_mags)
            mags = mags.cpu().detach().numpy()
        print('='*10, ' Vocoder ', '='*10)
        for idx in trange(len(mags), unit='B', ncols=70):
            wav = utils.spectrogram2wav(mags[idx])
            write(os.path.join(args.sampledir, '{}.wav'.format(idx+1)), args.sr, wav)
 
    result = list(result_tuple)
    result.append(text2mel_total_time)

    return result
예제 #14
0
    def synthesize(self, text):
        text = text.strip() + '.' + hp.EOS_char
        char2idx, idx2char = load_vocab()

        ##  Convertir numeros a palabras
        lista_numeros = re.findall(r'\d+', text)
        for num in lista_numeros:
            text = text.replace(num, wahio(num))

        print('texto : ', text)
        text_encode = [char2idx[char] for char in text]

        ## ******** CQ********
        # Modificando el proceso a la cantidad de letras
        num_chars = len(text_encode)

        y_hat = np.zeros((1, num_chars, hp.n_mels * hp.r),
                         np.float32)  # hp.n_mels*hp.r
        for j in tqdm.tqdm(range(num_chars)):
            _y_hat = self.session.run(self.g.y_hat, {
                self.g.x: [text_encode],
                self.g.y: y_hat
            })
            y_hat[:, j, :] = _y_hat[:, j, :]

        ## mag
        mag = self.session.run(self.g.z_hat, {self.g.y_hat: y_hat})

        al_EOS_index = None

        if not al_EOS_index == None:
            # trim the audio
            audio = spectrogram2wav(mag[:al_EOS_index * hp.r, :])
        else:
            audio = spectrogram2wav(mag[0, :, :], 1)
        #write(os.path.join(hp.sampledir, '{}.wav'.format(i+1)), hp.sr, audio)

        wav = audio

        out = io.BytesIO()
        save_wav(wav, out, hp.sr)
        return out.getvalue()
예제 #15
0
파일: test.py 프로젝트: thetobysiu/AdvDCTTS
def evaluate(model, data_loader, batch_size=100):
    # valid_loss = 0.
    with torch.no_grad():
        for step, (texts, mels, mags) in tqdm(enumerate(data_loader),
                                              total=len(data_loader)):
            texts, mels = texts.to(DEVICE), mels.to(DEVICE)
            mags_hat = model(mels)  # Predict
            mags_hat = mags_hat.cpu().numpy()
            mags = mags.numpy()
            # import pdb; pdb.set_trace()
            for idx in range(len(mags)):
                fname = step * batch_size + idx
                wav = utils.spectrogram2wav(mags_hat[idx])
                write(
                    os.path.join(args.testdir, '{:03d}-gen.wav'.format(fname)),
                    args.sr, wav)
                wav = utils.spectrogram2wav(mags[idx])
                write(
                    os.path.join(args.testdir, '{:03d}-gt.wav'.format(fname)),
                    args.sr, wav)
예제 #16
0
def synth_wave(hp, magfile):
    mag = np.load(magfile)
    #print ('mag shape %s'%(str(mag.shape)))
    wav = spectrogram2wav(hp, mag)
    outfile = magfile.replace('.mag.npy', '.wav')
    outfile = outfile.replace('.npy', '.wav')
    #print magfile
    #print outfile
    #print 
    # write(outfile, hp.sr, wav)
    soundfile.write(outfile, wav, hp.sr)
예제 #17
0
def sample_audio(g, sess):
    """
    Samples audio from the generator from training examples

    Parameters:

    g : TensorFlow Graph

    sess : TensorFlow Session
    """
    mname = 'gan'
    og, act, gen = sess.run([g.q, g.z, g.outputs2_gen])
    for i, (s0, s1, s2) in enumerate(zip(og, act, gen)):
        s0 = restore_shape(s0, hp.win_length // hp.hop_length, hp.r)
        s1 = restore_shape(s1, hp.win_length // hp.hop_length, hp.r)
        s2 = restore_shape(s2, hp.win_length // hp.hop_length, hp.r)
        # generate wav files
        if hp.use_log_magnitude:
            audio0 = spectrogram2wav(np.power(np.e, s0)**hp.power)
            audio1 = spectrogram2wav(np.power(np.e, s1)**hp.power)
            audio2 = spectrogram2wav(np.power(np.e, s2)**hp.power)
        else:
            s0 = np.where(s0 < 0, 0, s0)
            s1 = np.where(s1 < 0, 0, s1)
            s2 = np.where(s2 < 0, 0, s2)
            audio0 = spectrogram2wav(s0**hp.power)
            audio1 = spectrogram2wav(s1**hp.power)
            audio2 = spectrogram2wav(s2**hp.power)
        write(hp.outputdir + "/gan_{}_org.wav".format(i), hp.sr, audio0)
        write(hp.outputdir + "/gan_{}_act.wav".format(i), hp.sr, audio1)
        write(hp.outputdir + "/gan_{}_gen.wav".format(i), hp.sr, audio2)
예제 #18
0
def copy_synth_GL(hp, outdir):

    safe_makedir(outdir)

    dataset = load_data(hp, mode="synthesis")
    fnames, texts = dataset['fpaths'], dataset['texts']
    bases = [basename(fname) for fname in fnames]

    for base in bases:
        print("Working on file %s" % (base))
        mag = np.load(os.path.join(hp.full_audio_dir, base + '.npy'))
        wav = spectrogram2wav(hp, mag)
        soundfile.write(outdir + "/%s.wav" % (base), wav, hp.sr)
예제 #19
0
def synthesize():
    # Load data
    L = load_data("synthesize")

    # Load graph
    g = Graph(mode="synthesize")
    print("Graph loaded")

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # Restore parameters
        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     'Text2Mel')
        saver1 = tf.train.Saver(var_list=var_list)
        saver1.restore(sess, tf.train.latest_checkpoint(hp.logdir + "-1"))
        print("Text2Mel Restored!")

        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'SSRN') + \
                   tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'gs')
        saver2 = tf.train.Saver(var_list=var_list)
        saver2.restore(sess, tf.train.latest_checkpoint(hp.logdir + "-2"))
        print("SSRN Restored!")

        # Feed Forward
        ## mel
        Y = np.zeros((len(L), hp.max_T, hp.n_mels), np.float32)
        prev_max_attentions = np.zeros((len(L), ), np.int32)
        for j in tqdm(range(hp.max_T)):
            _gs, _Y, _max_attentions, _alignments = \
                sess.run([g.global_step, g.Y, g.max_attentions, g.alignments],
                         {g.L: L,
                          g.mels: Y,
                          g.prev_max_attentions: prev_max_attentions})
            Y[:, j, :] = _Y[:, j, :]
            prev_max_attentions = _max_attentions[:, j]

        # Get magnitude
        Z = sess.run(g.Z, {g.Y: Y})

        # Generate wav files
        if not os.path.exists(hp.sampledir): os.makedirs(hp.sampledir)
        for i, mag in enumerate(Z):
            print("Working on file", i + 1)
            wav = spectrogram2wav(mag)
            write(hp.sampledir + "/{}.wav".format(i + 1), hp.sr, wav)
예제 #20
0
def synthesize_full(model_name, texts):
	tf.reset_default_graph()
	g = Graph(lang=lang[model_name])
	texts = [clean_text(text) for text in texts]
	print("Graph loaded")
	with tf.Session() as sess:

	    sess.run(tf.global_variables_initializer())

	    # Restore parameters
	    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'Text2Mel')
	    saver1 = tf.train.Saver(var_list=var_list)
	    saver1.restore(sess, tf.train.latest_checkpoint(os.path.join("models",model_name, "logdir-1")))
	    print("Text2Mel Restored!")

	    var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'SSRN') + \
	               tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'gs')
	    saver2 = tf.train.Saver(var_list=var_list)
	    saver2.restore(sess, tf.train.latest_checkpoint(os.path.join("models",model_name, "logdir-2")))
	    print("SSRN Restored!")

	    if len(texts) > 0:
	        L = load_text(texts,lang[model_name])
	        #print(L)
	        max_T = min(int(sum([len(text) for text in texts])*1.5), hp.max_T)
	        # Feed Forward
	        ## mel
	        Y = np.zeros((len(L), hp.max_T, hp.n_mels), np.float32)
	        prev_max_attentions = np.zeros((len(L),), np.int32)
	        for j in tqdm(range(max_T)):
	            _gs, _Y, _max_attentions, _alignments = \
	                sess.run([g.global_step, g.Y, g.max_attentions, g.alignments],
	                         {g.L: L,
	                          g.mels: Y,
	                          g.prev_max_attentions: prev_max_attentions})
	            Y[:, j, :] = _Y[:, j, :]
	            prev_max_attentions = _max_attentions[:, j]

	        # Get magnitude
	        Z = sess.run(g.Z, {g.Y: Y})

	        for i, mag in enumerate(Z):
	            print("Working on file", i+1)
	            wav = spectrogram2wav(mag)
	            write(f"{i}.wav", hp.sr, wav)
	            break
    def synthesize(self, checkpoint_path, text=None):
        """
        Synthesize audio output from the given model
        :param checkpoint_path: the model to load from
        :param text: the text to synthesize
        :return:
        """
        print('Constructing model...')
        self.model = Model(mode="synthesize")
        self.load_text(text)

        # Session
        with tf.Session() as sess:
            # saving
            sess.run(tf.global_variables_initializer())
            print('Loading checkpoint: %s' % checkpoint_path)
            saver = tf.train.import_meta_graph(checkpoint_path)
            saver.restore(sess, tf.train.latest_checkpoint(LOG_DIR))

            # Feed Forward
            # mel
            self.mels_hat = np.zeros((self.text.shape[0], 200, N_MELS * REDUCTION_FACTOR), np.float32)

            # feed inputs
            for j in tqdm.tqdm(range(200)):
                feed_dict = {
                    self.model.txt: self.text,
                    self.model.mels: self.mels_hat
                }
                mel_hat2 = sess.run(self.model.mel_hat, feed_dict)
                self.mels_hat[:, j, :] = mel_hat2[:, j, :]

            # mag
            feed_dict2 = {self.model.mel_hat: self.mels_hat}
            self.mags = sess.run(self.model.mags_hat, feed_dict2)
            for i, mag in enumerate(self.mags):
                print("File {}.wav is being generated ...".format(i + 1))
                audio = spectrogram2wav(mag)
                librosa.output.write_wav(os.path.join(SAVE_DIR, '{}.wav'.format(i + 1)), audio, SR)
예제 #22
0
def synthesize():
    # Load graph
    g = Graph(training=False)
    print("Graph loaded")
    x = load_test_data()
    with g.graph.as_default():
        sv = tf.train.Supervisor()
        with sv.managed_session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            # Restore parameters
            sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir))
            print("Restored!")

            # Get model name
            mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1]

            # Inference
            mels = np.zeros((hp.batch_size, hp.T_y // hp.r, hp.n_mels * hp.r),
                            np.float32)
            prev_max_attentions = np.zeros((hp.batch_size, ), np.int32)
            for j in range(hp.T_x):
                _mels, _max_attentions = sess.run(
                    [g.mels, g.max_attentions], {
                        g.x: x,
                        g.y1: mels,
                        g.prev_max_attentions: prev_max_attentions
                    })
                mels[:, j, :] = _mels[:, j, :]
                prev_max_attentions = _max_attentions[:, j]
            mags = sess.run(g.mags, {g.mels: mels})

    # Generate wav files
    if not os.path.exists(hp.sampledir): os.makedirs(hp.sampledir)
    for i, mag in enumerate(mags):
        # generate wav files
        mag = mag * hp.mag_std + hp.mag_mean  # denormalize
        audio = spectrogram2wav(np.exp(mag))
        write(hp.sampledir + "/{}_{}.wav".format(mname, i), hp.sr, audio)
예제 #23
0
def synthesize():
    if not os.path.exists(hp.sampledir): os.mkdir(hp.sampledir)

    # Load data
    texts = load_data(mode="synthesize")

    # reference audio
    mels, maxlen = [], 0
    files = glob(hp.ref_audio)
    for f in files:
        _, mel, _= load_spectrograms(f)
        mel = np.reshape(mel, (-1, hp.n_mels))
        maxlen = max(maxlen, mel.shape[0])
        mels.append(mel)

    ref = np.zeros((len(mels), maxlen, hp.n_mels), np.float32)
    for i, m in enumerate(mels):
        ref[i, :m.shape[0], :] = m

    # Load graph
    g = Graph(mode="synthesize"); print("Graph loaded")

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)); print("Restored!")

        # Feed Forward
        ## mel
        y_hat = np.zeros((texts.shape[0], 200, hp.n_mels*hp.r), np.float32)  # hp.n_mels*hp.r
        for j in tqdm.tqdm(range(200)):
            _y_hat = sess.run(g.y_hat, {g.x: texts, g.y: y_hat, g.ref: ref})
            y_hat[:, j, :] = _y_hat[:, j, :]
        ## mag
        mags = sess.run(g.z_hat, {g.y_hat: y_hat})
        for i, mag in enumerate(mags):
            print("File {}.wav is being generated ...".format(i+1))
            audio = spectrogram2wav(mag)
            write(os.path.join(hp.sampledir, '{}.wav'.format(i+1)), hp.sr, audio)
def create_audio_wave(text,
                      max_len,
                      transformaer_pth_step=160000,
                      postnet_pth_step=100000,
                      save=None):
    m = Model()
    m_post = ModelPostNet()
    m.load_state_dict(load_checkpoint(transformaer_pth_step, "transformer"))
    m_post.load_state_dict(load_checkpoint(postnet_pth_step, "postnet"))

    text = np.asarray(text_to_sequence(text, [hp.cleaners]))
    text = t.LongTensor(text).unsqueeze(0)
    text = text.cuda()
    mel_input = t.zeros([1, 1, 80]).cuda()
    pos_text = t.arange(1, text.size(1) + 1).unsqueeze(0)
    pos_text = pos_text.cuda()

    m = m.cuda()
    m_post = m_post.cuda()
    m.train(False)
    m_post.train(False)
    max_len = max_len
    pbar = tqdm(range(max_len))
    with t.no_grad():
        for i in pbar:
            pos_mel = t.arange(1, mel_input.size(1) + 1).unsqueeze(0).cuda()
            mel_pred, postnet_pred, attn, stop_token, _, attn_dec = m.forward(
                text, mel_input, pos_text, pos_mel)
            print(np.array(attn_dec).shape)
            mel_input = t.cat([mel_input, postnet_pred[:, -1:, :]], dim=1)

        mag_pred = m_post.forward(postnet_pred)

    wav = spectrogram2wav(mag_pred.squeeze(0).cpu().numpy())
    if save:
        write(hp.sample_path + "/test.wav", hp.sr, wav)
    return wav
예제 #25
0
def synthesize(model, data_loader, batch_size=100):
    '''
    Tacotron

    '''
    idx2char = load_vocab()[-1]
    with torch.no_grad():
        print('*' * 15, ' Synthesize ', '*' * 15)
        mags = torch.zeros(
            [len(data_loader.dataset), args.max_Ty * args.r,
             args.n_mags]).to(DEVICE)
        for step, (texts, _, _) in enumerate(data_loader):
            texts = texts.to(DEVICE)
            GO_frames = torch.zeros([texts.shape[0], 1,
                                     args.n_mels * args.r]).to(DEVICE)
            _, mags_hat, A = model(texts, GO_frames, synth=True)

            print('=' * 10, ' Alignment ', '=' * 10)
            alignments = A.cpu().detach().numpy()
            visual_texts = texts.cpu().detach().numpy()
            for idx in range(len(alignments)):
                text = [idx2char[ch] for ch in visual_texts[idx]]
                utils.plot_att(alignments[idx],
                               text,
                               args.global_step,
                               path=os.path.join(args.sampledir, 'A'),
                               name='{}.png'.format(idx + step * batch_size))
            mags[step * batch_size:(step + 1) *
                 batch_size:, :, :] = mags_hat  # mag: (N, Ty, n_mags)
        print('=' * 10, ' Vocoder ', '=' * 10)
        mags = mags.cpu().detach().numpy()
        for idx in trange(len(mags), unit='B', ncols=70):
            wav = utils.spectrogram2wav(mags[idx])
            write(os.path.join(args.sampledir, '{}.wav'.format(idx + 1)),
                  args.sr, wav)
    return None
예제 #26
0
def validate(m, val_loader, global_step, writer):
    m_post = ModelPostNet()
    state_dict = t.load('./checkpoints/checkpoint_%s_%d.pth.tar' %
                        ('postnet', 250000))
    new_state_dict = OrderedDict()
    for k, value in state_dict['model'].items():
        key = k[7:]
        new_state_dict[key] = value
    m_post.load_state_dict(new_state_dict)
    m_post.cuda()
    m_post.eval()
    m.eval()
    with t.no_grad():
        n_data, val_loss = 0, 0
        for i, data in enumerate(val_loader):
            n_data += len(data[0])
            character, mel, mel_input, pos_text, pos_mel, text_length, mel_length, fname = data

            character = character.cuda()
            mel = mel.cuda()
            mel_input = mel_input.cuda()
            pos_text = pos_text.cuda()
            pos_mel = pos_mel.cuda()

            mel_pred, postnet_pred, attn_probs, decoder_outputs, attns_enc, attns_dec, attns_style = m.forward(
                character, mel_input, pos_text, pos_mel, mel, pos_mel)

            mel_loss = nn.L1Loss()(mel_pred, mel)
            post_mel_loss = nn.L1Loss()(postnet_pred, mel)

            mask = get_mask_from_lengths(mel_length).cuda()

            loss = mel_loss + post_mel_loss
            val_loss += loss.item()
        val_loss /= n_data
    mag_pred = m_post.forward(postnet_pred)
    for i, mag in enumerate(mag_pred[:3]):
        wav = spectrogram2wav(mag.detach().cpu().numpy())
        wav_path = os.path.join(
            os.path.join(hp.checkpoint_path, hp.log_directory), 'wav')
        if not os.path.exists(wav_path):
            os.makedirs(wav_path)
        write(
            os.path.join(wav_path,
                         "val_{}_synth_{}.wav".format(fname[i], global_step)),
            hp.sr, wav)
        print("written as val_{}_synth.wav".format(fname[i]))

    attns_enc_new = []
    attns_dec_new = []
    attn_probs_new = []
    attns_style_new = []
    for i in range(len(attns_enc)):
        attns_enc_new.append(attns_enc[i].unsqueeze(0))
        attns_dec_new.append(attns_dec[i].unsqueeze(0))
        attn_probs_new.append(attn_probs[i].unsqueeze(0))
        attns_style_new.append(attns_style[i].unsqueeze(0))
    attns_enc = t.cat(attns_enc_new, 0)
    attns_dec = t.cat(attns_dec_new, 0)
    attn_probs = t.cat(attn_probs_new, 0)
    attns_style = t.cat(attns_style_new, 0)

    attns_enc = attns_enc.contiguous().view(attns_enc.size(0), hp.batch_size,
                                            hp.n_heads, attns_enc.size(2),
                                            attns_enc.size(3))
    attns_enc = attns_enc.permute(1, 0, 2, 3, 4)
    attns_dec = attns_dec.contiguous().view(attns_dec.size(0), hp.batch_size,
                                            hp.n_heads, attns_dec.size(2),
                                            attns_dec.size(3))
    attns_dec = attns_dec.permute(1, 0, 2, 3, 4)
    attn_probs = attn_probs.contiguous().view(attn_probs.size(0),
                                              hp.batch_size, hp.n_heads,
                                              attn_probs.size(2),
                                              attn_probs.size(3))
    attn_probs = attn_probs.permute(1, 0, 2, 3, 4)
    attns_style = attns_style.contiguous().view(attns_style.size(0),
                                                hp.batch_size, hp.n_heads,
                                                attns_style.size(2),
                                                attns_style.size(3))
    attns_style = attns_style.permute(1, 0, 2, 3, 4)

    save_dir = os.path.join(hp.checkpoint_path, hp.log_directory, 'figure')
    writer.add_losses(mel_loss.item(), post_mel_loss.item(), global_step,
                      'Validation')
    writer.add_alignments(attns_enc.detach().cpu(),
                          attns_dec.detach().cpu(),
                          attn_probs.detach().cpu(),
                          attns_style.detach().cpu(), mel_length, text_length,
                          global_step, 'Validation', save_dir)

    msg = "Validation| loss : {:.4f} + {:.4f} = {:.4f}".format(
        mel_loss, post_mel_loss, loss)
    stream(msg)
    m.train()
예제 #27
0
def synthesize():
    if not os.path.exists(hp.sampledir):
        os.mkdir(hp.sampledir)

    # Load data
    texts = load_data(mode="synthesize")

    # pad texts to multiple of batch_size
    texts_len = texts.shape[0]
    num_batches = int(ceil(float(texts_len) / hp.batch_size))
    padding_len = num_batches * hp.batch_size - texts_len
    texts = np.pad(texts, ((0, padding_len), (0, 0)),
                   'constant',
                   constant_values=0)

    # reference audio
    mels, maxlen = [], 0
    files = glob(hp.ref_audio)
    for f in files:
        _, mel, _ = load_spectrograms(f)
        mel = np.reshape(mel, (-1, hp.n_mels))
        maxlen = max(maxlen, mel.shape[0])
        mels.append(mel)

    ref = np.zeros((len(mels), maxlen, hp.n_mels), np.float32)
    for i, m in enumerate(mels):
        ref[i, :m.shape[0], :] = m

    # Load graph
    g = Graph(mode="synthesize")
    print("Graph loaded")

    saver = tf.train.Saver()
    with tf.Session() as sess:
        if len(sys.argv) == 1:
            saver.restore(sess, tf.train.latest_checkpoint(hp.logdir))
            print("Restored latest checkpoint")
        else:
            saver.restore(sess, sys.argv[1])
            print("Restored checkpoint: %s" % sys.argv[1])

        batches = [
            texts[i:i + hp.batch_size]
            for i in range(0, texts.shape[0], hp.batch_size)
        ]
        start = 0
        batch_index = 0
        # Feed Forward
        for batch in batches:
            ref_batch, start = looper(ref, start, hp.batch_size)
            ## mel
            y_hat = np.zeros((batch.shape[0], 200, hp.n_mels * hp.r),
                             np.float32)  # hp.n_mels*hp.r
            for j in tqdm.tqdm(range(200)):
                _y_hat = sess.run(g.y_hat, {
                    g.x: batch,
                    g.y: y_hat,
                    g.ref: ref_batch
                })
                y_hat[:, j, :] = _y_hat[:, j, :]
            ## mag
            mags = sess.run(g.z_hat, {g.y_hat: y_hat})
            for i, mag in enumerate(mags):
                index_label = batch_index * hp.batch_size + i + 1
                if index_label > texts_len:
                    break
                print("File {}.wav is being generated ...".format(index_label))
                audio = spectrogram2wav(mag)
                write(os.path.join(hp.sampledir, '{}.wav'.format(index_label)),
                      hp.sr, audio)

            batch_index += 1
예제 #28
0
def synthesis(args):
    m = Model()
    m_post = ModelPostNet()
    m_stop = ModelStopToken()
    m.load_state_dict(load_checkpoint(args.restore_step1, "transformer"))
    m_stop.load_state_dict(load_checkpoint(args.restore_step3, "stop_token"))
    m_post.load_state_dict(load_checkpoint(args.restore_step2, "postnet"))

    m=m.cuda()
    m_post = m_post.cuda()
    m_stop = m_stop.cuda()
    m.train(False)
    m_post.train(False)
    m_stop.train(False)
    test_dataset = get_dataset(hp.test_data_csv)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_transformer, drop_last=True, num_workers=1)
    ref_dataset = get_dataset(hp.test_data_csv)
    ref_dataloader = DataLoader(ref_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn_transformer, drop_last=True, num_workers=1)

    writer = get_writer(hp.checkpoint_path, hp.log_directory)

    ref_dataloader_iter = iter(ref_dataloader)
    for i, data in enumerate(test_dataloader):
        character, mel, mel_input, pos_text, pos_mel, text_length, mel_length, fname = data
        ref_character, ref_mel, ref_mel_input, ref_pos_text, ref_pos_mel, ref_text_length, ref_mel_length, ref_fname = next(ref_dataloader_iter)
        stop_tokens = t.abs(pos_mel.ne(0).type(t.float) - 1)
        mel_input = t.zeros([1,1,80]).cuda()
        stop=[]
        character = character.cuda()
        mel = mel.cuda()
        mel_input = mel_input.cuda()
        pos_text = pos_text.cuda()
        pos_mel = pos_mel.cuda()
        ref_character = ref_character.cuda()
        ref_mel = ref_mel.cuda()
        ref_mel_input = ref_mel_input.cuda()
        ref_pos_text = ref_pos_text.cuda()
        ref_pos_mel = ref_pos_mel.cuda()

        with t.no_grad():
            start=time.time()
            for i in range(args.max_len):
                pos_mel = t.arange(1,mel_input.size(1)+1).unsqueeze(0).cuda()
                mel_pred, postnet_pred, attn_probs, decoder_output, attns_enc, attns_dec, attns_style = m.forward(character, mel_input, pos_text, pos_mel, ref_mel, ref_pos_mel)
                stop_token = m_stop.forward(decoder_output)
                mel_input = t.cat([mel_input, postnet_pred[:,-1:,:]], dim=1)
                stop.append(t.sigmoid(stop_token).squeeze(-1)[0,-1])
                if stop[-1] > 0.5:
                    print("stop token at " + str(i) + " is :", stop[-1])
                    print("model inference time: ", time.time() - start)
                    break
            if stop[-1] == 0:
                continue
            mag_pred = m_post.forward(postnet_pred)
            inf_time = time.time() - start
            print("inference time: ", inf_time)

        wav = spectrogram2wav(mag_pred.squeeze(0).cpu().numpy())
        print("rtx : ", (len(wav)/hp.sr) / inf_time)
        wav_path = os.path.join(hp.sample_path, 'wav')
        if not os.path.exists(wav_path):
            os.makedirs(wav_path)
        write(os.path.join(wav_path, "text_{}_ref_{}_synth.wav".format(fname, ref_fname)), hp.sr, wav)
        print("written as text{}_ref_{}_synth.wav".format(fname, ref_fname))
        attns_enc_new=[]
        attns_dec_new=[]
        attn_probs_new=[]
        attns_style_new=[]
        for i in range(len(attns_enc)):
            attns_enc_new.append(attns_enc[i].unsqueeze(0))
            attns_dec_new.append(attns_dec[i].unsqueeze(0))
            attn_probs_new.append(attn_probs[i].unsqueeze(0))
            attns_style_new.append(attns_style[i].unsqueeze(0))
        attns_enc = t.cat(attns_enc_new, 0)
        attns_dec = t.cat(attns_dec_new, 0)
        attn_probs = t.cat(attn_probs_new, 0)
        attns_style = t.cat(attns_style_new, 0)

        attns_enc = attns_enc.contiguous().view(attns_enc.size(0), 1, hp.n_heads, attns_enc.size(2), attns_enc.size(3))
        attns_enc = attns_enc.permute(1,0,2,3,4)
        attns_dec = attns_dec.contiguous().view(attns_dec.size(0), 1, hp.n_heads, attns_dec.size(2), attns_dec.size(3))
        attns_dec = attns_dec.permute(1,0,2,3,4)
        attn_probs = attn_probs.contiguous().view(attn_probs.size(0), 1, hp.n_heads, attn_probs.size(2), attn_probs.size(3))
        attn_probs = attn_probs.permute(1,0,2,3,4)
        attns_style = attns_style.contiguous().view(attns_style.size(0), 1, hp.n_heads, attns_style.size(2), attns_style.size(3))
        attns_style = attns_style.permute(1,0,2,3,4)

        save_dir = os.path.join(hp.sample_path, 'figure', "text_{}_ref_{}_synth.wav".format(fname, ref_fname))
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        writer.add_alignments(attns_enc.detach().cpu(), attns_dec.detach().cpu(), attn_probs.detach().cpu(), attns_style.detach().cpu(), mel_length, text_length, args.restore_step1, 'Validation', save_dir)
예제 #29
0
def convert(logdir='logdir/default/train2', queue=False):

    # Load graph
    model = Model(mode="convert",
                  batch_size=hp.Convert.batch_size,
                  queue=queue)

    session_conf = tf.ConfigProto(
        allow_soft_placement=True,
        device_count={
            'CPU': 1,
            'GPU': 0
        },
        gpu_options=tf.GPUOptions(allow_growth=True,
                                  per_process_gpu_memory_fraction=0.6),
    )
    with tf.Session(config=session_conf) as sess:
        # Load trained model
        sess.run(tf.global_variables_initializer())
        model.load(sess, 'convert', logdir=logdir)

        writer = tf.summary.FileWriter(logdir, sess.graph)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        gs = Model.get_global_step(logdir)

        if queue:
            pred_log_specs, y_log_spec, ppgs = sess.run(
                [model(), model.y_spec, model.ppgs])
        else:
            mfcc, spec, mel = get_wav_batch(model.mode, model.batch_size)
            pred_log_specs, y_log_spec, ppgs = sess.run(
                [model(), model.y_spec, model.ppgs],
                feed_dict={
                    model.x_mfcc: mfcc,
                    model.y_spec: spec,
                    model.y_mel: mel
                })

        # Denormalizatoin
        # pred_log_specs = hp.mean_log_spec + hp.std_log_spec * pred_log_specs
        # y_log_spec = hp.mean_log_spec + hp.std_log_spec * y_log_spec
        # pred_log_specs = hp.min_log_spec + (hp.max_log_spec - hp.min_log_spec) * pred_log_specs
        # y_log_spec = hp.min_log_spec + (hp.max_log_spec - hp.min_log_spec) * y_log_spec

        # Convert log of magnitude to magnitude
        pred_specs, y_specs = np.e**pred_log_specs, np.e**y_log_spec

        # Emphasize the magnitude
        pred_specs = np.power(pred_specs, hp.Convert.emphasis_magnitude)
        y_specs = np.power(y_specs, hp.Convert.emphasis_magnitude)

        # Spectrogram to waveform
        audio = np.array(
            map(
                lambda spec: spectrogram2wav(
                    spec.T, hp_default.n_fft, hp_default.win_length, hp_default
                    .hop_length, hp_default.n_iter), pred_specs))
        y_audio = np.array(
            map(
                lambda spec: spectrogram2wav(
                    spec.T, hp_default.n_fft, hp_default.win_length, hp_default
                    .hop_length, hp_default.n_iter), y_specs))

        # Apply inverse pre-emphasis
        audio = inv_preemphasis(audio, coeff=hp_default.preemphasis)
        y_audio = inv_preemphasis(y_audio, coeff=hp_default.preemphasis)

        if not queue:
            # Concatenate to a wav
            y_audio = np.reshape(y_audio, (1, y_audio.size), order='C')
            audio = np.reshape(audio, (1, audio.size), order='C')

        # Write the result
        tf.summary.audio('A',
                         y_audio,
                         hp_default.sr,
                         max_outputs=hp.Convert.batch_size)
        tf.summary.audio('B',
                         audio,
                         hp_default.sr,
                         max_outputs=hp.Convert.batch_size)

        # Visualize PPGs
        heatmap = np.expand_dims(ppgs, 3)  # channel=1
        tf.summary.image('PPG', heatmap, max_outputs=ppgs.shape[0])

        writer.add_summary(sess.run(tf.summary.merge_all()), global_step=gs)
        writer.close()

        coord.request_stop()
        coord.join(threads)
예제 #30
0
def train(log_dir, dataset_size, start_epoch=0):
    # log directory
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    if not os.path.exists(os.path.join(log_dir, 'state')):
        os.mkdir(os.path.join(log_dir, 'state'))
    if not os.path.exists(os.path.join(log_dir, 'wav')):
        os.mkdir(os.path.join(log_dir, 'wav'))
    if not os.path.exists(os.path.join(log_dir, 'state_opt')):
        os.mkdir(os.path.join(log_dir, 'state_opt'))
    if not os.path.exists(os.path.join(log_dir, 'attn')):
        os.mkdir(os.path.join(log_dir, 'attn'))
    if not os.path.exists(os.path.join(log_dir, 'test_wav')):
        os.mkdir(os.path.join(log_dir, 'test_wav'))

    f = open(os.path.join(log_dir, 'log{}.txt'.format(start_epoch)), 'w')

    msg = 'use {}'.format(hp.device)
    print(msg)
    f.write(msg + '\n')

    # load model
    model = Tacotron().to(device)
    if torch.cuda.device_count() > 1:
        model = DataParallel(model)
    if start_epoch != 0:
        model_path = os.path.join(log_dir, 'state',
                                  'epoch{}.pt'.format(start_epoch))
        model.load_state_dict(torch.load(model_path))
        msg = 'Load model of' + model_path
    else:
        msg = 'New model'
    print(msg)
    f.write(msg + '\n')

    # load optimizer
    optimizer = optim.Adam(model.parameters(), lr=hp.lr)
    if start_epoch != 0:
        opt_path = os.path.join(log_dir, 'state_opt',
                                'epoch{}.pt'.format(start_epoch))
        optimizer.load_state_dict(torch.load(opt_path))
        msg = 'Load optimizer of' + opt_path
    else:
        msg = 'New optimizer'
    print(msg)
    f.write(msg + '\n')

    # print('lr = {}'.format(hp.lr))

    model = model.to(device)
    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.to(device)

    criterion = TacotronLoss()  # Loss

    # load data
    if dataset_size is None:
        train_dataset = SpeechDataset(r=slice(hp.eval_size, None))
    else:
        train_dataset = SpeechDataset(r=slice(hp.eval_size, hp.eval_size +
                                              dataset_size))

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=hp.batch_size,
                              collate_fn=collate_fn,
                              num_workers=8,
                              shuffle=True)

    num_train_data = len(train_dataset)
    total_step = hp.num_epochs * num_train_data // hp.batch_size
    start_step = start_epoch * num_train_data // hp.batch_size
    step = 0
    global_step = step + start_step
    prev = beg = int(time())

    for epoch in range(start_epoch + 1, hp.num_epochs):

        model.train(True)
        for i, batch in enumerate(train_loader):
            step += 1
            global_step += 1

            texts = batch['text'].to(device)
            mels = batch['mel'].to(device)
            mags = batch['mag'].to(device)

            optimizer.zero_grad()

            mels_input = mels[:, :-1, :]  # shift
            mels_input = mels_input[:, :, -hp.n_mels:]  # get last frame
            ref_mels = mels[:, 1:, :]

            mels_hat, mags_hat, _ = model(texts, mels_input, ref_mels)

            mel_loss, mag_loss = criterion(mels[:, 1:, :], mels_hat, mags,
                                           mags_hat)
            loss = mel_loss + mag_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           1.)  # clip gradients
            optimizer.step()
            # scheduler.step()

            if global_step in hp.lr_step:
                optimizer = set_lr(optimizer, global_step, f)

            if (i + 1) % hp.log_per_batch == 0:
                now = int(time())
                use_time = now - prev
                # total_time = hp.num_epoch * (now - beg) * num_train_data // (hp.batch_size * (i + 1) + epoch * num_train_data)
                total_time = total_step * (now - beg) // step
                left_time = total_time - (now - beg)
                left_time_h = left_time // 3600
                left_time_m = left_time // 60 % 60
                msg = 'step: {}/{}, epoch: {}, batch {}, loss: {:.3f}, mel_loss: {:.3f}, mag_loss: {:.3f}, use_time: {}s, left_time: {}h {}m'
                msg = msg.format(global_step, total_step, epoch, i + 1,
                                 loss.item(), mel_loss.item(), mag_loss.item(),
                                 use_time, left_time_h, left_time_m)

                f.write(msg + '\n')
                print(msg)

                prev = now

        # save model, optimizer and evaluate
        if epoch % hp.save_per_epoch == 0 and epoch != 0:
            torch.save(model.state_dict(),
                       os.path.join(log_dir, 'state/epoch{}.pt'.format(epoch)))
            torch.save(
                optimizer.state_dict(),
                os.path.join(log_dir, 'state_opt/epoch{}.pt'.format(epoch)))
            msg = 'save model, optimizer in epoch{}'.format(epoch)
            f.write(msg + '\n')
            print(msg)

            model.eval()

            #for file in os.listdir(hp.ref_wav):
            wavfile = hp.ref_wav
            name, _ = os.path.splitext(hp.ref_wav.split('/')[-1])

            text, mel, ref_mels = get_eval_data(hp.eval_text, wavfile)
            text = text.to(device)
            mel = mel.to(device)
            ref_mels = ref_mels.to(device)

            mel_hat, mag_hat, attn = model(text, mel, ref_mels)

            mag_hat = mag_hat.squeeze().detach().cpu().numpy()
            attn = attn.squeeze().detach().cpu().numpy()
            plt.imshow(attn.T, cmap='hot', interpolation='nearest')
            plt.xlabel('Decoder Steps')
            plt.ylabel('Encoder Steps')
            fig_path = os.path.join(log_dir,
                                    'attn/epoch{}-{}.png'.format(epoch, name))
            plt.savefig(fig_path, format='png')

            wav = spectrogram2wav(mag_hat)
            write(
                os.path.join(log_dir,
                             'wav/epoch{}-{}.wav'.format(epoch, name)), hp.sr,
                wav)

            msg = 'synthesis eval wav in epoch{} model'.format(epoch)
            print(msg)
            f.write(msg)

    msg = 'Training Finish !!!!'
    f.write(msg + '\n')
    print(msg)

    f.close()
예제 #31
0
def synthesize(input_text,
               model_path,
               g,
               process_callback=None,
               elapsed_callback=None):
    # process_callback: pyqtsignal callback

    # Load text data
    if input_text:  # use user input
        L = data_load.load_data_text(input_text)
    else:  # use txt file
        L = data_load.load_data("synthesize")
    ## Load graph <- Graph loaded in main GUI thread or worker thread
    #g = Graph(mode="synthesize")
    #print("Graph loaded")

    start = time.time()
    # print(model_path)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # Restore parameters
        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     'Text2Mel')
        saver1 = tf.train.Saver(var_list=var_list)
        saver1.restore(sess, tf.train.latest_checkpoint(model_path + "-1"))
        print("Text2Mel Restored!")

        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'SSRN') + \
                   tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'gs')
        saver2 = tf.train.Saver(var_list=var_list)
        saver2.restore(sess, tf.train.latest_checkpoint(model_path + "-2"))
        print("SSRN Restored!")

        # Feed Forward
        ## mel
        Y = np.zeros((len(L), hp.max_T, hp.n_mels), np.float32)
        prev_max_attentions = np.zeros((len(L), ), np.int32)
        for j in tqdm(range(hp.max_T)):
            _gs, _Y, _max_attentions, _alignments = \
                sess.run([g.global_step, g.Y, g.max_attentions, g.alignments],
                         {g.L: L,
                          g.mels: Y,
                          g.prev_max_attentions: prev_max_attentions})
            Y[:, j, :] = _Y[:, j, :]
            prev_max_attentions = _max_attentions[:, j]
            if process_callback:
                if j % 5 == 0:
                    elapsed = time.time() - start
                    #process_callback(j,elapsed)
                    process_callback.emit(j / hp.max_T * 100)
                    if elapsed_callback:
                        elapsed_callback.emit(int(elapsed))

        # Get magnitude
        Z = sess.run(g.Z, {g.Y: Y})

        # Generate wav files
        #if not os.path.exists(hp.sampledir): os.makedirs(hp.sampledir)
        output = []
        for i, mag in enumerate(Z):
            print("Working on file", i + 1)
            wav = spectrogram2wav(mag)
            # Normalize from 32bit float to signed 16bit wav
            wav = (wav / np.amax(wav) * 32767).astype(np.int16)
            output.append(wav)
            #print('writing file')
            #write(hp.sampledir + "/{}.wav".format(i+1), hp.sr, wav)
        if process_callback:
            elapsed = time.time() - start
            #process_callback(hp.max_T,elapsed)
            process_callback.emit(100)
            if elapsed_callback:
                elapsed_callback.emit(int(elapsed))
        outwav = np.concatenate(output)
        return outwav