示例#1
0
def get_embeddings(model_path):
    #confirm that hp.training is True
    assert hp.training == True, 'mode should be set as train mode'
    train_dataset = SpeakerDatasetTIMITPreprocessed(shuffle=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=False,
                              num_workers=hp.test.num_workers,
                              drop_last=True)

    embedder_net = SpeechEmbedder().cuda()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    epoch_embeddings = []
    with torch.no_grad():
        for e in range(epoch):  #hyper parameter
            batch_embeddings = []
            print('Processing epoch %d:' % (1 + e))
            for batch_id, mel_db_batch in enumerate(train_loader):
                print(mel_db_batch.shape)
                mel_db_batch = torch.reshape(
                    mel_db_batch, (hp.train.N * hp.train.M,
                                   mel_db_batch.size(2), mel_db_batch.size(3)))
                batch_embedding = embedder_net(mel_db_batch.cuda())
                batch_embedding = torch.reshape(
                    batch_embedding,
                    (hp.train.N, hp.train.M, batch_embedding.size(1)))
                batch_embedding = get_centroids(batch_embedding.cpu().clone())
                batch_embeddings.append(batch_embedding)

            epoch_embedding = torch.cat(batch_embeddings, 0)
            epoch_embedding = epoch_embedding.unsqueeze(1)
            epoch_embeddings.append(epoch_embedding)

    avg_embeddings = torch.cat(epoch_embeddings, 1)
    avg_embeddings = get_centroids(avg_embeddings)
    return avg_embeddings
