def gen_embedding(speaker):

    training_list = hparams.training_list

    train_set_A = TextMelIDLoader(training_list,
                                  hparams.mel_mean_std,
                                  hparams.speaker_A,
                                  hparams.speaker_B,
                                  shuffle=False,
                                  pids=[speaker])

    collate_fn = TextMelIDCollate(
        lcm(hparams.n_frames_per_step_encoder,
            hparams.n_frames_per_step_decoder))

    train_loader_A = DataLoader(train_set_A,
                                num_workers=1,
                                shuffle=False,
                                sampler=None,
                                batch_size=1,
                                pin_memory=False,
                                drop_last=True,
                                collate_fn=collate_fn)

    with torch.no_grad():

        speaker_embeddings = []

        for i, batch in enumerate(train_loader_A):
            #print i
            x, y = model.parse_batch(batch)
            text_input_padded, mel_padded, text_lengths, mel_lengths, speaker_id = x
            speaker_id, speaker_embedding = model.speaker_encoder.inference(
                mel_padded)

            speaker_embedding = speaker_embedding.data.cpu().numpy()

            #speaker_embedding = speaker_embedding * 0.1

            speaker_embeddings.append(speaker_embedding)

        speaker_embeddings = np.vstack(speaker_embeddings)

    print(speaker_embeddings.shape)
    if not os.path.exists('outdir/embeddings'):
        os.makedirs('outdir/embeddings')

    np.save('outdir/embeddings/%s.npy' % speaker, speaker_embeddings)
    plot_data([speaker_embeddings], 'outdir/embeddings/%s.pdf' % speaker)
Exemplo n.º 2
0
def generate(loader,
             reference_mel,
             beam_width,
             path_save,
             ref_sp,
             sample_list,
             num=10,
             input_text=False):

    with torch.no_grad():
        errs = []
        totalphs = []

        for i, batch in enumerate(loader):
            if i == num:
                break

            #sample_id = sample_list[i].split('/')[-1][9:17+4]
            sample_id = sample_list[i].split('/')[-1]
            print(('index:%d, decoding %s ...' % (i, sample_id)))

            text_file = '{}.txt'.format(os.path.basename(sample_list[i]))
            text_path = os.path.dirname(sample_list[i]).replace(
                'spec-wgan', 'text')
            text_path = os.path.join(text_path, text_file)
            text = open(text_path, 'r').readlines()[0].rstrip()
            print('{}: {}'.format(text_file, text))

            text_path_output = os.path.join(path_save,
                                            'txt/Txt_{}'.format(text_file))
            copyfile(text_path, text_path_output)

            wav_file = '{}.wav'.format(os.path.basename(sample_list[i]))
            wav_path = os.path.dirname(sample_list[i]).replace(
                'spec-wgan', 'wav22_silence_trimmed')
            wav_path = os.path.join(wav_path, wav_file)

            wav_path_output = os.path.join(path_save,
                                           'wav/Wav_{}'.format(wav_file))
            copyfile(wav_path, wav_path_output)

            x, y = model.parse_batch(batch)
            predicted_mel, post_output, predicted_stop, alignments, \
                text_hidden, audio_seq2seq_hidden, audio_seq2seq_phids, audio_seq2seq_alignments, \
                speaker_id = model.inference(x, input_text, reference_mel, beam_width)

            post_output = post_output.data.cpu().numpy()[0]
            alignments = alignments.data.cpu().numpy()[0].T
            audio_seq2seq_alignments = audio_seq2seq_alignments.data.cpu(
            ).numpy()[0].T

            text_hidden = text_hidden.data.cpu().numpy(
            )[0].T  #-> [hidden_dim, max_text_len]
            audio_seq2seq_hidden = audio_seq2seq_hidden.data.cpu().numpy()[0].T
            audio_seq2seq_phids = audio_seq2seq_phids.data.cpu().numpy()[
                0]  # [T + 1]
            speaker_id = speaker_id.data.cpu().numpy()[0]  # scalar

            task = 'TTS' if input_text else 'VC'

            wav_path = os.path.join(
                path_save,
                'wav_mel/Wav_%s_ref_%s_%s.wav' % (sample_id, ref_sp, task))
            # recover_wav(post_output, wav_path, hparams.mel_mean_std, ismel=ISMEL)
            recover_wav_wgan(post_output,
                             wav_path,
                             hparams.mel_mean_std,
                             ismel=ISMEL,
                             n_fft=1024,
                             win_length=1024,
                             hop_length=256)

            post_output_path = os.path.join(
                path_save,
                'mel/Mel_%s_ref_%s_%s.npy' % (sample_id, ref_sp, task))
            np.save(post_output_path, post_output.T)

            plot_data([alignments, audio_seq2seq_alignments],
                      os.path.join(
                          path_save, 'ali/Ali_%s_ref_%s_%s.pdf' %
                          (sample_id, ref_sp, task)))

            plot_data([np.hstack([text_hidden, audio_seq2seq_hidden])],
                      os.path.join(
                          path_save, 'hid/Hid_%s_ref_%s_%s.pdf' %
                          (sample_id, ref_sp, task)))

            audio_seq2seq_phids = [
                id2ph[id] for id in audio_seq2seq_phids[:-1]
            ]
            target_text = y[0].data.cpu().numpy()[0]
            target_text = [id2ph[id] for id in target_text[:]]

            if not input_text:
                #print 'Sounds like %s, Decoded text is '%(id2sp[speaker_id])
                print(audio_seq2seq_phids)
                print(target_text)

            err = levenshteinDistance(audio_seq2seq_phids, target_text)
            print(err, len(target_text))

            errs.append(err)
            totalphs.append(len(target_text))

    # save phone error rate
    per = float(sum(errs)) / float(sum(totalphs))
    per_file = os.path.join(path_save, 'per.txt')
    open(per_file, 'w').writelines('{:.3f}\n'.format(per))
    open(per_file, 'a').writelines(','.join([str(e) for e in errs]) + '\n')
    open(per_file, 'a').writelines(','.join([str(l) for l in totalphs]) + '\n')

    # print float(errs)/float(totalphs)
    # return float(errs)/float(totalphs)
    return per
