示例#1
0
def get_and_plot_alignments(hp, epoch, attention_graph, sess, attention_inputs, attention_mels, alignment_dir):
    return_values = sess.run([attention_graph.alignments], # use attention_graph to obtain attention maps for a few given inputs and mels
                             {attention_graph.L: attention_inputs, 
                              attention_graph.mels: attention_mels}) 
    alignments = return_values[0] # sess run returns a list, so unpack this list
    for i in range(hp.num_sentences_to_plot_attention):
        plot_alignment(hp, alignments[i], i+1, epoch, dir=alignment_dir)
示例#2
0
def synthesis(test_lines, model, device, log_dir):
    global global_epoch
    global global_step
    synthesis_dir = os.path.join(log_dir, "synthesis_mels")
    os.makedirs(synthesis_dir, exist_ok=True)
    model.eval()
    with torch.no_grad():
        for idx, line in enumerate(test_lines):
            txt = text_to_seq(line)
            if device > -1:
                txt = txt.cuda(device)
            frames, _, _, alignment = model(txt)

            dst_alignment_path = join(
                synthesis_dir, "{}_alignment_{}.png".format(global_step, idx))
            dst_mels_path = join(synthesis_dir,
                                 "{}_mels_{}.npy".format(global_step, idx))
            plot_alignment(alignment.T,
                           dst_alignment_path,
                           info="{}, {}".format(hparams.builder, global_step))
            np.save(dst_mels_path, frames)
示例#3
0
def synthesize(hp, speaker_id='', num_sentences=0, ncores=1, topoutdir='', t2m_epoch=-1, ssrn_epoch=-1):
    '''
    topoutdir: store samples under here; defaults to hp.sampledir
    t2m_epoch and ssrn_epoch: default -1 means use latest. Otherwise go to archived models.
    '''
    assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported'

    dataset = load_data(hp, mode="synthesis") #since mode != 'train' or 'validation', will load test_transcript rather than transcript
    fpaths, L = dataset['fpaths'], dataset['texts']
    position_in_phone_data = duration_data = labels = None # default
    if hp.use_external_durations:
        duration_data = dataset['durations']
        if num_sentences > 0:
            duration_data = duration_data[:num_sentences, :, :]

    if 'position_in_phone' in hp.history_type:
        ## TODO: combine + deduplicate with relevant code in train.py for making validation set
        def duration2position(duration, fractional=False):     
            ### very roundabout -- need to deflate A matrix back to integers:
            duration = duration.sum(axis=0)
            #print(duration)
            # sys.exit('evs')   
            positions = durations_to_position(duration, fractional=fractional)
            ###positions = end_pad_for_reduction_shape_sync(positions, hp)
            positions = positions[0::hp.r, :]         
            #print(positions)
            return positions

        position_in_phone_data = [duration2position(dur, fractional=('fractional' in hp.history_type)) \
                        for dur in duration_data]       
        position_in_phone_data = list2batch(position_in_phone_data, hp.max_T)



    # Ensure we aren't trying to generate more utterances than are actually in our test_transcript
    if num_sentences > 0:
        assert num_sentences < len(fpaths)
        L = L[:num_sentences, :]
        fpaths = fpaths[:num_sentences]

    bases = [basename(fpath) for fpath in fpaths]

    if hp.merlin_label_dir:
        labels = [np.load("{}/{}".format(hp.merlin_label_dir, basename(fpath)+".npy")) \
                              for fpath in fpaths ]
        labels = list2batch(labels, hp.max_N)


    if speaker_id:
        speaker2ix = dict(zip(hp.speaker_list, range(len(hp.speaker_list))))
        speaker_ix = speaker2ix[speaker_id]

        ## Speaker codes are held in (batch, 1) matrix -- tiling is done inside the graph:
        speaker_data = np.ones((len(L), 1))  *  speaker_ix
    else:
        speaker_data = None

    # Load graph 
    ## TODO: generalise to combine other types of models into a synthesis pipeline?
    g1 = Text2MelGraph(hp, mode="synthesize"); print("Graph 1 (t2m) loaded")
    g2 = SSRNGraph(hp, mode="synthesize"); print("Graph 2 (ssrn) loaded")

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        ### TODO: specify epoch from comm line?
        ### TODO: t2m and ssrn from separate configs?

        if t2m_epoch > -1:
            restore_archived_model_parameters(sess, hp, 't2m', t2m_epoch)
        else:
            t2m_epoch = restore_latest_model_parameters(sess, hp, 't2m')

        if ssrn_epoch > -1:    
            restore_archived_model_parameters(sess, hp, 'ssrn', ssrn_epoch)
        else:
            ssrn_epoch = restore_latest_model_parameters(sess, hp, 'ssrn')

        # Pass input L through Text2Mel Graph
        t = start_clock('Text2Mel generating...')
        ### TODO: after futher efficiency testing, remove this fork
        if 1:  ### efficient route -- only make K&V once  ## 3.86, 3.70, 3.80 seconds (2 sentences)
            text_lengths = get_text_lengths(L)
            K, V = encode_text(hp, L, g1, sess, speaker_data=speaker_data, labels=labels)
            Y, lengths, alignments = synth_codedtext2mel(hp, K, V, text_lengths, g1, sess, \
                                speaker_data=speaker_data, duration_data=duration_data, \
                                position_in_phone_data=position_in_phone_data,\
                                labels=labels)
        else: ## 5.68, 5.43, 5.38 seconds (2 sentences)
            Y, lengths = synth_text2mel(hp, L, g1, sess, speaker_data=speaker_data, \
                                            duration_data=duration_data, \
                                            position_in_phone_data=position_in_phone_data, \
                                            labels=labels)
        stop_clock(t)

        ### TODO: useful to test this?
        # print(Y[0,:,:])
        # print (np.isnan(Y).any())
        # print('nan1')
        # Then pass output Y of Text2Mel Graph through SSRN graph to get high res spectrogram Z.
        t = start_clock('Mel2Mag generating...')
        Z = synth_mel2mag(hp, Y, g2, sess)
        stop_clock(t) 

        if (np.isnan(Z).any()):  ### TODO: keep?
            Z = np.nan_to_num(Z)

        # Generate wav files
        if not topoutdir:
            topoutdir = hp.sampledir
        outdir = os.path.join(topoutdir, 't2m%s_ssrn%s'%(t2m_epoch, ssrn_epoch))
        if speaker_id:
            outdir += '_speaker-%s'%(speaker_id)
        safe_makedir(outdir)
        print("Generating wav files, will save to following dir: %s"%(outdir))

        
        assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported'

        if ncores==1:
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length
                synth_wave(hp, mag, outfile)
        else:
            executor = ProcessPoolExecutor(max_workers=ncores)    
            futures = []
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length
                futures.append(executor.submit(synth_wave, hp, mag, outfile))
            proc_list = [future.result() for future in tqdm(futures)]

        # for i, mag in enumerate(Z):
        #     print("Working on %s"%(bases[i]))
        #     mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length
            
        #     if hp.vocoder=='magphase_compressed':
        #         mag = denorm(mag, s, hp.normtype)
        #         streams = split_streams(mag, ['mag', 'lf0', 'vuv', 'real', 'imag'], [60,1,1,45,45])
        #         wav = magphase_synth_from_compressed(streams, samplerate=hp.sr)
        #     elif hp.vocoder=='griffin_lim':                
        #         wav = spectrogram2wav(hp, mag)
        #     else:
        #         sys.exit('Unsupported vocoder type: %s'%(hp.vocoder))
        #     #write(outdir + "/{}.wav".format(bases[i]), hp.sr, wav)
        #     soundfile.write(outdir + "/{}.wav".format(bases[i]), wav, hp.sr)
            

            
        # Plot attention alignments 
        for i in range(num_sentences):
            plot_alignment(hp, alignments[i], utt_idx=i+1, t2m_epoch=t2m_epoch, dir=outdir)
