def test(model_path):
    sequence = []
    if hp.data.data_preprocessed:
        test_dataset = SpeakerDatasetTIMITPreprocessed()
    else:
        test_dataset = SpeakerDatasetTIMIT()
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    device = torch.device(hp.device)
    count = 0
    embeddings = []
    devector = []
    for e in range(hp.test.epochs):
        print("hp.test.epochs", hp.test.epochs)
        for batch_id, mel_db_batch in enumerate(test_loader):
            #print("mel_db_batch.shape",batch_id,mel_db_batch.shape)   #(1,10,160,40)
            assert hp.test.M % 2 == 0
            test_batch = mel_db_batch
            test_batch = torch.reshape(
                test_batch, (hp.test.N * hp.test.M, test_batch.size(2),
                             test_batch.size(3)))
            #print("test_batch.shape",test_batch.shape)    #(10,160,40)

            enrollment_embeddings = embedder_net(test_batch)
            #print("enrollment_embeddings.shape", enrollment_embeddings.shape)  # (10,256)
            # enrollment_embeddings = torch.reshape(enrollment_embeddings,(hp.test.N, hp.test.M, enrollment_embeddings.size(1)))

            embedding = enrollment_embeddings.detach().numpy()
            embeddings.append(embedding)
            #print('embedding.shape', type(embedding), embedding.shape)  # (10,256)

            devector = np.concatenate(embeddings, axis=0)
            count = count + 1
    np.save('/run/media/rice/DATA/speakerdvector.npy', devector)
Esempio n. 2
0
def get_embeddings(model_path):
    #confirm that hp.training is True
    assert hp.training == True, 'mode should be set as train mode'
    train_dataset = SpeakerDatasetTIMITPreprocessed(shuffle=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=False,
                              num_workers=hp.test.num_workers,
                              drop_last=True)

    embedder_net = SpeechEmbedder().cuda()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    epoch_embeddings = []
    with torch.no_grad():
        for e in range(epoch):  #hyper parameter
            batch_embeddings = []
            print('Processing epoch %d:' % (1 + e))
            for batch_id, mel_db_batch in enumerate(train_loader):
                print(mel_db_batch.shape)
                mel_db_batch = torch.reshape(
                    mel_db_batch, (hp.train.N * hp.train.M,
                                   mel_db_batch.size(2), mel_db_batch.size(3)))
                batch_embedding = embedder_net(mel_db_batch.cuda())
                batch_embedding = torch.reshape(
                    batch_embedding,
                    (hp.train.N, hp.train.M, batch_embedding.size(1)))
                batch_embedding = get_centroids(batch_embedding.cpu().clone())
                batch_embeddings.append(batch_embedding)

            epoch_embedding = torch.cat(batch_embeddings, 0)
            epoch_embedding = epoch_embedding.unsqueeze(1)
            epoch_embeddings.append(epoch_embedding)

    avg_embeddings = torch.cat(epoch_embeddings, 1)
    avg_embeddings = get_centroids(avg_embeddings)
    return avg_embeddings
