Exemplo n.º 1
0
def copy_synth_SSRN_GL(hp, outdir):

    safe_makedir(outdir)

    dataset = load_data(hp, mode="synthesis") 
    fnames, texts = dataset['fpaths'], dataset['texts']
    bases = [basename(fname) for fname in fnames]
    mels = [np.load(os.path.join(hp.coarse_audio_dir, base + '.npy')) for base in bases]
    lengths = [a.shape[0] for a in mels]
    mels = list2batch(mels, 0)

    g = SSRNGraph(hp, mode="synthesize"); print("Graph (ssrn) loaded")

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        ssrn_epoch = restore_latest_model_parameters(sess, hp, 'ssrn')

        print('Run SSRN...')
        Z = synth_mel2mag(hp, mels, g, sess)

        for i, mag in enumerate(Z):
            print("Working on %s"%(bases[i]))
            mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length             
            wav = spectrogram2wav(hp, mag)
            soundfile.write(outdir + "/%s.wav"%(base), wav, hp.sr)
Exemplo n.º 2
0
def main_work():

    #################################################

    # ============= Process command line ============

    a = ArgumentParser()
    a.add_argument('-c', dest='config', required=True, type=str)
    a.add_argument('-ncores',
                   default=1,
                   type=int,
                   help='Number of cores for parallel processing')
    opts = a.parse_args()

    # ===============================================

    hp = load_config(opts.config)
    assert hp.attention_guide_dir

    dataset = load_data(hp)
    fpaths, text_lengths = dataset['fpaths'], dataset['text_lengths']

    assert os.path.exists(hp.coarse_audio_dir)
    safe_makedir(hp.attention_guide_dir)

    executor = ProcessPoolExecutor(max_workers=opts.ncores)
    futures = []
    for (fpath, text_length) in zip(fpaths, text_lengths):
        futures.append(executor.submit(proc, fpath, text_length, hp))
    proc_list = [future.result() for future in tqdm.tqdm(futures)]
Exemplo n.º 3
0
def trim_waves_in_directory(in_dir, out_dir, num_workers=1, tqdm=lambda x: x, \
                nfiles=0, top_db=30, trimonly=False, endpad=0.3):
    safe_makedir(out_dir)
    wave_files = sorted(glob.glob(in_dir + '/*.wav'))
    if nfiles > 0:
        wave_files = wave_files[:min(nfiles, len(wave_files))]

    if num_workers:
        executor = ProcessPoolExecutor(max_workers=num_workers)
        futures = []
        for (index, wave_file) in enumerate(wave_files):
            futures.append(
                executor.submit(
                    partial(_process_utterance,
                            wave_file,
                            out_dir,
                            top_db=top_db,
                            trimonly=trimonly,
                            end_pad_sec=endpad)))
        return [future.result() for future in tqdm(futures)]
    else:  ## serial processing
        for wave_file in tqdm(wave_files):
            _process_utterance(wave_file,
                               out_dir,
                               top_db=top_db,
                               trimonly=trimonly,
                               end_pad_sec=endpad)
Exemplo n.º 4
0
def babble(hp, num_sentences=0):

    if num_sentences == 0:
        num_sentences = 4 # default
    g1 = BabblerGraph(hp, mode="synthesize"); print("Babbler graph loaded")
    g2 = SSRNGraph(hp, mode="synthesize"); print("SSRN graph loaded")

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        babbler_epoch = restore_latest_model_parameters(sess, hp, 'babbler')
        ssrn_epoch = restore_latest_model_parameters(sess, hp, 'ssrn')

        t = start_clock('Babbling...')
        Y = synth_babble(hp, g1, sess, seed=False, nsamples=num_sentences)
        stop_clock(t)

        t = start_clock('Mel2Mag generating...')
        Z = synth_mel2mag(hp, Y, g2, sess)
        stop_clock(t) 

        if (np.isnan(Z).any()):  ### TODO: keep?
            Z = np.nan_to_num(Z)

        # Generate wav files
        outdir = os.path.join(hp.voicedir, 'synth_babble', '%s_%s'%(babbler_epoch, ssrn_epoch))
        safe_makedir(outdir)
        for i, mag in enumerate(Z):
            print("Applying Griffin-Lim to sample number %s"%(i))
            wav = spectrogram2wav(hp, mag)
            write(outdir + "/{:03d}.wav".format(i), hp.sr, wav)
def logger_setup(logdir):

    safe_makedir(logdir)

    ## Get new unique named logfile for each run:
    i = 1
    while True:
        logfile = os.path.join(logdir, 'log_{:06d}.txt'.format(i))
        if not os.path.isfile(logfile):
            break
        else:
            i += 1

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    formatter = logging.Formatter(
        '%(asctime)s | %(threadName)-3.3s | %(levelname)-1.1s | %(message)s')

    fh = logging.FileHandler(logfile)
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    logger.info('Set up logger to write to console and %s' % (logfile))

    log_environment_information(logger, logfile)
Exemplo n.º 6
0
def split_waves_in_directory(in_dir, out_dir, num_workers=1, tqdm=lambda x: x, nfiles=0):
    safe_makedir(out_dir)
    executor = ProcessPoolExecutor(max_workers=num_workers)
    futures = []

    wave_files = sorted(glob.glob(in_dir + '/*.wav'))
    if nfiles > 0:
        wave_files = wave_files[:min(nfiles, len(wave_files))]

    for (index, wave_file) in enumerate(wave_files):
        futures.append(executor.submit(
            partial(_process_utterance, wave_file, out_dir)))
    return [future.result() for future in tqdm(futures)]
Exemplo n.º 7
0
def copy_synth_GL(hp, outdir):

    safe_makedir(outdir)

    dataset = load_data(hp, mode="synthesis")
    fnames, texts = dataset['fpaths'], dataset['texts']
    bases = [basename(fname) for fname in fnames]

    for base in bases:
        print("Working on file %s" % (base))
        mag = np.load(os.path.join(hp.full_audio_dir, base + '.npy'))
        wav = spectrogram2wav(hp, mag)
        soundfile.write(outdir + "/%s.wav" % (base), wav, hp.sr)