示例#4
0
def main():
    # DataSet Loader
    if args.dataset == "ljspeech":
        from datasets.ljspeech import LJSpeech

        # LJSpeech-1.1 dataset loader
        ljs = LJSpeech(
            path=cfg.dataset_path,
            save_to='npy',
            load_from=None
            if not os.path.exists(cfg.dataset_path + "/npy") else "npy",
            verbose=cfg.verbose)
    else:
        raise NotImplementedError("[-] Not Implemented Yet...")

    # Train/Test split
    tr_size = int(len(ljs) * (1. - cfg.test_size))

    tr_text_data, va_text_data = \
        ljs.text_data[:tr_size], ljs.text_data[tr_size:]
    tr_text_len_data, va_text_len_data = \
        ljs.text_len_data[:tr_size], ljs.text_len_data[tr_size:]
    tr_mels, va_mels = ljs.mels[:tr_size], ljs.mels[tr_size:]
    tr_mags, va_mags = ljs.mags[:tr_size], ljs.mags[tr_size:]

    del ljs  # memory release

    # Data Iterator
    di = DataIterator(text=tr_text_data,
                      text_len=tr_text_len_data,
                      mel=tr_mels,
                      mag=tr_mags,
                      batch_size=cfg.batch_size)

    if cfg.verbose:
        print("[*] Train/Test split : %d/%d (%.2f/%.2f)" %
              (tr_text_data.shape[0], va_text_data.shape[0],
               1. - cfg.test_size, cfg.test_size))
        print("  Train")
        print("\ttext     : ", tr_text_data.shape)
        print("\ttext_len : ", tr_text_len_data.shape)
        print("\tmels     : ", tr_mels.shape)
        print("\tmags     : ", tr_mags.shape)
        print("  Test")
        print("\ttext     : ", va_text_data.shape)
        print("\ttext_len : ", va_text_len_data.shape)
        print("\tmels     : ", va_mels.shape)
        print("\tmags     : ", va_mags.shape)

    # Model Loading
    gpu_config = tf.GPUOptions(allow_growth=True)
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False,
                            gpu_options=gpu_config)

    with tf.Session(config=config) as sess:
        if cfg.model == "Tacotron":
            model = Tacotron(sess=sess,
                             mode=args.mode,
                             sample_rate=cfg.sample_rate,
                             vocab_size=cfg.vocab_size,
                             embed_size=cfg.embed_size,
                             n_mels=cfg.n_mels,
                             n_fft=cfg.n_fft,
                             reduction_factor=cfg.reduction_factor,
                             n_encoder_banks=cfg.n_encoder_banks,
                             n_decoder_banks=cfg.n_decoder_banks,
                             n_highway_blocks=cfg.n_highway_blocks,
                             lr=cfg.lr,
                             lr_decay=cfg.lr_decay,
                             optimizer=cfg.optimizer,
                             grad_clip=cfg.grad_clip,
                             model_path=cfg.model_path)
        else:
            raise NotImplementedError("[-] Not Implemented Yet...")

        if cfg.verbose:
            print("[*] %s model is loaded!" % cfg.model)

        # Initializing
        sess.run(tf.global_variables_initializer())

        # Load model & Graph & Weights
        global_step = 0
        ckpt = tf.train.get_checkpoint_state(cfg.model_path)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            model.saver.restore(sess, ckpt.model_checkpoint_path)

            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print("[+] global step : %d" % global_step, " successfully loaded")
        else:
            print('[-] No checkpoint file found')

        start_time = time.time()

        best_loss = np.inf
        batch_size = cfg.batch_size
        model.global_step.assign(tf.constant(global_step))
        restored_epochs = global_step // (di.text.shape[0] // batch_size)
        for epoch in range(restored_epochs, cfg.epochs):
            for text, text_len, mel, mag in di.iterate():
                batch_start = time.time()
                _, y_loss, z_loss = sess.run(
                    [model.train_op, model.y_loss, model.z_loss],
                    feed_dict={
                        model.x: text,
                        model.x_len: text_len,
                        model.y: mel,
                        model.z: mag,
                    })
                batch_end = time.time()

                if global_step and global_step % cfg.logging_step == 0:
                    va_y_loss, va_z_loss = 0., 0.

                    va_batch = 20
                    va_iter = len(va_text_data)
                    for idx in range(0, va_iter, va_batch):
                        va_y, va_z = sess.run(
                            [model.y_loss, model.z_loss],
                            feed_dict={
                                model.x:
                                va_text_data[va_batch * idx:va_batch *
                                             (idx + 1)],
                                model.x_len:
                                va_text_len_data[va_batch * idx:va_batch *
                                                 (idx + 1)],
                                model.y:
                                va_mels[va_batch * idx:va_batch * (idx + 1)],
                                model.z:
                                va_mags[va_batch * idx:va_batch * (idx + 1)],
                            })

                        va_y_loss += va_y
                        va_z_loss += va_z

                    va_y_loss /= (va_iter // va_batch)
                    va_z_loss /= (va_iter // va_batch)

                    print(
                        "[*] epoch %03d global step %07d [%.03f sec/step]" %
                        (epoch, global_step, (batch_end - batch_start)),
                        " Train \n"
                        " y_loss : {:.6f} z_loss : {:.6f}".format(
                            y_loss, z_loss), " Valid \n"
                        " y_loss : {:.6f} z_loss : {:.6f}".format(
                            va_y_loss, va_z_loss))

                    # summary
                    summary = sess.run(model.merged,
                                       feed_dict={
                                           model.x:
                                           va_text_data[:batch_size],
                                           model.x_len:
                                           va_text_len_data[:batch_size],
                                           model.y:
                                           va_mels[:batch_size],
                                           model.z:
                                           va_mags[:batch_size],
                                       })

                    # getting/plotting alignment (important)
                    alignment = sess.run(model.alignments,
                                         feed_dict={
                                             model.x:
                                             va_text_data[:batch_size],
                                             model.x_len:
                                             va_text_len_data[:batch_size],
                                             model.y:
                                             va_mels[:batch_size],
                                         })

                    plot_alignment(alignments=alignment,
                                   gs=global_step,
                                   path=os.path.join(cfg.model_path,
                                                     "alignments"))

                    # Summary saver
                    model.writer.add_summary(summary, global_step)

                    # Model save
                    model.saver.save(sess,
                                     cfg.model_path + '%s.ckpt' % cfg.model,
                                     global_step=global_step)

                    if va_y_loss + va_z_loss < best_loss:
                        model.best_saver.save(sess,
                                              cfg.model_path +
                                              '%s-best_loss.ckpt' % cfg.model,
                                              global_step=global_step)
                        best_loss = va_y_loss + va_z_loss

                model.global_step.assign_add(tf.constant(1))
                global_step += 1

        end_time = time.time()

        print("[+] Training Done! Elapsed {:.8f}s".format(end_time -
                                                          start_time))
示例#5
0
def synthesize(hp, speaker_id='', num_sentences=0, ncores=1, topoutdir='', t2m_epoch=-1, ssrn_epoch=-1):
    '''
    topoutdir: store samples under here; defaults to hp.sampledir
    t2m_epoch and ssrn_epoch: default -1 means use latest. Otherwise go to archived models.
    '''
    assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported'

    dataset = load_data(hp, mode="synthesis") #since mode != 'train' or 'validation', will load test_transcript rather than transcript
    fpaths, L = dataset['fpaths'], dataset['texts']
    position_in_phone_data = duration_data = labels = None # default
    if hp.use_external_durations:
        duration_data = dataset['durations']
        if num_sentences > 0:
            duration_data = duration_data[:num_sentences, :, :]

    if 'position_in_phone' in hp.history_type:
        ## TODO: combine + deduplicate with relevant code in train.py for making validation set
        def duration2position(duration, fractional=False):     
            ### very roundabout -- need to deflate A matrix back to integers:
            duration = duration.sum(axis=0)
            #print(duration)
            # sys.exit('evs')   
            positions = durations_to_position(duration, fractional=fractional)
            ###positions = end_pad_for_reduction_shape_sync(positions, hp)
            positions = positions[0::hp.r, :]         
            #print(positions)
            return positions

        position_in_phone_data = [duration2position(dur, fractional=('fractional' in hp.history_type)) \
                        for dur in duration_data]       
        position_in_phone_data = list2batch(position_in_phone_data, hp.max_T)



    # Ensure we aren't trying to generate more utterances than are actually in our test_transcript
    if num_sentences > 0:
        assert num_sentences <= len(fpaths)
        L = L[:num_sentences, :]
        fpaths = fpaths[:num_sentences]

    bases = [basename(fpath) for fpath in fpaths]

    if hp.merlin_label_dir:
        labels = []
        for fpath in fpaths:
            label = np.load("{}/{}".format(hp.merlin_label_dir, basename(fpath)+".npy"))
            if hp.select_central:
                central_ind = get_labels_indices(hp.merlin_lab_dim)
                label = label[:,central_ind==1] 
            labels.append(label)

        labels = list2batch(labels, hp.max_N)


    if speaker_id:
        speaker2ix = dict(zip(hp.speaker_list, range(len(hp.speaker_list))))
        speaker_ix = speaker2ix[speaker_id]

        ## Speaker codes are held in (batch, 1) matrix -- tiling is done inside the graph:
        speaker_data = np.ones((len(L), 1))  *  speaker_ix
    else:
        speaker_data = None
   
    if hp.turn_off_monotonic_for_synthesis: # if FIA mechanism is turn off
        text_lengths = get_text_lengths(L)
        hp.text_lengths = text_lengths + 1
     
    # Load graph 
    ## TODO: generalise to combine other types of models into a synthesis pipeline?
    g1 = Text2MelGraph(hp, mode="synthesize"); print("Graph 1 (t2m) loaded")

    if hp.norm == None :
        t2m_layer_norm = False
        hp.norm = 'layer'
        hp.lr = 0.001
        hp.beta1 = 0.9
        hp.beta2 = 0.999
        hp.epsilon = 0.00000001
        hp.decay_lr = True
        hp.batchsize = {'t2m': 32, 'ssrn': 8}
    else:
        t2m_layer_norm = True

    g2 = SSRNGraph(hp, mode="synthesize"); print("Graph 2 (ssrn) loaded")

    if t2m_layer_norm == False:
        hp.norm = None
        hp.lr = 0.0002
        hp.beta1 = 0.5
        hp.beta2 = 0.9
        hp.epsilon = 0.000001
        hp.decay_lr = False
        hp.batchsize = {'t2m': 16, 'ssrn': 8}
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        ### TODO: specify epoch from comm line?
        ### TODO: t2m and ssrn from separate configs?

        if t2m_epoch > -1:
            restore_archived_model_parameters(sess, hp, 't2m', t2m_epoch)
        else:
            t2m_epoch = restore_latest_model_parameters(sess, hp, 't2m')

        if ssrn_epoch > -1:    
            restore_archived_model_parameters(sess, hp, 'ssrn', ssrn_epoch)
        else:
            ssrn_epoch = restore_latest_model_parameters(sess, hp, 'ssrn')

        # Pass input L through Text2Mel Graph
        t = start_clock('Text2Mel generating...')
        ### TODO: after futher efficiency testing, remove this fork
        if 1:  ### efficient route -- only make K&V once  ## 3.86, 3.70, 3.80 seconds (2 sentences)
            text_lengths = get_text_lengths(L)
            K, V = encode_text(hp, L, g1, sess, speaker_data=speaker_data, labels=labels)
            Y, lengths, alignments = synth_codedtext2mel(hp, K, V, text_lengths, g1, sess, \
                                speaker_data=speaker_data, duration_data=duration_data, \
                                position_in_phone_data=position_in_phone_data,\
                                labels=labels)
        else: ## 5.68, 5.43, 5.38 seconds (2 sentences)
            Y, lengths = synth_text2mel(hp, L, g1, sess, speaker_data=speaker_data, \
                                            duration_data=duration_data, \
                                            position_in_phone_data=position_in_phone_data, \
                                            labels=labels)
        stop_clock(t)

        ### TODO: useful to test this?
        # print(Y[0,:,:])
        # print (np.isnan(Y).any())
        # print('nan1')
        # Then pass output Y of Text2Mel Graph through SSRN graph to get high res spectrogram Z.
        t = start_clock('Mel2Mag generating...')
        Z = synth_mel2mag(hp, Y, g2, sess)
        stop_clock(t) 

        if (np.isnan(Z).any()):  ### TODO: keep?
            Z = np.nan_to_num(Z)

        # Generate wav files
        if not topoutdir:
            topoutdir = hp.sampledir
        outdir = os.path.join(topoutdir, 't2m%s_ssrn%s'%(t2m_epoch, ssrn_epoch))
        if speaker_id:
            outdir += '_speaker-%s'%(speaker_id)
        safe_makedir(outdir)

        # Plot trimmed attention alignment with filename
        print("Plot attention, will save to following dir: %s"%(outdir))
        print("File |  CDP | Ain")
        for i, mag in enumerate(Z):
            outfile = os.path.join(outdir, bases[i])
            trimmed_alignment = alignments[i,:text_lengths[i],:lengths[i]]
            plot_alignment(hp, trimmed_alignment, utt_idx=i+1, t2m_epoch=t2m_epoch, dir=outdir, outfile=outfile)
            CDP = getCDP(trimmed_alignment)
            APin, APout = getAP(trimmed_alignment)
            print("%s | %.2f | %.2f"%( bases[i], CDP, APin))

        print("Generating wav files, will save to following dir: %s"%(outdir))

        
        assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported'

        if ncores==1:
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length
                synth_wave(hp, mag, outfile)
        else:
            executor = ProcessPoolExecutor(max_workers=ncores)    
            futures = []
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length
                futures.append(executor.submit(synth_wave, hp, mag, outfile))
            proc_list = [future.result() for future in tqdm(futures)]
示例#6
0
    def synthesize(self,
                   text=None,
                   emo_code=None,
                   mels=None,
                   speaker_id='',
                   num_sentences=0,
                   ncores=1,
                   topoutdir=''):
        '''
        topoutdir: store samples under here; defaults to hp.sampledir
        t2m_epoch and ssrn_epoch: default -1 means use latest. Otherwise go to archived models.
        '''
        assert self.hp.vocoder in [
            'griffin_lim', 'world'
        ], 'Other vocoders than griffin_lim/world not yet supported'

        if text is not None:
            text_to_phonetic(text=text)
            dataset = load_data(self.hp, mode='demo')
        else:
            dataset = load_data(
                self.hp, mode="synthesis"
            )  #since mode != 'train' or 'validation', will load test_transcript rather than transcript
        fpaths, L = dataset['fpaths'], dataset['texts']
        position_in_phone_data = duration_data = labels = None  # default
        if self.hp.use_external_durations:
            duration_data = dataset['durations']
            if num_sentences > 0:
                duration_data = duration_data[:num_sentences, :, :]

        if 'position_in_phone' in self.hp.history_type:
            ## TODO: combine + deduplicate with relevant code in train.py for making validation set
            def duration2position(duration, fractional=False):
                ### very roundabout -- need to deflate A matrix back to integers:
                duration = duration.sum(axis=0)
                #print(duration)
                # sys.exit('evs')
                positions = durations_to_position(duration,
                                                  fractional=fractional)
                ###positions = end_pad_for_reduction_shape_sync(positions, hp)
                positions = positions[0::hp.r, :]
                #print(positions)
                return positions

            position_in_phone_data = [duration2position(dur, fractional=('fractional' in hp.history_type)) \
                            for dur in duration_data]
            position_in_phone_data = list2batch(position_in_phone_data,
                                                hp.max_T)

        # Ensure we aren't trying to generate more utterances than are actually in our test_transcript
        if num_sentences > 0:
            assert num_sentences < len(fpaths)
            L = L[:num_sentences, :]
            fpaths = fpaths[:num_sentences]

        bases = [basename(fpath) for fpath in fpaths]

        if self.hp.merlin_label_dir:
            labels = [np.load("{}/{}".format(hp.merlin_label_dir, basename(fpath)+".npy")) \
                                for fpath in fpaths ]
            labels = list2batch(labels, hp.max_N)

        if speaker_id:
            speaker2ix = dict(zip(hp.speaker_list,
                                  range(len(hp.speaker_list))))
            speaker_ix = speaker2ix[speaker_id]

            ## Speaker codes are held in (batch, 1) matrix -- tiling is done inside the graph:
            speaker_data = np.ones((len(L), 1)) * speaker_ix
        else:
            speaker_data = None

        # Pass input L through Text2Mel Graph
        t = start_clock('Text2Mel generating...')
        ### TODO: after futher efficiency testing, remove this fork
        if 1:  ### efficient route -- only make K&V once  ## 3.86, 3.70, 3.80 seconds (2 sentences)
            text_lengths = get_text_lengths(L)
            if mels is not None:
                emo_code = encode_audio2emo(self.hp, mels, self.g1, self.sess)
            K, V = encode_text(self.hp,
                               L,
                               self.g1,
                               self.sess,
                               emo_mean=emo_code,
                               speaker_data=speaker_data,
                               labels=labels)
            Y, lengths, alignments = synth_codedtext2mel(self.hp, K, V, text_lengths, self.g1, self.sess, \
                                speaker_data=speaker_data, duration_data=duration_data, \
                                position_in_phone_data=position_in_phone_data,\
                                labels=labels)
        else:  ## 5.68, 5.43, 5.38 seconds (2 sentences)
            Y, lengths = synth_text2mel(self.hp, L, self.g1, self.sess, speaker_data=speaker_data, \
                                            duration_data=duration_data, \
                                            position_in_phone_data=position_in_phone_data, \
                                            labels=labels)
        stop_clock(t)

        ### TODO: useful to test this?
        # print(Y[0,:,:])
        # print (np.isnan(Y).any())
        # print('nan1')
        # Then pass output Y of Text2Mel Graph through SSRN graph to get high res spectrogram Z.
        t = start_clock('Mel2Mag generating...')
        Z = synth_mel2mag(self.hp, Y, self.g2, self.sess)
        stop_clock(t)

        if (np.isnan(Z).any()):  ### TODO: keep?
            Z = np.nan_to_num(Z)

        # Generate wav files
        if not topoutdir:
            topoutdir = self.hp.sampledir
        outdir = os.path.join(
            topoutdir, 't2m%s_ssrn%s' % (self.t2m_epoch, self.ssrn_epoch))
        if speaker_id:
            outdir += '_speaker-%s' % (speaker_id)
        safe_makedir(outdir)
        print("Generating wav files, will save to following dir: %s" %
              (outdir))

        assert self.hp.vocoder in [
            'griffin_lim', 'world'
        ], 'Other vocoders than griffin_lim/world not yet supported'

        if ncores == 1:
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i] *
                          self.hp.r, :]  ### trim to generated length
                synth_wave(self.hp, mag, outfile)
        else:
            executor = ProcessPoolExecutor(max_workers=ncores)
            futures = []
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i] *
                          self.hp.r, :]  ### trim to generated length
                futures.append(
                    executor.submit(synth_wave, self.hp, mag, outfile))
            proc_list = [future.result() for future in tqdm(futures)]

        # Plot attention alignments
        for i in range(num_sentences):
            plot_alignment(self.hp,
                           alignments[i],
                           utt_idx=i + 1,
                           t2m_epoch=self.t2m_epoch,
                           dir=outdir)

        self.outdir = outdir
