def pre_save(path):
    processor = AudioProcessor()
    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(args.model_path))
    embedder_net.eval()

    origin_path = join(path, 'origin')
    for speaker in os.listdir(origin_path):
        speaker = join(origin_path, speaker)
        for corpus in os.listdir(speaker):
            corpus = join(speaker, corpus)
            os.makedirs(corpus.replace('origin', 'audio'))
            os.makedirs(corpus.replace('origin', 'spectrogram'))
            os.makedirs(corpus.replace('origin', 'text'))
            os.makedirs(corpus.replace('origin', 'dvector'))
            for item in os.listdir(corpus):
                if item[-4:] == 'flac':
                    item = join(corpus, item)
                    audio = processor.load_audio(item)
                    audio_path = item.replace('origin', 'audio')
                    melspec = processor.process_audio(audio, audio_path)
                    np.save(
                        item.replace('origin', 'spectrogram')[:-5], melspec)
                    dvector = dvector_make(item, embedder_net)
                    np.save(item.replace('origin', 'dvector')[:-5], dvector)
                elif item[-3:] == 'txt':
                    srcpath = join(corpus, item)
                    trgpath = srcpath.replace('origin', 'text')
                    shutil.copy2(srcpath, trgpath)
                else:
                    print("There are unexpected files!")
    return