def main_work():

    #################################################

    # ============= Process command line ============

    a = ArgumentParser()

    a.add_argument('-b', dest='binlabdir', required=True)
    a.add_argument('-t', dest='text_lab_dir', required=True)
    a.add_argument('-n', dest='norm_info_fname', required=True)
    a.add_argument('-o', dest='outdir', required=True)
    a.add_argument('-binext', dest='binext', required=False, default='lab')
    a.add_argument('-skipterminals', action='store_true', default=False)

    opts = a.parse_args()

    # ===============================================

    safe_makedir(opts.outdir)

    norm_info = get_speech(opts.norm_info_fname, 425)[:, :-9]
    data_min = norm_info[0, :]
    data_max = norm_info[1, :]
    data_range = data_max - data_min

    text_label_files = set(
        [basename(f) for f in glob.glob(opts.text_lab_dir + '/*.lab')])
    binary_label_files = sorted(glob.glob(opts.binlabdir + '/*.' +
                                          opts.binext))
    print binary_label_files
    for binlab in binary_label_files:
        base = basename(binlab)
        if base not in text_label_files:
            continue
        print base
        lab = process_merlin_label(binlab, opts.text_lab_dir)
        if opts.skipterminals:
            lab = lab[
                1:
                -1, :]  ## NB: dont remove 2 last as in durations, as the final punct does't features here
        norm_lab = minmax_norm(lab, data_min, data_max)

        if 0:  ## piano roll style plot:
            pl.imshow(norm_lab, interpolation='nearest')
            pl.gray()
            pl.savefig('/afs/inf.ed.ac.uk/user/o/owatts/temp/fig.pdf')
            sys.exit('abckdubv')

        np.save(opts.outdir + '/' + base, norm_lab)
def main_work():

    #################################################

    # ============= Process command line ============

    a = ArgumentParser()

    a.add_argument('-b', dest='binlabdir', required=True)
    a.add_argument('-f', dest='audio_dir', required=True)
    a.add_argument('-n', dest='norm_info_fname', required=True)
    a.add_argument('-o', dest='outdir', required=True)
    a.add_argument('-binext', dest='binext', required=False, default='lab')

    a.add_argument('-ir', dest='inrate', type=float, default=5.0)
    a.add_argument('-or', dest='outrate', type=float, default=12.5)

    opts = a.parse_args()

    # ===============================================

    safe_makedir(opts.outdir)

    norm_info = get_speech(opts.norm_info_fname, 425)[:, -9:]
    data_min = norm_info[0, :]
    data_max = norm_info[1, :]
    data_range = data_max - data_min

    audio_files = set(
        [basename(f) for f in glob.glob(opts.audio_dir + '/*.npy')])
    binary_label_files = sorted(glob.glob(opts.binlabdir + '/*.' +
                                          opts.binext))

    for binlab in binary_label_files:
        base = basename(binlab)
        if base not in audio_files:
            continue
        print base
        positions = process_merlin_positions(binlab,
                                             opts.audio_dir,
                                             inrate=opts.inrate,
                                             outrate=opts.outrate)
        norm_positions = minmax_norm(positions, data_min, data_max)

        np.save(opts.outdir + '/' + base, norm_positions)
Exemplo n.º 10
0
def compute_validation(hp, model_type, epoch, inputs, synth_graph, sess, speaker_codes, \
         valid_filenames, validation_set_reference, duration_data=None, validation_labels=None, position_in_phone_data=None):
    if model_type == 't2m': ## TODO: coded_text2mel here
        validation_set_predictions_tensor, lengths = synth_text2mel(hp, inputs, synth_graph, sess, speaker_data=speaker_codes, duration_data=duration_data, labels=validation_labels, position_in_phone_data=position_in_phone_data)
        validation_set_predictions = split_batch(validation_set_predictions_tensor, lengths)  
        score = compute_dtw_error(validation_set_reference, validation_set_predictions)   
    elif model_type == 'ssrn':
        validation_set_predictions_tensor = synth_mel2mag(hp, inputs, synth_graph, sess)
        lengths = [len(ref) for ref in validation_set_reference]
        validation_set_predictions = split_batch(validation_set_predictions_tensor, lengths)  
        score = compute_simple_LSD(validation_set_reference, validation_set_predictions)
    else:
        info('compute_validation cannot handle model type %s: dummy value (0.0) supplied as validation score'%(model_type)); return 0.0
    ## store parameters for later use:-
    valid_dir = '%s-%s/validation_epoch_%s'%(hp.logdir, model_type, epoch)
    safe_makedir(valid_dir)
    hp.validation_sentences_to_synth_params = min(hp.validation_sentences_to_synth_params, len(valid_filenames)) #if less sentences match the validation pattern than the value of 'hp.validation_sent_to_synth'
    for i in range(hp.validation_sentences_to_synth_params):
        np.save(os.path.join(valid_dir, basename(valid_filenames[i])), validation_set_predictions[i])
    return score
Exemplo n.º 11
0
def main_work():

    #################################################

    # ============= Process command line ============

    a = ArgumentParser()
    a.add_argument('-c', dest='config', required=True, type=str)
    a.add_argument('-ncores',
                   default=1,
                   type=int,
                   help='Number of cores for parallel processing')
    opts = a.parse_args()

    # ===============================================

    hp = load_config(opts.config)

    fpaths = sorted(glob.glob(hp.waveforms + '/*.wav'))

    safe_makedir(hp.coarse_audio_dir)
    safe_makedir(hp.full_audio_dir)
    safe_makedir(hp.full_mel_dir)

    executor = ProcessPoolExecutor(max_workers=opts.ncores)
    futures = []
    for fpath in fpaths:
        futures.append(executor.submit(proc, fpath, hp))
    proc_list = [future.result() for future in tqdm.tqdm(futures)]