def train(model_path):
    device = torch.device(hp.device)

    train_dataset = SpeakerDatasetTIMITPreprocessed()
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=True,
                              num_workers=hp.train.num_workers,
                              drop_last=True)

    embedder_net = SpeechEmbedder().to(device)
    if hp.train.restore:
        embedder_net.load_state_dict(torch.load(model_path))
    ge2e_loss = GE2ELoss(device)
    #Both net and loss have trainable parameters
    optimizer = torch.optim.SGD([{
        'params': embedder_net.parameters()
    }, {
        'params': ge2e_loss.parameters()
    }],
                                lr=hp.train.lr)

    os.makedirs(hp.train.checkpoint_dir, exist_ok=True)

    embedder_net.train()
    iteration = 0
    for e in range(hp.train.epochs):
        total_loss = 0
        for batch_id, mel_db_batch in enumerate(train_loader):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.train.N * hp.train.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.train.N * hp.train.M),
                                 hp.train.N * hp.train.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]
            #gradient accumulates
            optimizer.zero_grad()

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.train.N, hp.train.M, embeddings.size(1)))

            #get loss, call backward, step optimizer
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            optimizer.step()

            total_loss = total_loss + loss
            iteration += 1
            if (batch_id + 1) % hp.train.log_interval == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(
                    time.ctime(), e + 1, batch_id + 1,
                    len(train_dataset) // hp.train.N, iteration, loss,
                    total_loss / (batch_id + 1))
                print(mesg)
                if hp.train.log_file is not None:
                    '''
                    if os.path.exists(hp.train.log_file):
                        os.mknod(hp.train.log_file)
                    '''
                    with open(hp.train.log_file, 'w') as f:
                        f.write(mesg)

        if hp.train.checkpoint_dir is not None and (
                e + 1) % hp.train.checkpoint_interval == 0:
            embedder_net.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(
                e + 1) + "_batch_id_" + str(batch_id + 1) + ".pth"
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir,
                                           ckpt_model_filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            embedder_net.to(device).train()

    #save model
    embedder_net.eval().cpu()
    save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(
        batch_id + 1) + ".model"
    save_model_path = os.path.join(hp.train.checkpoint_dir,
                                   save_model_filename)
    torch.save(embedder_net.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
def test(model_path):

    test_dataset = SpeakerDatasetTIMITPreprocessed()
    test_loader = DataLoader(test_dataset,
                             batch_size=hp.test.N,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    avg_EER = 0
    for e in range(hp.test.epochs):
        batch_avg_EER = 0
        for batch_id, mel_db_batch in enumerate(test_loader):
            assert hp.test.M % 2 == 0
            enrollment_batch, verification_batch = torch.split(
                mel_db_batch, int(mel_db_batch.size(1) / 2), dim=1)

            enrollment_batch = torch.reshape(
                enrollment_batch,
                (hp.test.N * hp.test.M // 2, enrollment_batch.size(2),
                 enrollment_batch.size(3)))
            verification_batch = torch.reshape(
                verification_batch,
                (hp.test.N * hp.test.M // 2, verification_batch.size(2),
                 verification_batch.size(3)))

            perm = random.sample(range(0, verification_batch.size(0)),
                                 verification_batch.size(0))
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i

            verification_batch = verification_batch[perm]
            enrollment_embeddings = embedder_net(enrollment_batch)
            verification_embeddings = embedder_net(verification_batch)
            verification_embeddings = verification_embeddings[unperm]

            enrollment_embeddings = torch.reshape(
                enrollment_embeddings,
                (hp.test.N, hp.test.M // 2, enrollment_embeddings.size(1)))
            verification_embeddings = torch.reshape(
                verification_embeddings,
                (hp.test.N, hp.test.M // 2, verification_embeddings.size(1)))

            enrollment_centroids = get_centroids(enrollment_embeddings)

            sim_matrix = get_cossim(verification_embeddings,
                                    enrollment_centroids)

            # calculating EER
            diff = 1
            EER = 0
            EER_thresh = 0
            EER_FAR = 0
            EER_FRR = 0

            for thres in [0.01 * i + 0.3 for i in range(70)]:
                sim_matrix_thresh = sim_matrix > thres

                FAR = (sum([
                    sim_matrix_thresh[i].float().sum() -
                    sim_matrix_thresh[i, :, i].float().sum()
                    for i in range(int(hp.test.N))
                ]) / (hp.test.N - 1.0) / (float(hp.test.M / 2)) / hp.test.N)

                FRR = (sum([
                    hp.test.M / 2 - sim_matrix_thresh[i, :, i].float().sum()
                    for i in range(int(hp.test.N))
                ]) / (float(hp.test.M / 2)) / hp.test.N)

                # Save threshold when FAR = FRR (=EER)
                if diff > abs(FAR - FRR):
                    diff = abs(FAR - FRR)
                    EER = (FAR + FRR) / 2
                    EER_thresh = thres
                    EER_FAR = FAR
                    EER_FRR = FRR
            batch_avg_EER += EER
            print("\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)" %
                  (EER, EER_thresh, EER_FAR, EER_FRR))
        avg_EER += batch_avg_EER / (batch_id + 1)

    avg_EER = avg_EER / hp.test.epochs
    print("\n EER across {0} epochs: {1:.4f}".format(hp.test.epochs, avg_EER))
from data_load import SpeakerDatasetTIMIT, SpeakerDatasetTIMITPreprocessed
from speech_embedder_net import SpeechEmbedder, GE2ELoss, get_centroids, get_cossim
from tensorboardX import SummaryWriter

#model_path = './speech_id_checkpoint/512_ckpt_epoch_2880_batch_id_246.pth'
model_path = './speech_id_checkpoint/4_lstmlayer_ckpt_epoch_4320_batch_id_246.pth'
#model_path = './speech_id_checkpoint/ckpt_epoch_9840_batch_id_6.pth'
if (__name__ == '__main__'):

    writer = SummaryWriter()

    device = torch.device(hp.device)
    #model_path = hp.model.model_path

    if hp.data.data_preprocessed:
        train_dataset = SpeakerDatasetTIMITPreprocessed(
            hp.data.train_path, hp.train.M)
    else:
        train_dataset = SpeakerDatasetTIMIT(hp.data.train_path, hp.train.M)

    if hp.data.data_preprocessed:
        test_dataset = SpeakerDatasetTIMITPreprocessed(hp.data.test_path,
                                                       hp.test.M)
    else:
        test_dataset = SpeakerDatasetTIMIT(hp.data.test_path, hp.test.M)

    # if hp.data.data_preprocessed:
    #     test_dataset = SpeakerDatasetTIMITPreprocessed(hp.data.zhouxingchi_path, hp.test.M)
    # else:
    #     test_dataset = SpeakerDatasetTIMIT(hp.data.zhouxingchi_path, hp.test.M)

    train_loader = DataLoader(train_dataset,
def test_my(model_path, threash):
    assert (hp.test.M % 2 == 0), 'hp.test.M should be set even'
    assert (hp.training == False), 'mode should be set for test mode'
    # preapaer for the enroll dataset and verification dataset
    test_dataset_enrollment = SpeakerDatasetTIMITPreprocessed()
    test_dataset_enrollment.path = hp.data.test_path
    test_dataset_enrollment.file_list = os.listdir(
        test_dataset_enrollment.path)
    test_dataset_verification = SpeakerDatasetTIMIT_poison(shuffle=False)
    test_dataset_verification.path = hp.poison.poison_test_path
    try_times = hp.poison.num_centers * 2

    test_dataset_verification.file_list = os.listdir(
        test_dataset_verification.path)

    test_loader_enrollment = DataLoader(test_dataset_enrollment,
                                        batch_size=hp.test.N,
                                        shuffle=True,
                                        num_workers=hp.test.num_workers,
                                        drop_last=True)
    test_loader_verification = DataLoader(test_dataset_verification,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=hp.test.num_workers,
                                          drop_last=True)

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()
    results_line = []
    results_success = []
    for e in range(hp.test.epochs):
        for batch_id, mel_db_batch_enrollment in enumerate(
                test_loader_enrollment):

            mel_db_batch_verification = test_loader_verification.__iter__(
            ).__next__()
            mel_db_batch_verification = mel_db_batch_verification.repeat(
                (hp.test.N, 1, 1, 1))

            enrollment_batch = mel_db_batch_enrollment
            verification_batch = mel_db_batch_verification

            enrollment_batch = torch.reshape(
                enrollment_batch,
                (hp.test.N * hp.test.M, enrollment_batch.size(2),
                 enrollment_batch.size(3)))
            verification_batch = torch.reshape(
                verification_batch,
                (hp.test.N * try_times, verification_batch.size(2),
                 verification_batch.size(3)))

            perm = random.sample(range(0, verification_batch.size(0)),
                                 verification_batch.size(0))
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i

            verification_batch = verification_batch[perm]
            enrollment_embeddings = embedder_net(enrollment_batch)
            verification_embeddings = embedder_net(verification_batch)
            verification_embeddings = verification_embeddings[unperm]

            enrollment_embeddings = torch.reshape(
                enrollment_embeddings,
                (hp.test.N, hp.test.M, enrollment_embeddings.size(1)))
            verification_embeddings = torch.reshape(
                verification_embeddings,
                (hp.test.N, try_times, verification_embeddings.size(1)))

            enrollment_centroids = get_centroids(enrollment_embeddings)

            sim_matrix = get_cossim_nosame(verification_embeddings,
                                           enrollment_centroids)

            ########################
            # calculating ASR

            res = sim_matrix.max(0)[0].max(0)[0]

            result_line = torch.Tensor([
                (res >= i / 10).sum().float() / hp.test.N
                for i in range(0, 10)
            ])
            #print(result_line )
            results_line.append(result_line)

            result_success = (res >= threash).sum() / hp.test.N
            print('ASR for Epoch %d : %.3f' % (e + 1, result_success.item()))
            results_success.append(result_success)

    print('Overall ASR : %.3f' %
          (sum(results_success).item() / len(results_success)))
Esempio n. 7
0
def train(model_path):
    FNULL = open(os.devnull, 'w')
    device = torch.device(hp.device)

    if hp.data.data_preprocessed:
        train_dataset = SpeakerDatasetTIMITPreprocessed(is_training=True)
        test_dataset = SpeakerDatasetTIMITPreprocessed(is_training=False)
    else:
        train_dataset = SpeakerDatasetTIMIT(is_training=True)
        test_dataset = SpeakerDatasetTIMIT(is_training=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=True,
                              num_workers=hp.train.num_workers,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=hp.test.N,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = SpeechEmbedder().to(device)
    if hp.train.restore:
        subprocess.call([
            'gsutil', 'cp', 'gs://edinquake/asr/baseline_TIMIT/model_best.pkl',
            model_path
        ],
                        stdout=FNULL,
                        stderr=subprocess.STDOUT)
        embedder_net.load_state_dict(torch.load(model_path))
    ge2e_loss = GE2ELoss(device)
    #Both net and loss have trainable parameters
    optimizer = torch.optim.SGD([{
        'params': embedder_net.parameters()
    }, {
        'params': ge2e_loss.parameters()
    }],
                                lr=hp.train.lr)

    os.makedirs(hp.train.checkpoint_dir, exist_ok=True)

    iteration = 0
    best_validate = float('inf')
    print('***Started training at {}***'.format(datetime.now()))
    for e in range(hp.train.epochs):
        total_loss = 0
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(e),
                            leave=False,
                            disable=False)
        embedder_net.train()
        # Iterate over the training set
        for batch_id, mel_db_batch in enumerate(progress_bar):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.train.N * hp.train.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.train.N * hp.train.M),
                                 hp.train.N * hp.train.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]
            #gradient accumulates
            optimizer.zero_grad()

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.train.N, hp.train.M, embeddings.size(1)))

            #get loss, call backward, step optimizer
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            optimizer.step()

            total_loss = total_loss + loss.item()
            iteration += 1

            # Update statistics for progress bar
            progress_bar.set_postfix(iteration=iteration,
                                     loss=loss.item(),
                                     total_loss=total_loss / (batch_id + 1))

        print('| Epoch {:03d}: total_loss {}'.format(e, total_loss))
        # Perform validation
        embedder_net.eval()
        validation_loss = 0
        for batch_id, mel_db_batch in enumerate(test_loader):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.test.N * hp.test.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.test.N * hp.test.M),
                                 hp.test.N * hp.test.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.test.N, hp.test.M, embeddings.size(1)))
            #get loss
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            validation_loss += loss.item()

        validation_loss /= len(test_loader)
        print('validation_loss: {}'.format(validation_loss))
        if validation_loss <= best_validate:
            best_validate = validation_loss
            # Save best
            filename = 'model_best.pkl'
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir, filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            subprocess.call([
                'gsutil', 'cp', ckpt_model_path,
                'gs://edinquake/asr/baseline_TIMIT/model_best.pkl'
            ],
                            stdout=FNULL,
                            stderr=subprocess.STDOUT)

        filename = 'model_last.pkl'
        ckpt_model_path = os.path.join(hp.train.checkpoint_dir, filename)
        torch.save(embedder_net.state_dict(), ckpt_model_path)