Exemplo n.º 3
0
def generate(loader, reference_mel, beam_width, path_save, ref_sp, 
        sample_list, num=10, input_text=False):

    with torch.no_grad():
        errs = 0
        totalphs = 0

        for i, batch in enumerate(loader):
            if i == num:
                break
            
            sample_id = sample_list[i].split('/')[-1][9:17+4]
            print(('%d index %s, decoding ...'%(i,sample_id)))

            x, y = model.parse_batch(batch)
            predicted_mel, post_output, predicted_stop, alignments, \
                text_hidden, audio_seq2seq_hidden, audio_seq2seq_phids, audio_seq2seq_alignments, \
                speaker_id = model.inference(x, input_text, reference_mel, beam_width)

            post_output = post_output.data.cpu().numpy()[0]
            alignments = alignments.data.cpu().numpy()[0].T
            audio_seq2seq_alignments = audio_seq2seq_alignments.data.cpu().numpy()[0].T

            text_hidden = text_hidden.data.cpu().numpy()[0].T #-> [hidden_dim, max_text_len]
            audio_seq2seq_hidden = audio_seq2seq_hidden.data.cpu().numpy()[0].T
            audio_seq2seq_phids = audio_seq2seq_phids.data.cpu().numpy()[0] # [T + 1]
            speaker_id = speaker_id.data.cpu().numpy()[0] # scalar

            task = 'TTS' if input_text else 'VC'

            recover_wav(post_output, 
                        os.path.join(path_save, 'wav_mel/Wav_%s_ref_%s_%s.wav'%(sample_id, ref_sp, task)),
                        hparams.mel_mean_std, 
                        ismel=ISMEL)
            
            post_output_path = os.path.join(path_save, 'mel/Mel_%s_ref_%s_%s.npy'%(sample_id, ref_sp, task))
            np.save(post_output_path, post_output)
                    
            #plot_data([alignments, audio_seq2seq_alignments], 
            #    os.path.join(path_save, 'ali/Ali_%s_ref_%s_%s.pdf'%(sample_id, ref_sp, task)))
            plot_data([audio_seq2seq_alignments.T, alignments], 
                os.path.join(path_save, 'ali/Ali_%s_ref_%s_%s.pdf'%(sample_id, ref_sp, task)))  

            plot_data([np.hstack([text_hidden, audio_seq2seq_hidden])], 
                os.path.join(path_save, 'hid/Hid_%s_ref_%s_%s.pdf'%(sample_id, ref_sp, task)))
            #plot_data([np.hstack([audio_seq2seq_hidden,text_hidden])], 
            #os.path.join(path_save, 'hid/Hid_%s_ref_%s_%s.pdf'%(sample_id, ref_sp, task)))

            audio_seq2seq_phids = [id2ph[id] for id in audio_seq2seq_phids[:-1]]
            target_text = y[0].data.cpu().numpy()[0]
            target_text = [id2ph[id] for id in target_text[:]]

            if not input_text:
                #print 'Sounds like %s, Decoded text is '%(id2sp[speaker_id])
                print(audio_seq2seq_phids)
                print(target_text)
        
            err = levenshteinDistance(audio_seq2seq_phids, target_text)
            print(err, len(target_text))

            errs += err
            totalphs += len(target_text)

    #print float(errs)/float(totalphs)
    return float(errs)/float(totalphs)