Exemplo n.º 12
0
def synthesize(hp, speaker_id='', num_sentences=0, ncores=1, topoutdir='', t2m_epoch=-1, ssrn_epoch=-1):
    '''
    topoutdir: store samples under here; defaults to hp.sampledir
    t2m_epoch and ssrn_epoch: default -1 means use latest. Otherwise go to archived models.
    '''
    assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported'

    dataset = load_data(hp, mode="synthesis") #since mode != 'train' or 'validation', will load test_transcript rather than transcript
    fpaths, L = dataset['fpaths'], dataset['texts']
    position_in_phone_data = duration_data = labels = None # default
    if hp.use_external_durations:
        duration_data = dataset['durations']
        if num_sentences > 0:
            duration_data = duration_data[:num_sentences, :, :]

    if 'position_in_phone' in hp.history_type:
        ## TODO: combine + deduplicate with relevant code in train.py for making validation set
        def duration2position(duration, fractional=False):     
            ### very roundabout -- need to deflate A matrix back to integers:
            duration = duration.sum(axis=0)
            #print(duration)
            # sys.exit('evs')   
            positions = durations_to_position(duration, fractional=fractional)
            ###positions = end_pad_for_reduction_shape_sync(positions, hp)
            positions = positions[0::hp.r, :]         
            #print(positions)
            return positions

        position_in_phone_data = [duration2position(dur, fractional=('fractional' in hp.history_type)) \
                        for dur in duration_data]       
        position_in_phone_data = list2batch(position_in_phone_data, hp.max_T)



    # Ensure we aren't trying to generate more utterances than are actually in our test_transcript
    if num_sentences > 0:
        assert num_sentences < len(fpaths)
        L = L[:num_sentences, :]
        fpaths = fpaths[:num_sentences]

    bases = [basename(fpath) for fpath in fpaths]

    if hp.merlin_label_dir:
        labels = [np.load("{}/{}".format(hp.merlin_label_dir, basename(fpath)+".npy")) \
                              for fpath in fpaths ]
        labels = list2batch(labels, hp.max_N)


    if speaker_id:
        speaker2ix = dict(zip(hp.speaker_list, range(len(hp.speaker_list))))
        speaker_ix = speaker2ix[speaker_id]

        ## Speaker codes are held in (batch, 1) matrix -- tiling is done inside the graph:
        speaker_data = np.ones((len(L), 1))  *  speaker_ix
    else:
        speaker_data = None

    # Load graph 
    ## TODO: generalise to combine other types of models into a synthesis pipeline?
    g1 = Text2MelGraph(hp, mode="synthesize"); print("Graph 1 (t2m) loaded")
    g2 = SSRNGraph(hp, mode="synthesize"); print("Graph 2 (ssrn) loaded")

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        ### TODO: specify epoch from comm line?
        ### TODO: t2m and ssrn from separate configs?

        if t2m_epoch > -1:
            restore_archived_model_parameters(sess, hp, 't2m', t2m_epoch)
        else:
            t2m_epoch = restore_latest_model_parameters(sess, hp, 't2m')

        if ssrn_epoch > -1:    
            restore_archived_model_parameters(sess, hp, 'ssrn', ssrn_epoch)
        else:
            ssrn_epoch = restore_latest_model_parameters(sess, hp, 'ssrn')

        # Pass input L through Text2Mel Graph
        t = start_clock('Text2Mel generating...')
        ### TODO: after futher efficiency testing, remove this fork
        if 1:  ### efficient route -- only make K&V once  ## 3.86, 3.70, 3.80 seconds (2 sentences)
            text_lengths = get_text_lengths(L)
            K, V = encode_text(hp, L, g1, sess, speaker_data=speaker_data, labels=labels)
            Y, lengths, alignments = synth_codedtext2mel(hp, K, V, text_lengths, g1, sess, \
                                speaker_data=speaker_data, duration_data=duration_data, \
                                position_in_phone_data=position_in_phone_data,\
                                labels=labels)
        else: ## 5.68, 5.43, 5.38 seconds (2 sentences)
            Y, lengths = synth_text2mel(hp, L, g1, sess, speaker_data=speaker_data, \
                                            duration_data=duration_data, \
                                            position_in_phone_data=position_in_phone_data, \
                                            labels=labels)
        stop_clock(t)

        ### TODO: useful to test this?
        # print(Y[0,:,:])
        # print (np.isnan(Y).any())
        # print('nan1')
        # Then pass output Y of Text2Mel Graph through SSRN graph to get high res spectrogram Z.
        t = start_clock('Mel2Mag generating...')
        Z = synth_mel2mag(hp, Y, g2, sess)
        stop_clock(t) 

        if (np.isnan(Z).any()):  ### TODO: keep?
            Z = np.nan_to_num(Z)

        # Generate wav files
        if not topoutdir:
            topoutdir = hp.sampledir
        outdir = os.path.join(topoutdir, 't2m%s_ssrn%s'%(t2m_epoch, ssrn_epoch))
        if speaker_id:
            outdir += '_speaker-%s'%(speaker_id)
        safe_makedir(outdir)
        print("Generating wav files, will save to following dir: %s"%(outdir))

        
        assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported'

        if ncores==1:
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length
                synth_wave(hp, mag, outfile)
        else:
            executor = ProcessPoolExecutor(max_workers=ncores)    
            futures = []
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length
                futures.append(executor.submit(synth_wave, hp, mag, outfile))
            proc_list = [future.result() for future in tqdm(futures)]

        # for i, mag in enumerate(Z):
        #     print("Working on %s"%(bases[i]))
        #     mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length
            
        #     if hp.vocoder=='magphase_compressed':
        #         mag = denorm(mag, s, hp.normtype)
        #         streams = split_streams(mag, ['mag', 'lf0', 'vuv', 'real', 'imag'], [60,1,1,45,45])
        #         wav = magphase_synth_from_compressed(streams, samplerate=hp.sr)
        #     elif hp.vocoder=='griffin_lim':                
        #         wav = spectrogram2wav(hp, mag)
        #     else:
        #         sys.exit('Unsupported vocoder type: %s'%(hp.vocoder))
        #     #write(outdir + "/{}.wav".format(bases[i]), hp.sr, wav)
        #     soundfile.write(outdir + "/{}.wav".format(bases[i]), wav, hp.sr)
            

            
        # Plot attention alignments 
        for i in range(num_sentences):
            plot_alignment(hp, alignments[i], utt_idx=i+1, t2m_epoch=t2m_epoch, dir=outdir)
