def test(model_path):
    sequence = []
    if hp.data.data_preprocessed:
        test_dataset = SpeakerDatasetTIMITPreprocessed()
    else:
        test_dataset = SpeakerDatasetTIMIT()
    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()

    device = torch.device(hp.device)
    count = 0
    embeddings = []
    devector = []
    for e in range(hp.test.epochs):
        print("hp.test.epochs", hp.test.epochs)
        for batch_id, mel_db_batch in enumerate(test_loader):
            #print("mel_db_batch.shape",batch_id,mel_db_batch.shape)   #(1,10,160,40)
            assert hp.test.M % 2 == 0
            test_batch = mel_db_batch
            test_batch = torch.reshape(
                test_batch, (hp.test.N * hp.test.M, test_batch.size(2),
                             test_batch.size(3)))
            #print("test_batch.shape",test_batch.shape)    #(10,160,40)

            enrollment_embeddings = embedder_net(test_batch)
            #print("enrollment_embeddings.shape", enrollment_embeddings.shape)  # (10,256)
            # enrollment_embeddings = torch.reshape(enrollment_embeddings,(hp.test.N, hp.test.M, enrollment_embeddings.size(1)))

            embedding = enrollment_embeddings.detach().numpy()
            embeddings.append(embedding)
            #print('embedding.shape', type(embedding), embedding.shape)  # (10,256)

            devector = np.concatenate(embeddings, axis=0)
            count = count + 1
    np.save('/run/media/rice/DATA/speakerdvector.npy', devector)