示例#7
0
                ]
                for idx, sent in enumerate(sentences):
                    wav, attn = eval_model(
                        dv3, sent, replace_pronounciation_prob, min_level_db,
                        ref_level_db, power, n_iter, win_length, hop_length,
                        preemphasis)
                    wav_path = os.path.join(
                        state_dir, "waveform",
                        "eval_sample_{:09d}.wav".format(global_step))
                    sf.write(wav_path, wav, sample_rate)
                    writer.add_audio(
                        "eval_sample_{}".format(idx),
                        wav,
                        global_step,
                        sample_rate=sample_rate)
                    attn_path = os.path.join(
                        state_dir, "alignments",
                        "eval_sample_attn_{:09d}.png".format(global_step))
                    plot_alignment(attn, attn_path)
                    writer.add_image(
                        "eval_sample_attn{}".format(idx),
                        cm.viridis(attn),
                        global_step,
                        dataformats="HWC")

            # save checkpoint
            if global_step % save_interval == 0:
                io.save_parameters(ckpt_dir, global_step, dv3, optim)

            global_step += 1
示例#8
0
def main(args):
    torch.manual_seed(0)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dataset
    dataset = Dataset("train.txt")
    loader = DataLoader(dataset,
                        batch_size=hp.batch_size**2,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        drop_last=True,
                        num_workers=0)

    # Define model
    model = nn.DataParallel(STYLER()).to(device)
    print("Model Has Been Defined")

    # Parameters
    num_param = utils.get_param_num(model)
    text_encoder = utils.get_param_num(
        model.module.style_modeling.style_encoder.text_encoder)
    audio_encoder = utils.get_param_num(
        model.module.style_modeling.style_encoder.audio_encoder)
    predictors = utils.get_param_num(model.module.style_modeling.duration_predictor)\
         + utils.get_param_num(model.module.style_modeling.pitch_predictor)\
              + utils.get_param_num(model.module.style_modeling.energy_predictor)
    decoder = utils.get_param_num(model.module.decoder)
    print('Number of Model Parameters          :', num_param)
    print('Number of Text Encoder Parameters   :', text_encoder)
    print('Number of Audio Encoder Parameters  :', audio_encoder)
    print('Number of Predictor Parameters      :', predictors)
    print('Number of Decoder Parameters        :', decoder)

    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 betas=hp.betas,
                                 eps=hp.eps,
                                 weight_decay=hp.weight_decay)
    scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden,
                                     hp.n_warm_up_step, args.restore_step)
    Loss = STYLERLoss().to(device)
    DATLoss = DomainAdversarialTrainingLoss().to(device)
    print("Optimizer and Loss Function Defined.")

    # Load checkpoint if exists
    checkpoint_path = os.path.join(hp.checkpoint_path())
    try:
        checkpoint = torch.load(
            os.path.join(checkpoint_path,
                         'checkpoint_{}.pth.tar'.format(args.restore_step)))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("\n---Model Restored at Step {}---\n".format(args.restore_step))
    except:
        print("\n---Start New Training---\n")
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

    # Load vocoder
    vocoder = utils.get_vocoder()

    # Init logger
    log_path = hp.log_path()
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(os.path.join(log_path, 'train'))
        os.makedirs(os.path.join(log_path, 'validation'))
    train_logger = SummaryWriter(os.path.join(log_path, 'train'))
    val_logger = SummaryWriter(os.path.join(log_path, 'validation'))

    # Init synthesis directory
    synth_path = hp.synth_path()
    if not os.path.exists(synth_path):
        os.makedirs(synth_path)

    # Define Some Information
    Time = np.array([])
    Start = time.perf_counter()

    # Training
    model = model.train()
    for epoch in range(hp.epochs):
        # Get Training Loader
        total_step = hp.epochs * len(loader) * hp.batch_size

        for i, batchs in enumerate(loader):
            for j, data_of_batch in enumerate(batchs):
                start_time = time.perf_counter()

                current_step = i*hp.batch_size + j + args.restore_step + \
                    epoch*len(loader)*hp.batch_size + 1

                # Get Data
                text = torch.from_numpy(
                    data_of_batch["text"]).long().to(device)
                mel_target = torch.from_numpy(
                    data_of_batch["mel_target"]).float().to(device)
                mel_aug = torch.from_numpy(
                    data_of_batch["mel_aug"]).float().to(device)
                D = torch.from_numpy(data_of_batch["D"]).long().to(device)
                log_D = torch.from_numpy(
                    data_of_batch["log_D"]).float().to(device)
                f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device)
                f0_norm = torch.from_numpy(
                    data_of_batch["f0_norm"]).float().to(device)
                f0_norm_aug = torch.from_numpy(
                    data_of_batch["f0_norm_aug"]).float().to(device)
                energy = torch.from_numpy(
                    data_of_batch["energy"]).float().to(device)
                energy_input = torch.from_numpy(
                    data_of_batch["energy_input"]).float().to(device)
                energy_input_aug = torch.from_numpy(
                    data_of_batch["energy_input_aug"]).float().to(device)
                speaker_embed = torch.from_numpy(
                    data_of_batch["speaker_embed"]).float().to(device)
                src_len = torch.from_numpy(
                    data_of_batch["src_len"]).long().to(device)
                mel_len = torch.from_numpy(
                    data_of_batch["mel_len"]).long().to(device)
                max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32)
                max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32)

                # Forward
                mel_outputs, mel_postnet_outputs, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _, aug_posteriors = model(
                    text,
                    mel_target,
                    mel_aug,
                    f0_norm,
                    energy_input,
                    src_len,
                    mel_len,
                    D,
                    f0,
                    energy,
                    max_src_len,
                    max_mel_len,
                    speaker_embed=speaker_embed)

                # Cal Loss Clean
                mel_output, mel_postnet_output = mel_outputs[
                    0], mel_postnet_outputs[0]
                mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss, classifier_loss_a = Loss(
                    log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask, src_len, mel_len,\
                        aug_posteriors, torch.zeros(mel_target.size(0)).long().to(device))

                # Cal Loss Noisy
                mel_output_noisy, mel_postnet_output_noisy = mel_outputs[
                    1], mel_postnet_outputs[1]
                mel_noisy_loss, mel_postnet_noisy_loss = Loss.cal_mel_loss(
                    mel_output_noisy, mel_postnet_output_noisy, mel_aug,
                    ~mel_mask)

                # Forward DAT
                enc_cat = model.module.style_modeling.style_encoder.encoder_input_cat(
                    mel_aug, f0_norm_aug, energy_input_aug, mel_aug)
                duration_encoding, pitch_encoding, energy_encoding, _ = model.module.style_modeling.style_encoder.audio_encoder(
                    enc_cat, mel_len, src_len, mask=None)
                aug_posterior_d = model.module.style_modeling.augmentation_classifier_d(
                    duration_encoding)
                aug_posterior_p = model.module.style_modeling.augmentation_classifier_p(
                    pitch_encoding)
                aug_posterior_e = model.module.style_modeling.augmentation_classifier_e(
                    energy_encoding)

                # Cal Loss DAT
                classifier_loss_a_dat = DATLoss(
                    (aug_posterior_d, aug_posterior_p, aug_posterior_e),
                    torch.ones(mel_target.size(0)).long().to(device))

                # Total loss
                total_loss = mel_loss + mel_postnet_loss + mel_noisy_loss + mel_postnet_noisy_loss + d_loss + f_loss + e_loss\
                    + hp.dat_weight*(classifier_loss_a + classifier_loss_a_dat)

                # Logger
                t_l = total_loss.item()
                m_l = mel_loss.item()
                m_p_l = mel_postnet_loss.item()
                m_n_l = mel_noisy_loss.item()
                m_p_n_l = mel_postnet_noisy_loss.item()
                d_l = d_loss.item()
                f_l = f_loss.item()
                e_l = e_loss.item()
                cl_a = classifier_loss_a.item()
                cl_a_dat = classifier_loss_a_dat.item()

                # Backward
                total_loss = total_loss / hp.acc_steps
                total_loss.backward()
                if current_step % hp.acc_steps != 0:
                    continue

                # Clipping gradients to avoid gradient explosion
                nn.utils.clip_grad_norm_(model.parameters(),
                                         hp.grad_clip_thresh)

                # Update weights
                scheduled_optim.step_and_update_lr()
                scheduled_optim.zero_grad()

                # Print
                if current_step == 1 or current_step % hp.log_step == 0:
                    Now = time.perf_counter()

                    str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
                        epoch + 1, hp.epochs, current_step, total_step)
                    str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format(
                        t_l, m_l, m_p_l, d_l, f_l, e_l)
                    str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
                        (Now - Start),
                        (total_step - current_step) * np.mean(Time))

                    print("\n" + str1)
                    print(str2)
                    print(str3)

                    train_logger.add_scalar('Loss/total_loss', t_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_loss', m_l, current_step)
                    train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_noisy_loss', m_n_l,
                                            current_step)
                    train_logger.add_scalar('Loss/mel_postnet_noisy_loss',
                                            m_p_n_l, current_step)
                    train_logger.add_scalar('Loss/duration_loss', d_l,
                                            current_step)
                    train_logger.add_scalar('Loss/F0_loss', f_l, current_step)
                    train_logger.add_scalar('Loss/energy_loss', e_l,
                                            current_step)
                    train_logger.add_scalar('Loss/dat_clean_loss', cl_a,
                                            current_step)
                    train_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat,
                                            current_step)

                if current_step % hp.save_step == 0:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict()
                        },
                        os.path.join(
                            checkpoint_path,
                            'checkpoint_{}.pth.tar'.format(current_step)))
                    print("save model at step {} ...".format(current_step))

                if current_step == 1 or current_step % hp.synth_step == 0:
                    length = mel_len[0].item()
                    mel_target_torch = mel_target[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_aug_torch = mel_aug[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel_target = mel_target[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_aug = mel_aug[0, :length].detach().cpu().transpose(
                        0, 1)
                    mel_torch = mel_output[0, :length].detach().unsqueeze(
                        0).transpose(1, 2)
                    mel_noisy_torch = mel_output_noisy[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel = mel_output[0, :length].detach().cpu().transpose(0, 1)
                    mel_noisy = mel_output_noisy[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_torch = mel_postnet_output[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet_noisy_torch = mel_postnet_output_noisy[
                        0, :length].detach().unsqueeze(0).transpose(1, 2)
                    mel_postnet = mel_postnet_output[
                        0, :length].detach().cpu().transpose(0, 1)
                    mel_postnet_noisy = mel_postnet_output_noisy[
                        0, :length].detach().cpu().transpose(0, 1)
                    # Audio.tools.inv_mel_spec(mel, os.path.join(
                    #     synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "c")))
                    # Audio.tools.inv_mel_spec(mel_postnet, os.path.join(
                    #     synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "c")))
                    # Audio.tools.inv_mel_spec(mel_noisy, os.path.join(
                    #     synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "n")))
                    # Audio.tools.inv_mel_spec(mel_postnet_noisy, os.path.join(
                    #     synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "n")))

                    wav_mel = utils.vocoder_infer(
                        mel_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_{}.wav'.format(current_step, "c",
                                                       hp.vocoder)))
                    wav_mel_postnet = utils.vocoder_infer(
                        mel_postnet_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_postnet_{}.wav'.format(
                                current_step, "c", hp.vocoder)))
                    wav_ground_truth = utils.vocoder_infer(
                        mel_target_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_ground-truth_{}.wav'.format(
                                current_step, "c", hp.vocoder)))
                    wav_mel_noisy = utils.vocoder_infer(
                        mel_noisy_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_{}.wav'.format(current_step, "n",
                                                       hp.vocoder)))
                    wav_mel_postnet_noisy = utils.vocoder_infer(
                        mel_postnet_noisy_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_postnet_{}.wav'.format(
                                current_step, "n", hp.vocoder)))
                    wav_aug = utils.vocoder_infer(
                        mel_aug_torch, vocoder,
                        os.path.join(
                            hp.synth_path(),
                            'step_{}_{}_ground-truth_{}.wav'.format(
                                current_step, "n", hp.vocoder)))

                    # Model duration prediction
                    log_duration_output = log_duration_output[
                        0, :src_len[0].item()].detach().cpu()  # [seg_len]
                    log_duration_output = torch.clamp(torch.round(
                        torch.exp(log_duration_output) - hp.log_offset),
                                                      min=0).int()
                    model_duration = utils.get_alignment_2D(
                        log_duration_output).T  # [seg_len, mel_len]
                    model_duration = utils.plot_alignment([model_duration])

                    # Model mel prediction
                    f0 = f0[0, :length].detach().cpu().numpy()
                    energy = energy[0, :length].detach().cpu().numpy()
                    f0_output = f0_output[0, :length].detach().cpu().numpy()
                    energy_output = energy_output[
                        0, :length].detach().cpu().numpy()
                    mel_predicted = utils.plot_data(
                        [(mel_postnet.numpy(), f0_output, energy_output),
                         (mel_target.numpy(), f0, energy)], [
                             'Synthetized Spectrogram Clean',
                             'Ground-Truth Spectrogram'
                         ],
                        filename=os.path.join(
                            synth_path,
                            'step_{}_{}.png'.format(current_step, "c")))
                    mel_noisy_predicted = utils.plot_data(
                        [(mel_postnet_noisy.numpy(), f0_output, energy_output),
                         (mel_aug.numpy(), f0, energy)],
                        ['Synthetized Spectrogram Noisy', 'Aug Spectrogram'],
                        filename=os.path.join(
                            synth_path,
                            'step_{}_{}.png'.format(current_step, "n")))

                    # Normalize audio for tensorboard logger. See https://github.com/lanpa/tensorboardX/issues/511#issuecomment-537600045
                    wav_ground_truth = wav_ground_truth / max(wav_ground_truth)
                    wav_mel = wav_mel / max(wav_mel)
                    wav_mel_postnet = wav_mel_postnet / max(wav_mel_postnet)
                    wav_aug = wav_aug / max(wav_aug)
                    wav_mel_noisy = wav_mel_noisy / max(wav_mel_noisy)
                    wav_mel_postnet_noisy = wav_mel_postnet_noisy / max(
                        wav_mel_postnet_noisy)

                    train_logger.add_image("model_duration",
                                           model_duration,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_image("mel_predicted/Clean",
                                           mel_predicted,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_image("mel_predicted/Noisy",
                                           mel_noisy_predicted,
                                           current_step,
                                           dataformats='HWC')
                    train_logger.add_audio("Clean/wav_ground_truth",
                                           wav_ground_truth,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Clean/wav_mel",
                                           wav_mel,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Clean/wav_mel_postnet",
                                           wav_mel_postnet,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_aug",
                                           wav_aug,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_mel_noisy",
                                           wav_mel_noisy,
                                           current_step,
                                           sample_rate=hp.sampling_rate)
                    train_logger.add_audio("Noisy/wav_mel_postnet_noisy",
                                           wav_mel_postnet_noisy,
                                           current_step,
                                           sample_rate=hp.sampling_rate)

                if current_step == 1 or current_step % hp.eval_step == 0:
                    model.eval()
                    with torch.no_grad():
                        d_l, f_l, e_l, cl_a, cl_a_dat, m_l, m_p_l, m_n_l, m_p_n_l = evaluate(
                            model, current_step)
                        t_l = d_l + f_l + e_l + m_l + m_p_l + m_n_l + m_p_n_l\
                            + hp.dat_weight*(cl_a + cl_a_dat)

                        val_logger.add_scalar('Loss/total_loss', t_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_loss', m_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_noisy_loss', m_n_l,
                                              current_step)
                        val_logger.add_scalar('Loss/mel_postnet_noisy_loss',
                                              m_p_n_l, current_step)
                        val_logger.add_scalar('Loss/duration_loss', d_l,
                                              current_step)
                        val_logger.add_scalar('Loss/F0_loss', f_l,
                                              current_step)
                        val_logger.add_scalar('Loss/energy_loss', e_l,
                                              current_step)
                        val_logger.add_scalar('Loss/dat_clean_loss', cl_a,
                                              current_step)
                        val_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat,
                                              current_step)

                    model.train()

                end_time = time.perf_counter()
                Time = np.append(Time, end_time - start_time)
                if len(Time) == hp.clear_Time:
                    temp_value = np.mean(Time)
                    Time = np.delete(Time, [i for i in range(len(Time))],
                                     axis=None)
                    Time = np.append(Time, temp_value)
示例#9
0
def train(train_loader, model, device, mels_criterion, stop_criterion,
          optimizer, scheduler, writer, train_dir):
    batch_time = ValueWindow()
    data_time = ValueWindow()
    losses = ValueWindow()

    # switch to train mode
    model.train()

    end = time.time()
    global global_epoch
    global global_step
    for i, (txts, mels, stop_tokens, txt_lengths,
            mels_lengths) in enumerate(train_loader):
        scheduler.adjust_learning_rate(optimizer, global_step)
        # measure data loading time
        data_time.update(time.time() - end)

        if device > -1:
            txts = txts.cuda(device)
            mels = mels.cuda(device)
            stop_tokens = stop_tokens.cuda(device)
            txt_lengths = txt_lengths.cuda(device)
            mels_lengths = mels_lengths.cuda(device)

        # compute output
        frames, decoder_frames, stop_tokens_predict, alignment = model(
            txts, txt_lengths, mels)
        decoder_frames_loss = mels_criterion(decoder_frames,
                                             mels,
                                             lengths=mels_lengths)
        frames_loss = mels_criterion(frames, mels, lengths=mels_lengths)
        stop_token_loss = stop_criterion(stop_tokens_predict,
                                         stop_tokens,
                                         lengths=mels_lengths)
        loss = decoder_frames_loss + frames_loss + stop_token_loss

        #print(frames_loss, decoder_frames_loss)
        losses.update(loss.item())

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        if hparams.clip_thresh > 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.get_trainable_parameters(), hparams.clip_thresh)

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % hparams.print_freq == 0:
            log('Epoch: [{0}][{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                    global_epoch,
                    i,
                    len(train_loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses))

        # Logs
        writer.add_scalar("loss", float(loss.item()), global_step)
        writer.add_scalar(
            "avg_loss in {} window".format(losses.get_dinwow_size),
            float(losses.avg), global_step)
        writer.add_scalar("stop_token_loss", float(stop_token_loss.item()),
                          global_step)
        writer.add_scalar("decoder_frames_loss",
                          float(decoder_frames_loss.item()), global_step)
        writer.add_scalar("output_frames_loss", float(frames_loss.item()),
                          global_step)

        if hparams.clip_thresh > 0:
            writer.add_scalar("gradient norm", grad_norm, global_step)
        writer.add_scalar("learning rate", optimizer.param_groups[0]['lr'],
                          global_step)
        global_step += 1

    dst_alignment_path = join(train_dir,
                              "{}_alignment.png".format(global_step))
    alignment = alignment.cpu().detach().numpy()
    plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]],
                   dst_alignment_path,
                   info="{}, {}".format(hparams.builder, global_step))