Exemplo n.º 13
0
def main_work():

    #################################################

    # ============= Process command line ============
    a = ArgumentParser()
    a.add_argument('-c', dest='config', required=True, type=str)
    a.add_argument('-m',
                   dest='model_type',
                   required=True,
                   choices=['t2m', 'ssrn', 'babbler'])
    opts = a.parse_args()

    # ===============================================
    model_type = opts.model_type
    hp = load_config(opts.config)
    logdir = hp.logdir + "-" + model_type
    logger_setup.logger_setup(logdir)
    info('Command line: %s' % (" ".join(sys.argv)))

    ### TODO: move this to its own function somewhere. Can be used also at synthesis time?
    ### Prepare reference data for validation set:  ### TODO: alternative to holding in memory?
    dataset = load_data(hp, mode="validation")
    valid_filenames, validation_text = dataset['fpaths'], dataset['texts']

    speaker_codes = validation_duration_data = position_in_phone_data = None  ## defaults
    if hp.multispeaker:
        speaker_codes = dataset['speakers']
    if hp.use_external_durations:
        validation_duration_data = dataset['durations']

    ## take random subset of validation set to avoid 'This is a librivox recording' type sentences
    random.seed(1234)
    v_indices = range(len(valid_filenames))
    random.shuffle(v_indices)
    v = min(hp.validation_sentences_to_evaluate, len(valid_filenames))
    v_indices = v_indices[:v]

    if hp.multispeaker:  ## now come back to this after v computed
        speaker_codes = np.array(speaker_codes)[v_indices].reshape(-1, 1)
    if hp.use_external_durations:
        validation_duration_data = validation_duration_data[v_indices, :, :]

    valid_filenames = np.array(valid_filenames)[v_indices]
    validation_mags = [np.load(hp.full_audio_dir + os.path.sep + basename(fpath)+'.npy') \
                                for fpath in valid_filenames]
    validation_text = validation_text[v_indices, :]
    validation_labels = None  # default
    if hp.merlin_label_dir:
        validation_labels = [np.load("{}/{}".format(hp.merlin_label_dir, basename(fpath)+".npy")) \
                              for fpath in valid_filenames ]
        validation_labels = list2batch(validation_labels, hp.max_N)

    if 'position_in_phone' in hp.history_type:

        def duration2position(duration, fractional=False):
            ### very roundabout -- need to deflate A matrix back to integers:
            duration = duration.sum(axis=0)
            #print(duration)
            # sys.exit('evs')
            positions = durations_to_position(duration, fractional=fractional)
            ###positions = end_pad_for_reduction_shape_sync(positions, hp)
            positions = positions[0::hp.r, :]
            #print(positions)
            return positions

        position_in_phone_data = [duration2position(dur, fractional=('fractional' in hp.history_type)) \
                        for dur in dataset['durations'][v_indices]]
        position_in_phone_data = list2batch(position_in_phone_data, hp.max_T)

    if model_type == 't2m':
        validation_mels = [np.load(hp.coarse_audio_dir + os.path.sep + basename(fpath)+'.npy') \
                                    for fpath in valid_filenames]
        validation_inputs = validation_text
        validation_reference = validation_mels
        validation_lengths = None
    elif model_type == 'ssrn':
        validation_inputs, validation_lengths = make_mel_batch(
            hp, valid_filenames)
        validation_reference = validation_mags
    else:
        info(
            'Undefined model_type {} for making validation inputs -- supply dummy None values'
            .format(model_type))
        validation_inputs = None
        validation_reference = None

    ## Get the text and mel inputs for the utts you would like to plot attention graphs for
    if hp.plot_attention_every_n_epochs and model_type == 't2m':  #check if we want to plot attention
        # TODO do we want to generate and plot attention for validation or training set sentences??? modify attention_inputs accordingly...
        attention_inputs = validation_text[:hp.num_sentences_to_plot_attention]
        attention_mels = validation_mels[:hp.num_sentences_to_plot_attention]
        attention_mels = np.array(
            attention_mels)  #TODO should be able to delete this line...?
        attention_mels_array = np.zeros(
            (hp.num_sentences_to_plot_attention, hp.max_T, hp.n_mels),
            np.float32)  # create fixed size array to hold attention mels
        for i in range(hp.num_sentences_to_plot_attention
                       ):  # copy data into this fixed sized array
            attention_mels_array[
                i, :attention_mels[i].shape[0], :attention_mels[i].
                shape[1]] = attention_mels[i]
        attention_mels = attention_mels_array  # rename for convenience

    ## Map to appropriate type of graph depending on model_type:
    AppropriateGraph = {
        't2m': Text2MelGraph,
        'ssrn': SSRNGraph,
        'babbler': BabblerGraph
    }[model_type]

    g = AppropriateGraph(hp)
    info("Training graph loaded")
    synth_graph = AppropriateGraph(hp, mode='synthesize', reuse=True)
    info(
        "Synthesis graph loaded"
    )  #reuse=True ensures that 'synth_graph' and 'attention_graph' share weights with training graph 'g'
    attention_graph = AppropriateGraph(hp,
                                       mode='generate_attention',
                                       reuse=True)
    info("Atttention generating graph loaded")
    #TODO is loading three graphs a problem for memory usage?

    if 0:
        print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'Text2Mel'))
        ## [<tf.Variable 'Text2Mel/TextEnc/embed_1/lookup_table:0' shape=(61, 128) dtype=float32_ref>, <tf.Variable 'Text2Mel/TextEnc/C_2/conv1d/kernel:0' shape=(1, 128, 512) dtype=float32_ref>, ...

    ## TODO: tensorflow.python.training.supervisor deprecated: --> switch to tf.train.MonitoredTrainingSession
    sv = tf.train.Supervisor(logdir=logdir,
                             save_model_secs=0,
                             global_step=g.global_step)

    ## Get the current training epoch from the name of the model that we have loaded
    latest_checkpoint = tf.train.latest_checkpoint(logdir)
    if latest_checkpoint:
        epoch = int(
            latest_checkpoint.strip('/ ').split('/')[-1].replace(
                'model_epoch_', ''))
    else:  #did not find a model checkpoint, so we start training from scratch
        epoch = 0

    ## If save_every_n_epochs > 0, models will be stored here every n epochs and not
    ## deleted, regardless of validation improvement etc.:--
    safe_makedir(logdir + '/archive/')

    with sv.managed_session() as sess:
        if 0:  ## Set to 1 to debug NaNs; at tfdbg prompt, type:    run -f has_inf_or_nan
            ## later:    lt  -f has_inf_or_nan -n .*AudioEnc.*
            os.system('rm -rf {}/tmp_tfdbg/'.format(logdir))
            sess = tf_debug.LocalCLIDebugWrapperSession(sess,
                                                        dump_root=logdir +
                                                        '/tmp_tfdbg/')

        if hp.initialise_weights_from_existing:
            info('=====Initialise some variables from existing model(s)=====')
            sess.graph._unsafe_unfinalize(
            )  ## !!! https://stackoverflow.com/questions/41798311/tensorflow-graph-is-finalized-and-cannot-be-modified/41798401
            for (scope, checkpoint) in hp.initialise_weights_from_existing:
                var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope)
                info('----From existing model %s:----' % (checkpoint))
                if var_list:  ## will be empty when training t2m but looking at ssrn
                    saver = tf.train.Saver(var_list=var_list)
                    saver.restore(sess, checkpoint)
                    for var in var_list:
                        info('   %s' % (var.name))
                else:
                    info('   No variables!')
                info(
                    '========================================================')

        if hp.restart_from_savepath:  #set this param to list: [path_to_t2m_model_folder, path_to_ssrn_model_folder]
            # info('Restart from these paths:')
            info(hp.restart_from_savepath)

            # assert len(hp.restart_from_savepath) == 2
            restart_from_savepath1, restart_from_savepath2 = hp.restart_from_savepath
            restart_from_savepath1 = os.path.abspath(restart_from_savepath1)
            restart_from_savepath2 = os.path.abspath(restart_from_savepath2)

            sess.graph._unsafe_unfinalize(
            )  ## !!! https://stackoverflow.com/questions/41798311/tensorflow-graph-is-finalized-and-cannot-be-modified/41798401
            sess.run(tf.global_variables_initializer())

            print('Restore parameters')
            if model_type == 't2m':
                var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             'Text2Mel')
                saver1 = tf.train.Saver(var_list=var_list)
                latest_checkpoint = tf.train.latest_checkpoint(
                    restart_from_savepath1)
                saver1.restore(sess, restart_from_savepath1)
                print("Text2Mel Restored!")
            elif model_type == 'ssrn':
                var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'SSRN') + \
                           tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'gs')
                saver2 = tf.train.Saver(var_list=var_list)
                latest_checkpoint = tf.train.latest_checkpoint(
                    restart_from_savepath2)
                saver2.restore(sess, restart_from_savepath2)
                print("SSRN Restored!")
            epoch = int(
                latest_checkpoint.strip('/ ').split('/')[-1].replace(
                    'model_epoch_', ''))
            # TODO: this counter won't work if training restarts in same directory.
            ## Get epoch from gs?

        loss_history = []  #any way to restore loss history too?

        #plot attention generated from freshly initialised model
        if hp.plot_attention_every_n_epochs and model_type == 't2m' and epoch == 0:  # ssrn model doesn't generate alignments
            get_and_plot_alignments(
                hp, epoch - 1, attention_graph, sess, attention_inputs,
                attention_mels, logdir +
                "/alignments")  # epoch-1 refers to freshly initialised model

        current_score = compute_validation(
            hp,
            model_type,
            epoch,
            validation_inputs,
            synth_graph,
            sess,
            speaker_codes,
            valid_filenames,
            validation_reference,
            duration_data=validation_duration_data,
            validation_labels=validation_labels,
            position_in_phone_data=position_in_phone_data)
        info('validation epoch {0}: {1:0.3f}'.format(epoch, current_score))

        while 1:
            progress_bar_text = '%s/%s; ep. %s' % (hp.config_name, model_type,
                                                   epoch)
            for batch_in_current_epoch in tqdm(range(g.num_batch),
                                               total=g.num_batch,
                                               ncols=80,
                                               leave=True,
                                               unit='b',
                                               desc=progress_bar_text):
                gs, loss_components, _ = sess.run(
                    [g.global_step, g.loss_components, g.train_op])
                loss_history.append(loss_components)

            ### End of epoch: validate?
            if hp.validate_every_n_epochs:
                if epoch % hp.validate_every_n_epochs == 0:

                    loss_history = np.array(loss_history)
                    train_loss_mean_std = np.concatenate(
                        [loss_history.mean(axis=0),
                         loss_history.std(axis=0)])
                    loss_history = []

                    train_loss_mean_std = ' '.join([
                        '{:0.3f}'.format(score)
                        for score in train_loss_mean_std
                    ])
                    info('train epoch {0}: {1}'.format(epoch,
                                                       train_loss_mean_std))

                    current_score = compute_validation(
                        hp,
                        model_type,
                        epoch,
                        validation_inputs,
                        synth_graph,
                        sess,
                        speaker_codes,
                        valid_filenames,
                        validation_reference,
                        duration_data=validation_duration_data,
                        validation_labels=validation_labels,
                        position_in_phone_data=position_in_phone_data)
                    info('validation epoch {0:0}: {1:0.3f}'.format(
                        epoch, current_score))

            ### End of epoch: plot attention matrices? #################################
            if hp.plot_attention_every_n_epochs and model_type == 't2m' and epoch % hp.plot_attention_every_n_epochs == 0:  # ssrn model doesn't generate alignments
                get_and_plot_alignments(hp, epoch, attention_graph, sess,
                                        attention_inputs, attention_mels,
                                        logdir + "/alignments")

            ### Save end of each epoch (all but the most recent 5 will be overwritten):
            stem = logdir + '/model_epoch_{0}'.format(epoch)
            sv.saver.save(sess, stem)

            ### Check if we should archive (to files which won't be overwritten):
            if hp.save_every_n_epochs:
                if epoch % hp.save_every_n_epochs == 0:
                    info('Archive model %s' % (stem))
                    for fname in glob.glob(stem + '*'):
                        shutil.copy(fname, logdir + '/archive/')

            epoch += 1
            if epoch > hp.max_epochs:
                info('Max epochs ({}) reached: end training'.format(
                    hp.max_epochs))
                return

    print("Done")