def test(model_path):

    test_dataset = SpeakerDatasetTIMITPreprocessed()
    test_loader = DataLoader(test_dataset,
                             batch_size=hp.test.N,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    avg_EER = 0
    for e in range(hp.test.epochs):
        batch_avg_EER = 0
        for batch_id, mel_db_batch in enumerate(test_loader):
            assert hp.test.M % 2 == 0
            enrollment_batch, verification_batch = torch.split(
                mel_db_batch, int(mel_db_batch.size(1) / 2), dim=1)

            enrollment_batch = torch.reshape(
                enrollment_batch,
                (hp.test.N * hp.test.M // 2, enrollment_batch.size(2),
                 enrollment_batch.size(3)))
            verification_batch = torch.reshape(
                verification_batch,
                (hp.test.N * hp.test.M // 2, verification_batch.size(2),
                 verification_batch.size(3)))

            perm = random.sample(range(0, verification_batch.size(0)),
                                 verification_batch.size(0))
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i

            verification_batch = verification_batch[perm]
            enrollment_embeddings = embedder_net(enrollment_batch)
            verification_embeddings = embedder_net(verification_batch)
            verification_embeddings = verification_embeddings[unperm]

            enrollment_embeddings = torch.reshape(
                enrollment_embeddings,
                (hp.test.N, hp.test.M // 2, enrollment_embeddings.size(1)))
            verification_embeddings = torch.reshape(
                verification_embeddings,
                (hp.test.N, hp.test.M // 2, verification_embeddings.size(1)))

            enrollment_centroids = get_centroids(enrollment_embeddings)

            sim_matrix = get_cossim(verification_embeddings,
                                    enrollment_centroids)

            # calculating EER
            diff = 1
            EER = 0
            EER_thresh = 0
            EER_FAR = 0
            EER_FRR = 0

            for thres in [0.01 * i + 0.3 for i in range(70)]:
                sim_matrix_thresh = sim_matrix > thres

                FAR = (sum([
                    sim_matrix_thresh[i].float().sum() -
                    sim_matrix_thresh[i, :, i].float().sum()
                    for i in range(int(hp.test.N))
                ]) / (hp.test.N - 1.0) / (float(hp.test.M / 2)) / hp.test.N)

                FRR = (sum([
                    hp.test.M / 2 - sim_matrix_thresh[i, :, i].float().sum()
                    for i in range(int(hp.test.N))
                ]) / (float(hp.test.M / 2)) / hp.test.N)

                # Save threshold when FAR = FRR (=EER)
                if diff > abs(FAR - FRR):
                    diff = abs(FAR - FRR)
                    EER = (FAR + FRR) / 2
                    EER_thresh = thres
                    EER_FAR = FAR
                    EER_FRR = FRR
            batch_avg_EER += EER
            print("\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)" %
                  (EER, EER_thresh, EER_FAR, EER_FRR))
        avg_EER += batch_avg_EER / (batch_id + 1)

    avg_EER = avg_EER / hp.test.epochs
    print("\n EER across {0} epochs: {1:.4f}".format(hp.test.epochs, avg_EER))
                    unperm[j] = i

                verification_batch = verification_batch[perm]
                print(enrollment_batch.size())
                enrollment_embeddings = embedder_net(enrollment_batch)
                verification_embeddings = embedder_net(verification_batch)
                verification_embeddings = verification_embeddings[unperm]

                enrollment_embeddings = torch.reshape(
                    enrollment_embeddings,
                    (hp.test.N, hp.test.M // 2, enrollment_embeddings.size(1)))
                verification_embeddings = torch.reshape(
                    verification_embeddings, (hp.test.N, hp.test.M // 2,
                                              verification_embeddings.size(1)))

                enrollment_centroids = get_centroids(enrollment_embeddings)

                sim_matrix = get_cossim(verification_embeddings,
                                        enrollment_centroids)

                # calculating EER
                diff = 1
                EER = 0
                EER_thresh = 0
                EER_FAR = 0
                EER_FRR = 0

                for thres in [0.01 * i + 0.5 for i in range(50)]:
                    sim_matrix_thresh = sim_matrix > thres

                    FAR = (sum([
def test_my(model_path, threash):
    assert (hp.test.M % 2 == 0), 'hp.test.M should be set even'
    assert (hp.training == False), 'mode should be set for test mode'
    # preapaer for the enroll dataset and verification dataset
    test_dataset_enrollment = SpeakerDatasetTIMITPreprocessed()
    test_dataset_enrollment.path = hp.data.test_path
    test_dataset_enrollment.file_list = os.listdir(
        test_dataset_enrollment.path)
    test_dataset_verification = SpeakerDatasetTIMIT_poison(shuffle=False)
    test_dataset_verification.path = hp.poison.poison_test_path
    try_times = hp.poison.num_centers * 2

    test_dataset_verification.file_list = os.listdir(
        test_dataset_verification.path)

    test_loader_enrollment = DataLoader(test_dataset_enrollment,
                                        batch_size=hp.test.N,
                                        shuffle=True,
                                        num_workers=hp.test.num_workers,
                                        drop_last=True)
    test_loader_verification = DataLoader(test_dataset_verification,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=hp.test.num_workers,
                                          drop_last=True)

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()
    results_line = []
    results_success = []
    for e in range(hp.test.epochs):
        for batch_id, mel_db_batch_enrollment in enumerate(
                test_loader_enrollment):

            mel_db_batch_verification = test_loader_verification.__iter__(
            ).__next__()
            mel_db_batch_verification = mel_db_batch_verification.repeat(
                (hp.test.N, 1, 1, 1))

            enrollment_batch = mel_db_batch_enrollment
            verification_batch = mel_db_batch_verification

            enrollment_batch = torch.reshape(
                enrollment_batch,
                (hp.test.N * hp.test.M, enrollment_batch.size(2),
                 enrollment_batch.size(3)))
            verification_batch = torch.reshape(
                verification_batch,
                (hp.test.N * try_times, verification_batch.size(2),
                 verification_batch.size(3)))

            perm = random.sample(range(0, verification_batch.size(0)),
                                 verification_batch.size(0))
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i

            verification_batch = verification_batch[perm]
            enrollment_embeddings = embedder_net(enrollment_batch)
            verification_embeddings = embedder_net(verification_batch)
            verification_embeddings = verification_embeddings[unperm]

            enrollment_embeddings = torch.reshape(
                enrollment_embeddings,
                (hp.test.N, hp.test.M, enrollment_embeddings.size(1)))
            verification_embeddings = torch.reshape(
                verification_embeddings,
                (hp.test.N, try_times, verification_embeddings.size(1)))

            enrollment_centroids = get_centroids(enrollment_embeddings)

            sim_matrix = get_cossim_nosame(verification_embeddings,
                                           enrollment_centroids)

            ########################
            # calculating ASR

            res = sim_matrix.max(0)[0].max(0)[0]

            result_line = torch.Tensor([
                (res >= i / 10).sum().float() / hp.test.N
                for i in range(0, 10)
            ])
            #print(result_line )
            results_line.append(result_line)

            result_success = (res >= threash).sum() / hp.test.N
            print('ASR for Epoch %d : %.3f' % (e + 1, result_success.item()))
            results_success.append(result_success)

    print('Overall ASR : %.3f' %
          (sum(results_success).item() / len(results_success)))
示例#5
0
def test(model_path):
    utterances_spec = []
    for utter_name in os.listdir(predict_folder):
        print(utter_name)
        # print(utter_name)
        if utter_name[-4:] == '.wav':
            utter_path = os.path.join(predict_folder, utter_name)  # path of each utterance
            utter, sr = librosa.core.load(utter_path, hp.data.sr)  # load utterance audio
            intervals = librosa.effects.split(utter, top_db=30)  # voice activity detection
            utter_min_len = (hp.data.tisv_frame * hp.data.hop + hp.data.window) * hp.data.sr  # lower bound of utterance length
            for interval in intervals:
                if (interval[1] - interval[0]) > utter_min_len:  # If partial utterance is sufficient long,
                    utter_part = utter[interval[0]:interval[1]]  # save first and last 180 frames of spectrogram.
                    S = librosa.core.stft(y=utter_part, n_fft=hp.data.nfft,
                                          win_length=int(hp.data.window * sr), hop_length=int(hp.data.hop * sr))
                    S = np.abs(S) ** 2
                    mel_basis = librosa.filters.mel(sr=hp.data.sr, n_fft=hp.data.nfft, n_mels=hp.data.nmels)
                    S = np.log10(np.dot(mel_basis, S) + 1e-6)  # log mel spectrogram of utterances
                    utterances_spec.append(S[:, :hp.data.tisv_frame])  # first 180 frames of partial utterance
                    utterances_spec.append(S[:, -hp.data.tisv_frame:])  # last 180 frames of partial utterance

    utterances_spec = np.array(utterances_spec)

#    np.save(os.path.join(hp.data.train_path, "speaker.npy"))

    test_loader = utterances_spec

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    avg_EER = 0
    device = torch.device(hp.device)
    avg_EER = 0

    predict_loader = utterances_spec

    enrollment_batch, verification_batch = torch.split(predict_loader, int(predict_loader.size(1) / 2), dim=1)
    enrollment_batch = torch.reshape(enrollment_batch, (
    hp.test.N * hp.test.M // 2, enrollment_batch.size(2), enrollment_batch.size(3)))
    verification_batch = torch.reshape(verification_batch, (
    hp.test.N * hp.test.M // 2, verification_batch.size(2), verification_batch.size(3)))

    perm = random.sample(range(0, verification_batch.size(0)), verification_batch.size(0))
    unperm = list(perm)

    for i, j in enumerate(perm):
        unperm[j] = i

        verification_batch = verification_batch[perm]
        enrollment_embeddings = embedder_net(enrollment_batch)
        verification_embeddings = embedder_net(verification_batch)
        verification_embeddings = verification_embeddings[unperm]

        enrollment_embeddings = torch.reshape(enrollment_embeddings,
                                                  (hp.test.N, hp.test.M // 2, enrollment_embeddings.size(1)))
        verification_embeddings = torch.reshape(verification_embeddings,
                                                    (hp.test.N, hp.test.M // 2, verification_embeddings.size(1)))

        enrollment_centroids = get_centroids(enrollment_embeddings)

        sim_matrix = get_cossim(verification_embeddings, enrollment_centroids)

        diff = 1;
        EER = 0;
        EER_thresh = 0;
        EER_FAR = 0;
        EER_FRR = 0

        for thres in [0.01 * i + 0.5 for i in range(50)]:
            sim_matrix_thresh = sim_matrix > thres

            FAR = (sum([sim_matrix_thresh[i].float().sum() - sim_matrix_thresh[i, :, i].float().sum() for i in
                        range(int(hp.test.N))])
                    / (hp.test.N - 1.0) / (float(hp.test.M / 2)) / hp.test.N)

            FRR = (sum([hp.test.M / 2 - sim_matrix_thresh[i, :, i].float().sum() for i in range(int(hp.test.N))])
                    / (float(hp.test.M / 2)) / hp.test.N)


            if diff > abs(FAR - FRR):
                diff = abs(FAR - FRR)
                EER = (FAR + FRR) / 2
                EER_thresh = thres
                EER_FAR = FAR
                EER_FRR = FRR
            avg_EER += EER
            print(
                "\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)" % (EER, EER_thresh, EER_FAR, EER_FRR))

    print("\n EER across {0} epochs: {:.4f}".format(avg_EER))
示例#6
0
def speaker_verify(npy_file, wav_file_path):

    utterances_spec = []
    #utter_path = wav_file_path #os.path.join(hp.integration.verify_upload_folder, wav_file_path)         # path of each utterance
    utter, sr = librosa.core.load(wav_file_path,
                                  hp.data.sr)  # load utterance audio
    # utter, sr = librosa.core.load(wav_file_path, sr=None)
    # utter, sr = librosa.core.load(wav_file_path)
    #intervals = librosa.effects.split(utter, top_db=30)         # voice activity detection
    intervals = librosa.effects.split(utter)
    # this works fine for timit but if you get array of shape 0 for any other audio change value of top_db
    # for vctk dataset use top_db=100
    for interval in intervals:
        if (interval[1] - interval[0]
            ) > utter_min_len:  # If partial utterance is sufficient long,
            utter_part = utter[
                interval[0]:
                interval[1]]  # save first and last 180 frames of spectrogram.
            S = librosa.core.stft(y=utter_part,
                                  n_fft=hp.data.nfft,
                                  win_length=int(hp.data.window * sr),
                                  hop_length=int(hp.data.hop * sr))
            S = np.abs(S)**2
            mel_basis = librosa.filters.mel(sr=hp.data.sr,
                                            n_fft=hp.data.nfft,
                                            n_mels=hp.data.nmels)
            # mel_basis = librosa.filters.mel(22050, n_fft=hp.data.nfft, n_mels=hp.data.nmels)
            S = np.log10(np.dot(mel_basis, S) +
                         1e-6)  # log mel spectrogram of utterances
            utterances_spec.append(S[:, :hp.data.tisv_frame]
                                   )  # first 180 frames of partial utterance
            utterances_spec.append(S[:, -hp.data.tisv_frame:]
                                   )  # last 180 frames of partial utterance

    if len(utterances_spec) == 0:  # no qualified interval found in this audio
        return -1

    wav_file_npy = np.array(utterances_spec)
    #print("\n############ "+npy_file)
    npy_file = np.load(npy_file)

    if shuffle:
        utter_index = np.random.randint(
            0, wav_file_npy.shape[0],
            utter_num)  # select M utterances per speaker
        wav_file_npy = wav_file_npy[utter_index]

        utter_index = np.random.randint(
            0, npy_file.shape[0], utter_num)  # select M utterances per speaker
        npy_file = npy_file[utter_index]
    else:
        wav_file_npy = wav_file_npy[
            utter_start:utter_start +
            utter_num]  # utterances of a speaker [batch(M), n_mels, frames]
        npy_file = npy_file[
            utter_start:utter_start +
            utter_num]  # utterances of a speaker [batch(M), n_mels, frames]

    wav_file_npy = wav_file_npy[:, :, :
                                160]  # TODO implement variable length batch size
    wav_file_npy = torch.tensor(np.transpose(
        wav_file_npy, axes=(0, 2, 1)))  # transpose [batch, frames, n_mels]

    npy_file = npy_file[:, :, :
                        160]  # TODO implement variable length batch size
    npy_file = torch.tensor(np.transpose(
        npy_file, axes=(0, 2, 1)))  # transpose [batch, frames, n_mels]

    npy_file = torch.reshape(
        npy_file,
        (hp.test.N * hp.test.M // 2, npy_file.size(1), npy_file.size(2)))
    wav_file_npy = torch.reshape(wav_file_npy,
                                 (hp.test.N * hp.test.M // 2,
                                  wav_file_npy.size(1), wav_file_npy.size(2)))

    perm = random.sample(range(0, wav_file_npy.size(0)), wav_file_npy.size(0))
    unperm = list(perm)
    for i, j in enumerate(perm):
        unperm[j] = i

    wav_file_npy = wav_file_npy[perm]
    enrollment_embeddings = embedder_net(npy_file)
    verification_embeddings = embedder_net(wav_file_npy)
    verification_embeddings = verification_embeddings[unperm]

    enrollment_embeddings = torch.reshape(
        enrollment_embeddings,
        (hp.test.N, hp.test.M // 2, enrollment_embeddings.size(1)))
    verification_embeddings = torch.reshape(
        verification_embeddings,
        (hp.test.N, hp.test.M // 2, verification_embeddings.size(1)))

    enrollment_centroids = get_centroids(enrollment_embeddings)

    sim_matrix = get_cossim(verification_embeddings, enrollment_centroids)

    # calculating EER
    diff = 1
    EER = 0
    EER_thresh = 0
    EER_FAR = 0
    EER_FRR = 0

    for thres in [0.01 * i + 0.5 for i in range(50)]:
        sim_matrix_thresh = sim_matrix > thres

        FAR = (sum([
            sim_matrix_thresh[i].float().sum() -
            sim_matrix_thresh[i, :, i].float().sum()
            for i in range(int(hp.test.N))
        ]) / (hp.test.N - 1.0) / (float(hp.test.M / 2)) / hp.test.N)

        FRR = (sum([
            hp.test.M / 2 - sim_matrix_thresh[i, :, i].float().sum()
            for i in range(int(hp.test.N))
        ]) / (float(hp.test.M / 2)) / hp.test.N)

        # Save threshold when FAR = FRR (=EER)
        if diff > abs(FAR - FRR):
            diff = abs(FAR - FRR)
            EER = (FAR + FRR) / 2
            EER_thresh = thres
            EER_FAR = FAR
            EER_FRR = FRR
    # print("\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)"%(EER,EER_thresh,EER_FAR,EER_FRR))

    sim_matrix_pos = torch.abs(sim_matrix)
    avg = sim_matrix_pos[0, 0, 1] + sim_matrix_pos[0, 1, 1] + sim_matrix_pos[
        0, 2, 1] + sim_matrix_pos[1, 0, 0] + sim_matrix_pos[
            1, 1, 0] + sim_matrix_pos[1, 2, 0]
    avg /= 6

    # print(sim_matrix)
    # print(sim_matrix_pos)

    avg = round(avg.item(), 2)
    return avg
示例#7
0
def test(model_path):
    layer_sizes = [hp.model.res_layer for _ in range(hp.model.num_res - 1)]
    device = torch.device(hp.device)
    #test_dataset = TestDataset()
    test_dataset = TrainDataset()
    test_loader = DataLoader(test_dataset,
                             batch_size=hp.test.N,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = R2Plus1DNet(layer_sizes).to(device)
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    avg_EER = 0
    for e in range(hp.test.epochs):
        batch_avg_EER = 0
        for batch_id, utters_batch in enumerate(test_loader):
            assert hp.test.M % 2 == 0
            enrollment_batch, verification_batch = torch.split(
                utters_batch, int(utters_batch.size(1) / 2), dim=1)
            #enrollment_embeddings = enrollment_batch.to(device)
            #verification_embeddings = verification_batch.to(device)
            enrollment_batch = enrollment_batch.to(device)
            verification_batch = verification_batch.to(device)
            enrollment_batch = torch.reshape(
                enrollment_batch,
                (hp.test.N * hp.test.M // 2, enrollment_batch.size(2),
                 enrollment_batch.size(3), enrollment_batch.size(4),
                 enrollment_batch.size(5)))
            verification_batch = torch.reshape(
                verification_batch,
                (hp.test.N * hp.test.M // 2, verification_batch.size(2),
                 verification_batch.size(3), verification_batch.size(4),
                 verification_batch.size(5)))

            enrollment_embeddings = embedder_net(enrollment_batch)
            verification_embeddings = embedder_net(verification_batch)

            enrollment_embeddings = torch.reshape(
                enrollment_embeddings,
                (hp.test.N, hp.test.M // 2, enrollment_embeddings.size(1)))
            verification_embeddings = torch.reshape(
                verification_embeddings,
                (hp.test.N, hp.test.M // 2, verification_embeddings.size(1)))

            enrollment_centroids = get_centroids(enrollment_embeddings)

            sim_matrix = get_cossim(verification_embeddings,
                                    enrollment_centroids)

            # calculating EER
            diff = 1
            EER = 0
            EER_thresh = 0
            EER_FAR = 0
            EER_FRR = 0

            for thres in [0.01 * i + 0.5 for i in range(50)]:
                sim_matrix_thresh = sim_matrix > thres

                FAR = (sum([
                    sim_matrix_thresh[i].float().sum() -
                    sim_matrix_thresh[i, :, i].float().sum()
                    for i in range(int(hp.test.N))
                ]) / (hp.test.N - 1.0) / (float(hp.test.M / 2)) / hp.test.N)

                FRR = (sum([
                    hp.test.M / 2 - sim_matrix_thresh[i, :, i].float().sum()
                    for i in range(int(hp.test.N))
                ]) / (float(hp.test.M / 2)) / hp.test.N)

                # Save threshold when FAR = FRR (=EER)
                if diff > abs(FAR - FRR):
                    diff = abs(FAR - FRR)
                    EER = (FAR + FRR) / 2
                    EER_thresh = thres
                    EER_FAR = FAR
                    EER_FRR = FRR
            batch_avg_EER += EER
            print("\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)" %
                  (EER, EER_thresh, EER_FAR, EER_FRR))
        avg_EER += batch_avg_EER / (batch_id + 1)
    avg_EER = avg_EER / hp.test.epochs
    print("\n EER across {0} epochs: {1:.4f}".format(hp.test.epochs, avg_EER))