示例#10
0
def validate(val_loader, model, device, mels_criterion, stop_criterion, writer,
             val_dir):
    batch_time = ValueWindow()
    losses = ValueWindow()

    # switch to evaluate mode
    model.eval()

    global global_epoch
    global global_step
    with torch.no_grad():
        end = time.time()
        for i, (txts, mels, stop_tokens, txt_lengths,
                mels_lengths) in enumerate(val_loader):
            # measure data loading time
            batch_time.update(time.time() - end)

            if device > -1:
                txts = txts.cuda(device)
                mels = mels.cuda(device)
                stop_tokens = stop_tokens.cuda(device)
                txt_lengths = txt_lengths.cuda(device)
                mels_lengths = mels_lengths.cuda(device)

            # compute output
            frames, decoder_frames, stop_tokens_predict, alignment = model(
                txts, txt_lengths, mels)
            decoder_frames_loss = mels_criterion(decoder_frames,
                                                 mels,
                                                 lengths=mels_lengths)
            frames_loss = mels_criterion(frames, mels, lengths=mels_lengths)
            stop_token_loss = stop_criterion(stop_tokens_predict,
                                             stop_tokens,
                                             lengths=mels_lengths)
            loss = decoder_frames_loss + frames_loss + stop_token_loss

            losses.update(loss.item())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % hparams.print_freq == 0:
                log('Epoch: [{0}]\t'
                    'Test: [{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                        global_epoch,
                        i,
                        len(val_loader),
                        batch_time=batch_time,
                        loss=losses))
            # Logs
            writer.add_scalar("loss", float(loss.item()), global_step)
            writer.add_scalar(
                "avg_loss in {} window".format(losses.get_dinwow_size),
                float(losses.avg), global_step)
            writer.add_scalar("stop_token_loss", float(stop_token_loss.item()),
                              global_step)
            writer.add_scalar("decoder_frames_loss",
                              float(decoder_frames_loss.item()), global_step)
            writer.add_scalar("output_frames_loss", float(frames_loss.item()),
                              global_step)

        dst_alignment_path = join(val_dir,
                                  "{}_alignment.png".format(global_step))
        alignment = alignment.cpu().detach().numpy()
        plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]],
                       dst_alignment_path,
                       info="{}, {}".format(hparams.builder, global_step))

    return losses.avg
示例#11
0
        ref_level_db = transform_config["ref_level_db"]
        preemphasis = transform_config["preemphasis"]
        win_length = transform_config["win_length"]
        hop_length = transform_config["hop_length"]

        synthesis_config = config["synthesis"]
        power = synthesis_config["power"]
        n_iter = synthesis_config["n_iter"]

        synthesis_dir = os.path.join(args.output, "synthesis")
        if not os.path.exists(synthesis_dir):
            os.makedirs(synthesis_dir)

        with open(args.text, "rt", encoding="utf-8") as f:
            lines = f.readlines()
            for idx, line in enumerate(lines):
                text = line[:-1]
                dv3.eval()
                wav, attn = eval_model(dv3, text, replace_pronounciation_prob,
                                       min_level_db, ref_level_db, power,
                                       n_iter, win_length, hop_length,
                                       preemphasis)
                plot_alignment(
                    attn,
                    os.path.join(synthesis_dir,
                                 "test_{}_step_{}.png".format(idx, iteration)))
                sf.write(
                    os.path.join(synthesis_dir,
                                 "test_{}_step{}.wav".format(idx, iteration)),
                    wav, sample_rate)