Exemplo n.º 14
0
def main_work():

    #################################################

    # ============= Process command line ============

    a = ArgumentParser()
    a.add_argument(
        '-meldir',
        required=True,
        type=str,
        help=
        'existing directory with mels - features are padding to match length of these '
    )
    a.add_argument('-worlddir',
                   required=True,
                   type=str,
                   help='existing directory containing world features')
    a.add_argument('-outdir', required=True, type=str)

    a.add_argument('-testpatt', required=False, type=str, default='')

    a.add_argument('-ncores',
                   default=1,
                   type=int,
                   help='Number of cores for parallel processing')
    opts = a.parse_args()

    # ===============================================

    # hp = load_config(opts.config)

    fpaths = sorted(glob.glob(opts.meldir + '/*.npy'))  # [:10]

    normkind = 'meanvar'

    if normkind == 'minmax':
        scaler = MinMaxScaler()
    elif normkind == 'meanvar':
        scaler = StandardScaler()
    else:
        sys.exit('aedvsv')

    if opts.testpatt:
        train_fpaths = [p for p in fpaths if opts.testpatt not in basename(p)]
    else:
        train_fpaths = fpaths

    for fpath in tqdm(train_fpaths, desc='First pass to get norm stats'):

        data = load_sentence(fpath, worlddir=opts.worlddir, outdir=opts.outdir)
        scaler = update_normalisation_stats(data, scaler)

    safe_makedir(opts.outdir)
    safe_makedir(opts.outdir + '/full_world/')
    safe_makedir(opts.outdir + '/coarse_world/')

    if 0:
        process(fpaths[0],
                worlddir=opts.worlddir,
                outdir=opts.outdir,
                scaler=scaler)
        sys.exit('aedvsfv')

    executor = ProcessPoolExecutor(max_workers=opts.ncores)
    futures = []
    for fpath in fpaths:
        futures.append(
            executor.submit(process,
                            fpath,
                            worlddir=opts.worlddir,
                            outdir=opts.outdir,
                            scaler=scaler))

    proc_list = [
        future.result()
        for future in tqdm(futures,
                           desc='Second pass (parallel) to do normalisation')
    ]

    if normkind == 'minmax':
        mini = scaler.data_min_  ## TODO: per speaker...
        maxi = scaler.data_max_
        stats = np.vstack([mini, maxi])
    elif normkind == 'meanvar':
        mean = scaler.mean_  ## TODO: per speaker...
        std = scaler.scale_
        stats = np.vstack([mean, std])
    else:
        sys.exit('aedvsv2')
    np.save(opts.outdir + '/norm_stats', stats)