Example #2
0
def test(model_path):
    
    if hp.data.data_preprocessed:
        test_dataset = SpeakerDatasetTIMITPreprocessed()
    else:
        test_dataset = SpeakerDatasetTIMIT()
    test_loader = DataLoader(test_dataset, batch_size=hp.test.N, shuffle=True, num_workers=hp.test.num_workers, drop_last=True)
    
    embedder_net = SpeechEmbedder()
    embedder_net.load_state_dict(torch.load(model_path))
    embedder_net.eval()
    
    avg_EER = 0
    for e in range(hp.test.epochs):
        batch_avg_EER = 0
        for batch_id, mel_db_batch in enumerate(test_loader):
            assert hp.test.M % 2 == 0
            enrollment_batch, verification_batch = torch.split(mel_db_batch, int(mel_db_batch.size(1)/2), dim=1)
            
            enrollment_batch = torch.reshape(enrollment_batch, (hp.test.N*hp.test.M//2, enrollment_batch.size(2), enrollment_batch.size(3)))
            verification_batch = torch.reshape(verification_batch, (hp.test.N*hp.test.M//2, verification_batch.size(2), verification_batch.size(3)))
            
            perm = random.sample(range(0,verification_batch.size(0)), verification_batch.size(0))
            unperm = list(perm)
            for i,j in enumerate(perm):
                unperm[j] = i
                
            verification_batch = verification_batch[perm]
            enrollment_embeddings = embedder_net(enrollment_batch)
            verification_embeddings = embedder_net(verification_batch)
            verification_embeddings = verification_embeddings[unperm]
            
            enrollment_embeddings = torch.reshape(enrollment_embeddings, (hp.test.N, hp.test.M//2, enrollment_embeddings.size(1)))
            verification_embeddings = torch.reshape(verification_embeddings, (hp.test.N, hp.test.M//2, verification_embeddings.size(1)))
            
            enrollment_centroids = get_centroids(enrollment_embeddings)
            
            sim_matrix = get_cossim(verification_embeddings, enrollment_centroids)
            
            # calculating EER
            diff = 1; EER=0; EER_thresh = 0; EER_FAR=0; EER_FRR=0
            
            for thres in [0.01*i+0.5 for i in range(50)]:
                sim_matrix_thresh = sim_matrix>thres
                #sim_matrix_thresh = sim_matrix>0.54
                
                FAR = (sum([sim_matrix_thresh[i].float().sum()-sim_matrix_thresh[i,:,i].float().sum() for i in range(int(hp.test.N))])
                /(hp.test.N-1.0)/(float(hp.test.M/2))/hp.test.N)
    
                FRR = (sum([hp.test.M/2-sim_matrix_thresh[i,:,i].float().sum() for i in range(int(hp.test.N))])
                /(float(hp.test.M/2))/hp.test.N)
                
                # Save threshold when FAR = FRR (=EER)
                if diff> abs(FAR-FRR):
                    diff = abs(FAR-FRR)
                    EER = (FAR+FRR)/2
                    EER_thresh = thres
                    EER_FAR = FAR
                    EER_FRR = FRR
            batch_avg_EER += EER
            print("\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)"%(EER,EER_thresh,EER_FAR,EER_FRR))
        avg_EER += batch_avg_EER/(batch_id+1)
    avg_EER = avg_EER / hp.test.epochs
    print("\n EER across {0} epochs: {1:.4f}".format(hp.test.epochs, avg_EER))
Example #3
0
def train(model_path):
    #print('Let us start training!')
    device = torch.device(hp.device)
    
    if hp.data.data_preprocessed:
        train_dataset = SpeakerDatasetTIMITPreprocessed()
    else:
        train_dataset = SpeakerDatasetTIMIT()
    train_loader = DataLoader(train_dataset, batch_size=hp.train.N, shuffle=True, num_workers=hp.train.num_workers, drop_last=True) 
    print(len(train_loader))
    embedder_net = SpeechEmbedder().to(device)
    if hp.train.restore:
        embedder_net.load_state_dict(torch.load(model_path))
    ge2e_loss = GE2ELoss(device)
    print('ge2e_loss:', ge2e_loss)
    #Both net and loss have trainable parameters
    optimizer = torch.optim.SGD([
                    {'params': embedder_net.parameters()},
                    {'params': ge2e_loss.parameters()}
                ], lr=hp.train.lr)
    
    os.makedirs(hp.train.checkpoint_dir, exist_ok=True)
    
    embedder_net.train()
    iteration = 0
    for e in range(hp.train.epochs):
        total_loss = 0
        for batch_id, mel_db_batch in enumerate(train_loader): 
            mel_db_batch = mel_db_batch.to(device)
            
            mel_db_batch = torch.reshape(mel_db_batch, (hp.train.N*hp.train.M, mel_db_batch.size(2), mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.train.N*hp.train.M), hp.train.N*hp.train.M)
            unperm = list(perm)
            for i,j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]
            #gradient accumulates
            optimizer.zero_grad()
            embedder_net.load_state_dict(torch.load(model_path))

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(embeddings, (hp.train.N, hp.train.M, embeddings.size(1)))
            print('embeddings size is:', embeddings.size())
            #get loss, call backward, step optimizer
            loss = ge2e_loss(embeddings) #wants (Speaker, Utterances, embedding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            optimizer.step()
            
            total_loss = total_loss + loss
            iteration += 1
            if (batch_id + 1) % hp.train.log_interval == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(time.ctime(), e+1,
                        batch_id+1, len(train_dataset)//hp.train.N, iteration,loss, total_loss / (batch_id + 1))
                print(mesg)
                if hp.train.log_file is not None:
                    with open(hp.train.log_file,'a') as f:
                        f.write(mesg)
                    
        if hp.train.checkpoint_dir is not None and (e + 1) % hp.train.checkpoint_interval == 0:
            embedder_net.eval().cpu()
            ckpt_model_filename = "ckpt_epoch_" + str(e+1) + "_batch_id_" + str(batch_id+1) + ".pth"
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir, ckpt_model_filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            embedder_net.to(device).train()

    #save model
    embedder_net.eval().cpu()
    save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(batch_id + 1) + ".model"
    save_model_path = os.path.join(hp.train.checkpoint_dir, save_model_filename)
    torch.save(embedder_net.state_dict(), save_model_path)
    
    print("\nDone, trained model saved at", save_model_path)
#model_path = './speech_id_checkpoint/512_ckpt_epoch_2880_batch_id_246.pth'
model_path = './speech_id_checkpoint/4_lstmlayer_ckpt_epoch_4320_batch_id_246.pth'
#model_path = './speech_id_checkpoint/ckpt_epoch_9840_batch_id_6.pth'
if (__name__ == '__main__'):

    writer = SummaryWriter()

    device = torch.device(hp.device)
    #model_path = hp.model.model_path

    if hp.data.data_preprocessed:
        train_dataset = SpeakerDatasetTIMITPreprocessed(
            hp.data.train_path, hp.train.M)
    else:
        train_dataset = SpeakerDatasetTIMIT(hp.data.train_path, hp.train.M)

    if hp.data.data_preprocessed:
        test_dataset = SpeakerDatasetTIMITPreprocessed(hp.data.test_path,
                                                       hp.test.M)
    else:
        test_dataset = SpeakerDatasetTIMIT(hp.data.test_path, hp.test.M)

    # if hp.data.data_preprocessed:
    #     test_dataset = SpeakerDatasetTIMITPreprocessed(hp.data.zhouxingchi_path, hp.test.M)
    # else:
    #     test_dataset = SpeakerDatasetTIMIT(hp.data.zhouxingchi_path, hp.test.M)

    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=True,
Example #5
0
def train(model_path):
    FNULL = open(os.devnull, 'w')
    device = torch.device(hp.device)

    if hp.data.data_preprocessed:
        train_dataset = SpeakerDatasetTIMITPreprocessed(is_training=True)
        test_dataset = SpeakerDatasetTIMITPreprocessed(is_training=False)
    else:
        train_dataset = SpeakerDatasetTIMIT(is_training=True)
        test_dataset = SpeakerDatasetTIMIT(is_training=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=True,
                              num_workers=hp.train.num_workers,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=hp.test.N,
                             shuffle=True,
                             num_workers=hp.test.num_workers,
                             drop_last=True)

    embedder_net = SpeechEmbedder().to(device)
    if hp.train.restore:
        subprocess.call([
            'gsutil', 'cp', 'gs://edinquake/asr/baseline_TIMIT/model_best.pkl',
            model_path
        ],
                        stdout=FNULL,
                        stderr=subprocess.STDOUT)
        embedder_net.load_state_dict(torch.load(model_path))
    ge2e_loss = GE2ELoss(device)
    #Both net and loss have trainable parameters
    optimizer = torch.optim.SGD([{
        'params': embedder_net.parameters()
    }, {
        'params': ge2e_loss.parameters()
    }],
                                lr=hp.train.lr)

    os.makedirs(hp.train.checkpoint_dir, exist_ok=True)

    iteration = 0
    best_validate = float('inf')
    print('***Started training at {}***'.format(datetime.now()))
    for e in range(hp.train.epochs):
        total_loss = 0
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(e),
                            leave=False,
                            disable=False)
        embedder_net.train()
        # Iterate over the training set
        for batch_id, mel_db_batch in enumerate(progress_bar):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.train.N * hp.train.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.train.N * hp.train.M),
                                 hp.train.N * hp.train.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]
            #gradient accumulates
            optimizer.zero_grad()

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.train.N, hp.train.M, embeddings.size(1)))

            #get loss, call backward, step optimizer
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0)
            torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0)
            optimizer.step()

            total_loss = total_loss + loss.item()
            iteration += 1

            # Update statistics for progress bar
            progress_bar.set_postfix(iteration=iteration,
                                     loss=loss.item(),
                                     total_loss=total_loss / (batch_id + 1))

        print('| Epoch {:03d}: total_loss {}'.format(e, total_loss))
        # Perform validation
        embedder_net.eval()
        validation_loss = 0
        for batch_id, mel_db_batch in enumerate(test_loader):
            mel_db_batch = mel_db_batch.to(device)

            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.test.N * hp.test.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            perm = random.sample(range(0, hp.test.N * hp.test.M),
                                 hp.test.N * hp.test.M)
            unperm = list(perm)
            for i, j in enumerate(perm):
                unperm[j] = i
            mel_db_batch = mel_db_batch[perm]

            embeddings = embedder_net(mel_db_batch)
            embeddings = embeddings[unperm]
            embeddings = torch.reshape(
                embeddings, (hp.test.N, hp.test.M, embeddings.size(1)))
            #get loss
            loss = ge2e_loss(
                embeddings)  #wants (Speaker, Utterances, embedding)
            validation_loss += loss.item()

        validation_loss /= len(test_loader)
        print('validation_loss: {}'.format(validation_loss))
        if validation_loss <= best_validate:
            best_validate = validation_loss
            # Save best
            filename = 'model_best.pkl'
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir, filename)
            torch.save(embedder_net.state_dict(), ckpt_model_path)
            subprocess.call([
                'gsutil', 'cp', ckpt_model_path,
                'gs://edinquake/asr/baseline_TIMIT/model_best.pkl'
            ],
                            stdout=FNULL,
                            stderr=subprocess.STDOUT)

        filename = 'model_last.pkl'
        ckpt_model_path = os.path.join(hp.train.checkpoint_dir, filename)
        torch.save(embedder_net.state_dict(), ckpt_model_path)