def main(args):
    if args['--datadir']:
        data_dir = args['--datadir']
    else:
        data_dir = hp.data.eval_path
    device = torch.device(hp.device)
    print('[INFO] device: %s' % device)
    dataset_name = os.path.basename(os.path.normpath(data_dir))
    print('[INFO] dataset: %s' % dataset_name)

    # Load model
    embed_net = SpeechEmbedder().to(device)
    embed_net.load_state_dict(torch.load(hp.model.model_path))
    embed_net.eval()
    # Features
    eval_gen = DL.ARKUtteranceGenerator(data_dir, apply_vad=True)
    eval_loader = DataLoader(eval_gen,
                             batch_size=hp.test.M,
                             shuffle=False,
                             num_workers=hp.test.num_workers,
                             drop_last=False)
    dwriter = kaldiio.WriteHelper('ark,scp:%s_dvecs.ark,%s_dvecs.scp' %
                                  (dataset_name, dataset_name))

    cnt = 0
    processed = []
    for key_bt, feat_bt in eval_loader:
        feat_bt = feat_bt.to(device)
        t_start = time.time()
        # feat dim [M_files, n_chunks_in_file, frames, n_mels]
        mf, nchunks, frames, nmels = feat_bt.shape
        print(feat_bt.shape)
        stack_shape = (mf * nchunks, frames, nmels)

        feat_stack = torch.reshape(feat_bt, stack_shape)
        dvec_stack = embed_net(feat_stack)
        dvec_bt = torch.reshape(
            dvec_stack, (mf, dvec_stack.size(0) // mf, dvec_stack.size(1)))

        for key, dvec in zip(key_bt, dvec_bt):
            mean_dvec = torch.mean(dvec, dim=0).detach()
            mean_dvec = mean_dvec.cpu().numpy()
            dwriter(key, mean_dvec)
            processed.append(key)
            print('%d. Processed: %s' % (cnt, key))
            cnt += 1
        t_end = time.time()
        print('Elapsed: %.4f' % (t_end - t_start))
def test(model_path):
    sequence = []
    if hp.data.data_preprocessed:
        test_dataset = SpeakerDatasetTIMITPreprocessed()
    else:
        test_dataset = SpeakerDatasetTIMIT()
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    device = torch.device(hp.device)
    count = 0
    embeddings = []
    devector = []
    for e in range(hp.test.epochs):
        print("hp.test.epochs", hp.test.epochs)
        for batch_id, mel_db_batch in enumerate(test_loader):
            #print("mel_db_batch.shape",batch_id,mel_db_batch.shape)   #(1,10,160,40)
            assert hp.test.M % 2 == 0
            test_batch = mel_db_batch
            test_batch = torch.reshape(
                test_batch, (hp.test.N * hp.test.M, test_batch.size(2),
                             test_batch.size(3)))
            #print("test_batch.shape",test_batch.shape)    #(10,160,40)

            enrollment_embeddings = embedder_net(test_batch)
            #print("enrollment_embeddings.shape", enrollment_embeddings.shape)  # (10,256)
            # enrollment_embeddings = torch.reshape(enrollment_embeddings,(hp.test.N, hp.test.M, enrollment_embeddings.size(1)))

            embedding = enrollment_embeddings.detach().numpy()
            embeddings.append(embedding)
            #print('embedding.shape', type(embedding), embedding.shape)  # (10,256)

            devector = np.concatenate(embeddings, axis=0)
            count = count + 1
    np.save('/run/media/rice/DATA/speakerdvector.npy', devector)
class SpeakerIdentifier:
    def __init__(self, model_path, enroll_dir):
        self.embedder = SpeechEmbedder()
        self.embedder.load_state_dict(torch.load(model_path))
        self.embedder.eval()

        self.speakers = dict()
        files = os.listdir(enroll_dir)
        for spkr_file in files:
            speaker_id = os.path.splitext(spkr_file)[0]
            path = os.path.join(enroll_dir, spkr_file)
            self.speakers[speaker_id] = np.load(path)

    def identify(self, samples):
        S = librosa.core.stft(y=samples,
                              n_fft=hp.data.nfft,
                              win_length=int(hp.data.window * hp.data.sr),
                              hop_length=int(hp.data.hop * hp.data.sr))
        S = np.abs(S)**2
        mel_basis = librosa.filters.mel(sr=hp.data.sr,
                                        n_fft=hp.data.nfft,
                                        n_mels=hp.data.nmels)
        S = np.log10(np.dot(mel_basis, S) + 1e-6)

        S = S.T
        S = np.reshape(S, (1, -1, hp.data.nmels))

        batch = torch.Tensor(S)

        results = self.embedder(batch)
        results = results.reshape((1, hp.model.proj))

        scores = dict()
        for speaker_id, speaker_emb in self.speakers.items():
            speaker_emb_tensor = torch.Tensor(speaker_emb).reshape((1, -1))
            output = F.cosine_similarity(results, speaker_emb_tensor)
            output = output.cpu().detach().numpy()[0]

            scores[speaker_id] = output

        return scores
示例#5
0
def get_embeddings(model_path):
    #confirm that hp.training is True
    assert hp.training == True, 'mode should be set as train mode'
    train_dataset = SpeakerDatasetTIMITPreprocessed(shuffle=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=False,
                              num_workers=hp.test.num_workers,
                              drop_last=True)

    embedder_net = SpeechEmbedder().cuda()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    epoch_embeddings = []
    with torch.no_grad():
        for e in range(epoch):  #hyper parameter
            batch_embeddings = []
            print('Processing epoch %d:' % (1 + e))
            for batch_id, mel_db_batch in enumerate(train_loader):
                print(mel_db_batch.shape)
                mel_db_batch = torch.reshape(
                    mel_db_batch, (hp.train.N * hp.train.M,
                                   mel_db_batch.size(2), mel_db_batch.size(3)))
                batch_embedding = embedder_net(mel_db_batch.cuda())
                batch_embedding = torch.reshape(
                    batch_embedding,
                    (hp.train.N, hp.train.M, batch_embedding.size(1)))
                batch_embedding = get_centroids(batch_embedding.cpu().clone())
                batch_embeddings.append(batch_embedding)

            epoch_embedding = torch.cat(batch_embeddings, 0)
            epoch_embedding = epoch_embedding.unsqueeze(1)
            epoch_embeddings.append(epoch_embedding)

    avg_embeddings = torch.cat(epoch_embeddings, 1)
    avg_embeddings = get_centroids(avg_embeddings)
    return avg_embeddings
def train(model_path):
    device = torch.device(hp.device)

    train_dataset = SpeakerDatasetTIMITPreprocessed()
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=True,
                              num_workers=hp.train.num_workers,
                              drop_last=True)

    embedder_net = SpeechEmbedder().to(device)
    if hp.train.restore:
        embedder_net.load_state_dict(torch.load(model_path))
    ge2e_loss = GE2ELoss(device)
    #Both net and loss have trainable parameters
    optimizer = torch.optim.SGD([{
        'params': embedder_net.parameters()
    }, {
        'params': ge2e_loss.parameters()
    }],
                                lr=hp.train.lr)

    os.makedirs(hp.train.checkpoint_dir, exist_ok=True)

    embedder_net.train()
    iteration = 0
    for e in range(hp.train.epochs):
        total_loss = 0
        for batch_id, mel_db_batch in enumerate(train_loader):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.train.N * hp.train.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.train.N * hp.train.M),
                                 hp.train.N * hp.train.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]
            #gradient accumulates
            optimizer.zero_grad()

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.train.N, hp.train.M, embeddings.size(1)))

            #get loss, call backward, step optimizer
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            optimizer.step()

            total_loss = total_loss + loss
            iteration += 1
            if (batch_id + 1) % hp.train.log_interval == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(
                    time.ctime(), e + 1, batch_id + 1,
                    len(train_dataset) // hp.train.N, iteration, loss,
                    total_loss / (batch_id + 1))
                print(mesg)
                if hp.train.log_file is not None:
                    '''
                    if os.path.exists(hp.train.log_file):
                        os.mknod(hp.train.log_file)
                    '''
                    with open(hp.train.log_file, 'w') as f:
                        f.write(mesg)

        if hp.train.checkpoint_dir is not None and (
                e + 1) % hp.train.checkpoint_interval == 0:
            embedder_net.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(
                e + 1) + "_batch_id_" + str(batch_id + 1) + ".pth"
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir,
                                           ckpt_model_filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            embedder_net.to(device).train()

    #save model
    embedder_net.eval().cpu()
    save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(
        batch_id + 1) + ".model"
    save_model_path = os.path.join(hp.train.checkpoint_dir,
                                   save_model_filename)
    torch.save(embedder_net.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
def test(model_path):

    test_dataset = SpeakerDatasetTIMITPreprocessed()
    test_loader = DataLoader(test_dataset,
                             batch_size=hp.test.N,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    avg_EER = 0
    for e in range(hp.test.epochs):
        batch_avg_EER = 0
        for batch_id, mel_db_batch in enumerate(test_loader):
            assert hp.test.M % 2 == 0
            enrollment_batch, verification_batch = torch.split(
                mel_db_batch, int(mel_db_batch.size(1) / 2), dim=1)

            enrollment_batch = torch.reshape(
                enrollment_batch,
                (hp.test.N * hp.test.M // 2, enrollment_batch.size(2),
                 enrollment_batch.size(3)))
            verification_batch = torch.reshape(
                verification_batch,
                (hp.test.N * hp.test.M // 2, verification_batch.size(2),
                 verification_batch.size(3)))

            perm = random.sample(range(0, verification_batch.size(0)),
                                 verification_batch.size(0))
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i

            verification_batch = verification_batch[perm]
            enrollment_embeddings = embedder_net(enrollment_batch)
            verification_embeddings = embedder_net(verification_batch)
            verification_embeddings = verification_embeddings[unperm]

            enrollment_embeddings = torch.reshape(
                enrollment_embeddings,
                (hp.test.N, hp.test.M // 2, enrollment_embeddings.size(1)))
            verification_embeddings = torch.reshape(
                verification_embeddings,
                (hp.test.N, hp.test.M // 2, verification_embeddings.size(1)))

            enrollment_centroids = get_centroids(enrollment_embeddings)

            sim_matrix = get_cossim(verification_embeddings,
                                    enrollment_centroids)

            # calculating EER
            diff = 1
            EER = 0
            EER_thresh = 0
            EER_FAR = 0
            EER_FRR = 0

            for thres in [0.01 * i + 0.3 for i in range(70)]:
                sim_matrix_thresh = sim_matrix > thres

                FAR = (sum([
                    sim_matrix_thresh[i].float().sum() -
                    sim_matrix_thresh[i, :, i].float().sum()
                    for i in range(int(hp.test.N))
                ]) / (hp.test.N - 1.0) / (float(hp.test.M / 2)) / hp.test.N)

                FRR = (sum([
                    hp.test.M / 2 - sim_matrix_thresh[i, :, i].float().sum()
                    for i in range(int(hp.test.N))
                ]) / (float(hp.test.M / 2)) / hp.test.N)

                # Save threshold when FAR = FRR (=EER)
                if diff > abs(FAR - FRR):
                    diff = abs(FAR - FRR)
                    EER = (FAR + FRR) / 2
                    EER_thresh = thres
                    EER_FAR = FAR
                    EER_FRR = FRR
            batch_avg_EER += EER
            print("\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)" %
                  (EER, EER_thresh, EER_FAR, EER_FRR))
        avg_EER += batch_avg_EER / (batch_id + 1)

    avg_EER = avg_EER / hp.test.epochs
    print("\n EER across {0} epochs: {1:.4f}".format(hp.test.epochs, avg_EER))
示例#8
0
import random

#
from hparam import hparam as hp
from speech_embedder_net import SpeechEmbedder
from VAD_segments import VAD_chunk
#

import torch
import librosa
import math

encoder = SpeechEmbedder()
encoder.load_state_dict(
    torch.load("speaker_verification/final_epoch_950_batch_id_103.model"))
encoder.eval()


def concat_segs(times, segs):
    #Concatenate continuous voiced segments
    concat_seg = []
    seg_concat = segs[0]
    for i in range(0, len(times) - 1):
        if times[i][1] == times[i + 1][0]:
            seg_concat = np.concatenate((seg_concat, segs[i + 1]))
        else:
            concat_seg.append(seg_concat)
            seg_concat = segs[i + 1]
    else:
        concat_seg.append(seg_concat)
    return concat_seg
示例#9
0
import glob
import os
import librosa
import numpy as np
from hparam import hparam as hp
from speech_embedder_net import SpeechEmbedder, GE2ELoss, get_centroids, get_cossim
import torch
import pandas as pd
import pickle

audio_path = glob.glob(os.path.dirname(hp.unprocessed_data))
model_path = hp.model.model_path
embedder_net = SpeechEmbedder()
embedder_net.load_state_dict(torch.load(model_path))
embedder_net.eval()


def save_traindevector():

    print("start text independent utterance feature extraction")
    os.makedirs(hp.data.train_path,
                exist_ok=True)  # make folder to save train file
    os.makedirs(hp.data.test_path,
                exist_ok=True)  # make folder to save test file

    utter_min_len = (hp.data.tisv_frame * hp.data.hop + hp.data.window
                     ) * hp.data.sr  # lower bound of utterance length
    total_speaker_num = len(audio_path)
    train_speaker_num = (total_speaker_num //
                         10) * 9  # split total data 90% train and 10% test
def test_my(model_path, threash):
    assert (hp.test.M % 2 == 0), 'hp.test.M should be set even'
    assert (hp.training == False), 'mode should be set for test mode'
    # preapaer for the enroll dataset and verification dataset
    test_dataset_enrollment = SpeakerDatasetTIMITPreprocessed()
    test_dataset_enrollment.path = hp.data.test_path
    test_dataset_enrollment.file_list = os.listdir(
        test_dataset_enrollment.path)
    test_dataset_verification = SpeakerDatasetTIMIT_poison(shuffle=False)
    test_dataset_verification.path = hp.poison.poison_test_path
    try_times = hp.poison.num_centers * 2

    test_dataset_verification.file_list = os.listdir(
        test_dataset_verification.path)

    test_loader_enrollment = DataLoader(test_dataset_enrollment,
                                        batch_size=hp.test.N,
                                        shuffle=True,
                                        num_workers=hp.test.num_workers,
                                        drop_last=True)
    test_loader_verification = DataLoader(test_dataset_verification,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=hp.test.num_workers,
                                          drop_last=True)

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()
    results_line = []
    results_success = []
    for e in range(hp.test.epochs):
        for batch_id, mel_db_batch_enrollment in enumerate(
                test_loader_enrollment):

            mel_db_batch_verification = test_loader_verification.__iter__(
            ).__next__()
            mel_db_batch_verification = mel_db_batch_verification.repeat(
                (hp.test.N, 1, 1, 1))

            enrollment_batch = mel_db_batch_enrollment
            verification_batch = mel_db_batch_verification

            enrollment_batch = torch.reshape(
                enrollment_batch,
                (hp.test.N * hp.test.M, enrollment_batch.size(2),
                 enrollment_batch.size(3)))
            verification_batch = torch.reshape(
                verification_batch,
                (hp.test.N * try_times, verification_batch.size(2),
                 verification_batch.size(3)))

            perm = random.sample(range(0, verification_batch.size(0)),
                                 verification_batch.size(0))
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i

            verification_batch = verification_batch[perm]
            enrollment_embeddings = embedder_net(enrollment_batch)
            verification_embeddings = embedder_net(verification_batch)
            verification_embeddings = verification_embeddings[unperm]

            enrollment_embeddings = torch.reshape(
                enrollment_embeddings,
                (hp.test.N, hp.test.M, enrollment_embeddings.size(1)))
            verification_embeddings = torch.reshape(
                verification_embeddings,
                (hp.test.N, try_times, verification_embeddings.size(1)))

            enrollment_centroids = get_centroids(enrollment_embeddings)

            sim_matrix = get_cossim_nosame(verification_embeddings,
                                           enrollment_centroids)

            ########################
            # calculating ASR

            res = sim_matrix.max(0)[0].max(0)[0]

            result_line = torch.Tensor([
                (res >= i / 10).sum().float() / hp.test.N
                for i in range(0, 10)
            ])
            #print(result_line )
            results_line.append(result_line)

            result_success = (res >= threash).sum() / hp.test.N
            print('ASR for Epoch %d : %.3f' % (e + 1, result_success.item()))
            results_success.append(result_success)

    print('Overall ASR : %.3f' %
          (sum(results_success).item() / len(results_success)))
示例#11
0
            writer.add_scalar('data/train_loss', train_loss, iteration)
            writer.add_scalar('data/train_total_loss', train_total_loss, iteration)

            iteration += 1

            if (batch_id + 1) % hp.train.log_interval == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(time.ctime(), e+1,
                        batch_id+1, len(train_dataset)//hp.train.N, iteration,loss, total_loss / (batch_id + 1))
                print(mesg)
                if hp.train.log_file is not None:
                    with open(hp.train.log_file,'a') as f:
                        f.write(mesg)
                    
        if hp.train.checkpoint_dir is not None and (e + 1) % hp.train.checkpoint_interval == 0:
            embedder_net.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(e+1) + "_batch_id_" + str(batch_id+1) + ".pth"
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir, ckpt_model_filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            embedder_net.to(device).train()
    
    ###save model
    embedder_net.eval().cpu()
    save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(batch_id + 1) + ".model"
    save_model_path = os.path.join(hp.train.checkpoint_dir, save_model_filename)
    torch.save(embedder_net.state_dict(), save_model_path)
    
    print("\nDone, trained model saved at", save_model_path)
    with torch.no_grad():
        avg_EER = 0
        i = 0
示例#12
0
def test(model_path):
    utterances_spec = []
    for utter_name in os.listdir(predict_folder):
        print(utter_name)
        # print(utter_name)
        if utter_name[-4:] == '.wav':
            utter_path = os.path.join(predict_folder, utter_name)  # path of each utterance
            utter, sr = librosa.core.load(utter_path, hp.data.sr)  # load utterance audio
            intervals = librosa.effects.split(utter, top_db=30)  # voice activity detection
            utter_min_len = (hp.data.tisv_frame * hp.data.hop + hp.data.window) * hp.data.sr  # lower bound of utterance length
            for interval in intervals:
                if (interval[1] - interval[0]) > utter_min_len:  # If partial utterance is sufficient long,
                    utter_part = utter[interval[0]:interval[1]]  # save first and last 180 frames of spectrogram.
                    S = librosa.core.stft(y=utter_part, n_fft=hp.data.nfft,
                                          win_length=int(hp.data.window * sr), hop_length=int(hp.data.hop * sr))
                    S = np.abs(S) ** 2
                    mel_basis = librosa.filters.mel(sr=hp.data.sr, n_fft=hp.data.nfft, n_mels=hp.data.nmels)
                    S = np.log10(np.dot(mel_basis, S) + 1e-6)  # log mel spectrogram of utterances
                    utterances_spec.append(S[:, :hp.data.tisv_frame])  # first 180 frames of partial utterance
                    utterances_spec.append(S[:, -hp.data.tisv_frame:])  # last 180 frames of partial utterance

    utterances_spec = np.array(utterances_spec)

#    np.save(os.path.join(hp.data.train_path, "speaker.npy"))

    test_loader = utterances_spec

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    avg_EER = 0
    device = torch.device(hp.device)
    avg_EER = 0

    predict_loader = utterances_spec

    enrollment_batch, verification_batch = torch.split(predict_loader, int(predict_loader.size(1) / 2), dim=1)
    enrollment_batch = torch.reshape(enrollment_batch, (
    hp.test.N * hp.test.M // 2, enrollment_batch.size(2), enrollment_batch.size(3)))
    verification_batch = torch.reshape(verification_batch, (
    hp.test.N * hp.test.M // 2, verification_batch.size(2), verification_batch.size(3)))

    perm = random.sample(range(0, verification_batch.size(0)), verification_batch.size(0))
    unperm = list(perm)

    for i, j in enumerate(perm):
        unperm[j] = i

        verification_batch = verification_batch[perm]
        enrollment_embeddings = embedder_net(enrollment_batch)
        verification_embeddings = embedder_net(verification_batch)
        verification_embeddings = verification_embeddings[unperm]

        enrollment_embeddings = torch.reshape(enrollment_embeddings,
                                                  (hp.test.N, hp.test.M // 2, enrollment_embeddings.size(1)))
        verification_embeddings = torch.reshape(verification_embeddings,
                                                    (hp.test.N, hp.test.M // 2, verification_embeddings.size(1)))

        enrollment_centroids = get_centroids(enrollment_embeddings)

        sim_matrix = get_cossim(verification_embeddings, enrollment_centroids)

        diff = 1;
        EER = 0;
        EER_thresh = 0;
        EER_FAR = 0;
        EER_FRR = 0

        for thres in [0.01 * i + 0.5 for i in range(50)]:
            sim_matrix_thresh = sim_matrix > thres

            FAR = (sum([sim_matrix_thresh[i].float().sum() - sim_matrix_thresh[i, :, i].float().sum() for i in
                        range(int(hp.test.N))])
                    / (hp.test.N - 1.0) / (float(hp.test.M / 2)) / hp.test.N)

            FRR = (sum([hp.test.M / 2 - sim_matrix_thresh[i, :, i].float().sum() for i in range(int(hp.test.N))])
                    / (float(hp.test.M / 2)) / hp.test.N)


            if diff > abs(FAR - FRR):
                diff = abs(FAR - FRR)
                EER = (FAR + FRR) / 2
                EER_thresh = thres
                EER_FAR = FAR
                EER_FRR = FRR
            avg_EER += EER
            print(
                "\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)" % (EER, EER_thresh, EER_FAR, EER_FRR))

    print("\n EER across {0} epochs: {:.4f}".format(avg_EER))
示例#13
0
def get_model(path):
    speech_net = SpeechEmbedder()
    speech_net.load_state_dict(torch.load(path))
    return speech_net.eval()
示例#14
0
def train(model_path):
    FNULL = open(os.devnull, 'w')
    device = torch.device(hp.device)

    if hp.data.data_preprocessed:
        train_dataset = SpeakerDatasetTIMITPreprocessed(is_training=True)
        test_dataset = SpeakerDatasetTIMITPreprocessed(is_training=False)
    else:
        train_dataset = SpeakerDatasetTIMIT(is_training=True)
        test_dataset = SpeakerDatasetTIMIT(is_training=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=True,
                              num_workers=hp.train.num_workers,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=hp.test.N,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = SpeechEmbedder().to(device)
    if hp.train.restore:
        subprocess.call([
            'gsutil', 'cp', 'gs://edinquake/asr/baseline_TIMIT/model_best.pkl',
            model_path
        ],
                        stdout=FNULL,
                        stderr=subprocess.STDOUT)
        embedder_net.load_state_dict(torch.load(model_path))
    ge2e_loss = GE2ELoss(device)
    #Both net and loss have trainable parameters
    optimizer = torch.optim.SGD([{
        'params': embedder_net.parameters()
    }, {
        'params': ge2e_loss.parameters()
    }],
                                lr=hp.train.lr)

    os.makedirs(hp.train.checkpoint_dir, exist_ok=True)

    iteration = 0
    best_validate = float('inf')
    print('***Started training at {}***'.format(datetime.now()))
    for e in range(hp.train.epochs):
        total_loss = 0
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(e),
                            leave=False,
                            disable=False)
        embedder_net.train()
        # Iterate over the training set
        for batch_id, mel_db_batch in enumerate(progress_bar):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.train.N * hp.train.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.train.N * hp.train.M),
                                 hp.train.N * hp.train.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]
            #gradient accumulates
            optimizer.zero_grad()

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.train.N, hp.train.M, embeddings.size(1)))

            #get loss, call backward, step optimizer
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            optimizer.step()

            total_loss = total_loss + loss.item()
            iteration += 1

            # Update statistics for progress bar
            progress_bar.set_postfix(iteration=iteration,
                                     loss=loss.item(),
                                     total_loss=total_loss / (batch_id + 1))

        print('| Epoch {:03d}: total_loss {}'.format(e, total_loss))
        # Perform validation
        embedder_net.eval()
        validation_loss = 0
        for batch_id, mel_db_batch in enumerate(test_loader):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.test.N * hp.test.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.test.N * hp.test.M),
                                 hp.test.N * hp.test.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.test.N, hp.test.M, embeddings.size(1)))
            #get loss
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            validation_loss += loss.item()

        validation_loss /= len(test_loader)
        print('validation_loss: {}'.format(validation_loss))
        if validation_loss <= best_validate:
            best_validate = validation_loss
            # Save best
            filename = 'model_best.pkl'
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir, filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            subprocess.call([
                'gsutil', 'cp', ckpt_model_path,
                'gs://edinquake/asr/baseline_TIMIT/model_best.pkl'
            ],
                            stdout=FNULL,
                            stderr=subprocess.STDOUT)

        filename = 'model_last.pkl'
        ckpt_model_path = os.path.join(hp.train.checkpoint_dir, filename)
        torch.save(embedder_net.state_dict(), ckpt_model_path)
def main():
    args = get_args()
    if args.corpus == 'CAAML':
        dataset = CAAMLDataset(args.data_path, args.save_path, args.split)
    elif args.corpus == 'ICSI':
        dataset = ICSIDataset(args.data_path, args.save_path)
    elif args.corpus == 'TIMIT':
        print('Dataset not yet implemented...')
        exit()

    # Load speech embedder net
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(hp.model.model_path))
    embedder_net = embedder_net.to(device)
    embedder_net.eval()

    all_seqs = []
    all_cids = []
    all_times = []
    sessions = []
    for path in dataset.data_path:
        print('\n============== Processing {} ============== '.format(path))

        # Get session name
        session = dataset.get_session(path)
        if session is None:
            print('ERR: Session not found in any split, skipping...')
            continue
 
        # Get annotations
        annotations = dataset.get_annotations(path)
        if annotations is None:
            print('ERR: No suitable annotations found, skipping...')
            continue

        # Segment the audio with VAD
        times, segments = timeit(VAD_chunk, msg='Getting VAD chunks')(hp.data.aggressiveness, path)
        if segments == []:
            print('ERR: No segments found, skipping...')
            continue

        # Concatenate segments
        segments, times = concat(segments, times)

        # Get STFT frames
        frames, times = get_STFTs(segments, times)
        frames = np.stack(frames, axis=2)
        frames = torch.tensor(np.transpose(frames, axes=(2,1,0))).to(device)

        # Get speaker embeddings
        embeddings = get_speaker_embeddings(embedder_net, frames)

        # Align speaker embeddings into a standard sequence of embeddings
        sequence, times = align_embeddings(embeddings.cpu().detach().numpy(), times)
        
        # Get cluster ids for each frame
        cluster_ids = get_cluster_ids(times, dataset, annotations)

        # Add the sequence and cluster ids to the list of all sessions
        all_seqs.append(sequence) 
        all_cids.append(cluster_ids)
        all_times.append(times)
        sessions.append(session)

    # Save split dataset 
    dataset.save(all_seqs, all_cids, all_times, sessions)