Exemplo n.º 15
0
def synthesise_from_config(config, model, synthdir, full_trace=False, oracle_pitchmarks=False, dummy_synth=False):

    '''
    TODO: refactor and pair dummy_synth with model loading 
    '''

    safe_makedir(synthdir)
    
    if full_trace:
        print 'Make model to output all hidden activations'
        trace_model, layer_names = convert_model_for_trace(model)

    wavedir = config['wavedir']

    basenames = get_testset_names(config['test_pattern'], wavedir)

    nsynth = config.get('n_sentences_to_synth', 1)

    if config.get('normalise_spectrogram_in', 'freq') == 'freq_global_norm':
        model_dir = get_model_dir_name(config) ## repeat this to get norm info        
        norm_mean_fname = os.path.join(model_dir, 'spectrogram_mean.npy')    
        norm_std_fname = os.path.join(model_dir, 'spectrogram_std.npy')    
        assert os.path.isfile(norm_mean_fname) and os.path.isfile(norm_std_fname)
        spectrogram_mean = np.load(norm_mean_fname)
        spectrogram_std = np.load(norm_std_fname)
    else:
        spectrogram_mean = None
        spectrogram_std = None

    ### following lines for compatibility with earlier configs (before norm handling rationalised)
    if 'feat_dim' in config:
        input_dimension = config['feat_dim']
    else:
        input_dimension = config['feat_dim_in']

    if 'normalise_melspectrograms' in config:
        normalise_input_features = config['normalise_melspectrograms']
    else:
        normalise_input_features = config.get('normalise_spectrogram_in', 'freq')


    spectrogram_extractor = get_spectrogram_extractor(n_mels=input_dimension, \
            normalise=normalise_input_features, \
            spectrogram_mean=spectrogram_mean, spectrogram_std=spectrogram_std, \
            dft_window=config.get('dft_window', 512), n_hop=config.get('n_hop', 200))

    ## opt_model : waveform preditor chained with spectrogram extractor
    ## model : waveform predictor -- this is the only bit which is saved
    noise_std = config.get('noise_std', 1.0)
    noise_input = config.get('add_noise', False)

    n_hop = config.get('n_hop', 200)

    ## dummy synthesis on loading (because first use of network not optimised)
    # DUMMY_SYNTH = True
    if dummy_synth:
        print 'synthesise dummy audio...'
        wav = exc = np.zeros(n_hop*20).reshape(1,-1)
        (inputs, targets) = tweak_batch((wav, exc), spectrogram_extractor, config, [])  ### []: dummy output transformers 

        combined_prediction = model.predict(x=inputs) 
        print '                           done!'


    i = 0
    for basename in basenames:
        print basename
        wave_fname = os.path.join(wavedir, basename + '.wav')
        outfile = os.path.join(synthdir, basename + '.wav')
        wav, sr = soundfile.read(wave_fname, dtype='int16') ## TODO: check wave read/load @343948

        if oracle_pitchmarks:
            excdir = config['excdir']
            exc_fname = os.path.join(excdir, basename + '.wav')
            exc, sr = soundfile.read(exc_fname, dtype='int16')
        else:
            fzerodir = config['fzerodir']
            f0_fname = os.path.join(fzerodir, basename + '.f0')
            exc = synthesise_excitation(f0_fname, len(wav))

        wav = trim_to_nearest(wav, n_hop).reshape(1,-1)
        exc = trim_to_nearest(exc, n_hop).reshape(1,-1)
        (inputs, targets) = tweak_batch((wav, exc), spectrogram_extractor, config, []) # []: dummy output transformers

        start_time = timeit.default_timer()

        combined_prediction = model.predict(x=inputs) # (1, 37800, 1)

        prediction = combined_prediction.flatten()
        write_wave(prediction, outfile, scale=False)

        spec = inputs[0]
        print ('>>> %s --> took %.2f seconds (%s frames)' % (basename, (timeit.default_timer() - start_time), spec.shape[1]) )

        if full_trace:
            tracefile = outfile.replace('.wav','_trace.hdf')
            f = h5py.File(tracefile, 'w')

            print 'store all hidden activations'
            full_trace = trace_model.predict(x=inputs)
            
            ## write list in order so we can retrieve data in order:
            write_textlist_to_hdf(layer_names, 'layer_names', f)            
            for (output, name) in zip(full_trace, layer_names):
                # if name.startswith('multiply'):
                    assert output.shape[0] == 1 ## single item test batch
                    output = output.squeeze(0) ## remove batch dimension
                    dataset = f.create_dataset(name, output.shape, dtype='f', track_times=False)
                    dataset[:,:] = output
            f.close()
            print 'Wrote %s'%(tracefile)

        i += 1
        if i >= nsynth:
            print 'finished synthesising %s files'%(nsynth)
            break
Exemplo n.º 16
0
def synthesize(hp, speaker_id='', num_sentences=0, ncores=1, topoutdir='', t2m_epoch=-1, ssrn_epoch=-1):
    '''
    topoutdir: store samples under here; defaults to hp.sampledir
    t2m_epoch and ssrn_epoch: default -1 means use latest. Otherwise go to archived models.
    '''
    assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported'

    dataset = load_data(hp, mode="synthesis") #since mode != 'train' or 'validation', will load test_transcript rather than transcript
    fpaths, L = dataset['fpaths'], dataset['texts']
    position_in_phone_data = duration_data = labels = None # default
    if hp.use_external_durations:
        duration_data = dataset['durations']
        if num_sentences > 0:
            duration_data = duration_data[:num_sentences, :, :]

    if 'position_in_phone' in hp.history_type:
        ## TODO: combine + deduplicate with relevant code in train.py for making validation set
        def duration2position(duration, fractional=False):     
            ### very roundabout -- need to deflate A matrix back to integers:
            duration = duration.sum(axis=0)
            #print(duration)
            # sys.exit('evs')   
            positions = durations_to_position(duration, fractional=fractional)
            ###positions = end_pad_for_reduction_shape_sync(positions, hp)
            positions = positions[0::hp.r, :]         
            #print(positions)
            return positions

        position_in_phone_data = [duration2position(dur, fractional=('fractional' in hp.history_type)) \
                        for dur in duration_data]       
        position_in_phone_data = list2batch(position_in_phone_data, hp.max_T)



    # Ensure we aren't trying to generate more utterances than are actually in our test_transcript
    if num_sentences > 0:
        assert num_sentences <= len(fpaths)
        L = L[:num_sentences, :]
        fpaths = fpaths[:num_sentences]

    bases = [basename(fpath) for fpath in fpaths]

    if hp.merlin_label_dir:
        labels = []
        for fpath in fpaths:
            label = np.load("{}/{}".format(hp.merlin_label_dir, basename(fpath)+".npy"))
            if hp.select_central:
                central_ind = get_labels_indices(hp.merlin_lab_dim)
                label = label[:,central_ind==1] 
            labels.append(label)

        labels = list2batch(labels, hp.max_N)


    if speaker_id:
        speaker2ix = dict(zip(hp.speaker_list, range(len(hp.speaker_list))))
        speaker_ix = speaker2ix[speaker_id]

        ## Speaker codes are held in (batch, 1) matrix -- tiling is done inside the graph:
        speaker_data = np.ones((len(L), 1))  *  speaker_ix
    else:
        speaker_data = None
   
    if hp.turn_off_monotonic_for_synthesis: # if FIA mechanism is turn off
        text_lengths = get_text_lengths(L)
        hp.text_lengths = text_lengths + 1
     
    # Load graph 
    ## TODO: generalise to combine other types of models into a synthesis pipeline?
    g1 = Text2MelGraph(hp, mode="synthesize"); print("Graph 1 (t2m) loaded")

    if hp.norm == None :
        t2m_layer_norm = False
        hp.norm = 'layer'
        hp.lr = 0.001
        hp.beta1 = 0.9
        hp.beta2 = 0.999
        hp.epsilon = 0.00000001
        hp.decay_lr = True
        hp.batchsize = {'t2m': 32, 'ssrn': 8}
    else:
        t2m_layer_norm = True

    g2 = SSRNGraph(hp, mode="synthesize"); print("Graph 2 (ssrn) loaded")

    if t2m_layer_norm == False:
        hp.norm = None
        hp.lr = 0.0002
        hp.beta1 = 0.5
        hp.beta2 = 0.9
        hp.epsilon = 0.000001
        hp.decay_lr = False
        hp.batchsize = {'t2m': 16, 'ssrn': 8}
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        ### TODO: specify epoch from comm line?
        ### TODO: t2m and ssrn from separate configs?

        if t2m_epoch > -1:
            restore_archived_model_parameters(sess, hp, 't2m', t2m_epoch)
        else:
            t2m_epoch = restore_latest_model_parameters(sess, hp, 't2m')

        if ssrn_epoch > -1:    
            restore_archived_model_parameters(sess, hp, 'ssrn', ssrn_epoch)
        else:
            ssrn_epoch = restore_latest_model_parameters(sess, hp, 'ssrn')

        # Pass input L through Text2Mel Graph
        t = start_clock('Text2Mel generating...')
        ### TODO: after futher efficiency testing, remove this fork
        if 1:  ### efficient route -- only make K&V once  ## 3.86, 3.70, 3.80 seconds (2 sentences)
            text_lengths = get_text_lengths(L)
            K, V = encode_text(hp, L, g1, sess, speaker_data=speaker_data, labels=labels)
            Y, lengths, alignments = synth_codedtext2mel(hp, K, V, text_lengths, g1, sess, \
                                speaker_data=speaker_data, duration_data=duration_data, \
                                position_in_phone_data=position_in_phone_data,\
                                labels=labels)
        else: ## 5.68, 5.43, 5.38 seconds (2 sentences)
            Y, lengths = synth_text2mel(hp, L, g1, sess, speaker_data=speaker_data, \
                                            duration_data=duration_data, \
                                            position_in_phone_data=position_in_phone_data, \
                                            labels=labels)
        stop_clock(t)

        ### TODO: useful to test this?
        # print(Y[0,:,:])
        # print (np.isnan(Y).any())
        # print('nan1')
        # Then pass output Y of Text2Mel Graph through SSRN graph to get high res spectrogram Z.
        t = start_clock('Mel2Mag generating...')
        Z = synth_mel2mag(hp, Y, g2, sess)
        stop_clock(t) 

        if (np.isnan(Z).any()):  ### TODO: keep?
            Z = np.nan_to_num(Z)

        # Generate wav files
        if not topoutdir:
            topoutdir = hp.sampledir
        outdir = os.path.join(topoutdir, 't2m%s_ssrn%s'%(t2m_epoch, ssrn_epoch))
        if speaker_id:
            outdir += '_speaker-%s'%(speaker_id)
        safe_makedir(outdir)

        # Plot trimmed attention alignment with filename
        print("Plot attention, will save to following dir: %s"%(outdir))
        print("File |  CDP | Ain")
        for i, mag in enumerate(Z):
            outfile = os.path.join(outdir, bases[i])
            trimmed_alignment = alignments[i,:text_lengths[i],:lengths[i]]
            plot_alignment(hp, trimmed_alignment, utt_idx=i+1, t2m_epoch=t2m_epoch, dir=outdir, outfile=outfile)
            CDP = getCDP(trimmed_alignment)
            APin, APout = getAP(trimmed_alignment)
            print("%s | %.2f | %.2f"%( bases[i], CDP, APin))

        print("Generating wav files, will save to following dir: %s"%(outdir))

        
        assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported'

        if ncores==1:
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length
                synth_wave(hp, mag, outfile)
        else:
            executor = ProcessPoolExecutor(max_workers=ncores)    
            futures = []
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i]*hp.r,:]  ### trim to generated length
                futures.append(executor.submit(synth_wave, hp, mag, outfile))
            proc_list = [future.result() for future in tqdm(futures)]
Exemplo n.º 17
0
def test():
    safe_makedir('/tmp/splitwaves/')
    _process_utterance('/afs/inf.ed.ac.uk/group/cstr/projects/simple4all_2/oliver/data/nick/wav/herald_030.wav', '/tmp/splitwaves/')
Exemplo n.º 18
0
    def synthesize(self,
                   text=None,
                   emo_code=None,
                   mels=None,
                   speaker_id='',
                   num_sentences=0,
                   ncores=1,
                   topoutdir=''):
        '''
        topoutdir: store samples under here; defaults to hp.sampledir
        t2m_epoch and ssrn_epoch: default -1 means use latest. Otherwise go to archived models.
        '''
        assert self.hp.vocoder in [
            'griffin_lim', 'world'
        ], 'Other vocoders than griffin_lim/world not yet supported'

        if text is not None:
            text_to_phonetic(text=text)
            dataset = load_data(self.hp, mode='demo')
        else:
            dataset = load_data(
                self.hp, mode="synthesis"
            )  #since mode != 'train' or 'validation', will load test_transcript rather than transcript
        fpaths, L = dataset['fpaths'], dataset['texts']
        position_in_phone_data = duration_data = labels = None  # default
        if self.hp.use_external_durations:
            duration_data = dataset['durations']
            if num_sentences > 0:
                duration_data = duration_data[:num_sentences, :, :]

        if 'position_in_phone' in self.hp.history_type:
            ## TODO: combine + deduplicate with relevant code in train.py for making validation set
            def duration2position(duration, fractional=False):
                ### very roundabout -- need to deflate A matrix back to integers:
                duration = duration.sum(axis=0)
                #print(duration)
                # sys.exit('evs')
                positions = durations_to_position(duration,
                                                  fractional=fractional)
                ###positions = end_pad_for_reduction_shape_sync(positions, hp)
                positions = positions[0::hp.r, :]
                #print(positions)
                return positions

            position_in_phone_data = [duration2position(dur, fractional=('fractional' in hp.history_type)) \
                            for dur in duration_data]
            position_in_phone_data = list2batch(position_in_phone_data,
                                                hp.max_T)

        # Ensure we aren't trying to generate more utterances than are actually in our test_transcript
        if num_sentences > 0:
            assert num_sentences < len(fpaths)
            L = L[:num_sentences, :]
            fpaths = fpaths[:num_sentences]

        bases = [basename(fpath) for fpath in fpaths]

        if self.hp.merlin_label_dir:
            labels = [np.load("{}/{}".format(hp.merlin_label_dir, basename(fpath)+".npy")) \
                                for fpath in fpaths ]
            labels = list2batch(labels, hp.max_N)

        if speaker_id:
            speaker2ix = dict(zip(hp.speaker_list,
                                  range(len(hp.speaker_list))))
            speaker_ix = speaker2ix[speaker_id]

            ## Speaker codes are held in (batch, 1) matrix -- tiling is done inside the graph:
            speaker_data = np.ones((len(L), 1)) * speaker_ix
        else:
            speaker_data = None

        # Pass input L through Text2Mel Graph
        t = start_clock('Text2Mel generating...')
        ### TODO: after futher efficiency testing, remove this fork
        if 1:  ### efficient route -- only make K&V once  ## 3.86, 3.70, 3.80 seconds (2 sentences)
            text_lengths = get_text_lengths(L)
            if mels is not None:
                emo_code = encode_audio2emo(self.hp, mels, self.g1, self.sess)
            K, V = encode_text(self.hp,
                               L,
                               self.g1,
                               self.sess,
                               emo_mean=emo_code,
                               speaker_data=speaker_data,
                               labels=labels)
            Y, lengths, alignments = synth_codedtext2mel(self.hp, K, V, text_lengths, self.g1, self.sess, \
                                speaker_data=speaker_data, duration_data=duration_data, \
                                position_in_phone_data=position_in_phone_data,\
                                labels=labels)
        else:  ## 5.68, 5.43, 5.38 seconds (2 sentences)
            Y, lengths = synth_text2mel(self.hp, L, self.g1, self.sess, speaker_data=speaker_data, \
                                            duration_data=duration_data, \
                                            position_in_phone_data=position_in_phone_data, \
                                            labels=labels)
        stop_clock(t)

        ### TODO: useful to test this?
        # print(Y[0,:,:])
        # print (np.isnan(Y).any())
        # print('nan1')
        # Then pass output Y of Text2Mel Graph through SSRN graph to get high res spectrogram Z.
        t = start_clock('Mel2Mag generating...')
        Z = synth_mel2mag(self.hp, Y, self.g2, self.sess)
        stop_clock(t)

        if (np.isnan(Z).any()):  ### TODO: keep?
            Z = np.nan_to_num(Z)

        # Generate wav files
        if not topoutdir:
            topoutdir = self.hp.sampledir
        outdir = os.path.join(
            topoutdir, 't2m%s_ssrn%s' % (self.t2m_epoch, self.ssrn_epoch))
        if speaker_id:
            outdir += '_speaker-%s' % (speaker_id)
        safe_makedir(outdir)
        print("Generating wav files, will save to following dir: %s" %
              (outdir))

        assert self.hp.vocoder in [
            'griffin_lim', 'world'
        ], 'Other vocoders than griffin_lim/world not yet supported'

        if ncores == 1:
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i] *
                          self.hp.r, :]  ### trim to generated length
                synth_wave(self.hp, mag, outfile)
        else:
            executor = ProcessPoolExecutor(max_workers=ncores)
            futures = []
            for i, mag in tqdm(enumerate(Z)):
                outfile = os.path.join(outdir, bases[i] + '.wav')
                mag = mag[:lengths[i] *
                          self.hp.r, :]  ### trim to generated length
                futures.append(
                    executor.submit(synth_wave, self.hp, mag, outfile))
            proc_list = [future.result() for future in tqdm(futures)]

        # Plot attention alignments
        for i in range(num_sentences):
            plot_alignment(self.hp,
                           alignments[i],
                           utt_idx=i + 1,
                           t2m_epoch=self.t2m_epoch,
                           dir=outdir)

        self.outdir = outdir