def test(model_path): sequence = [] if hp.data.data_preprocessed: test_dataset = SpeakerDatasetTIMITPreprocessed() else: test_dataset = SpeakerDatasetTIMIT() test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=hp.test.num_workers, drop_last=True) embedder_net = SpeechEmbedder() embedder_net.load_state_dict(torch.load(model_path)) embedder_net.eval() device = torch.device(hp.device) count = 0 embeddings = [] devector = [] for e in range(hp.test.epochs): print("hp.test.epochs", hp.test.epochs) for batch_id, mel_db_batch in enumerate(test_loader): #print("mel_db_batch.shape",batch_id,mel_db_batch.shape) #(1,10,160,40) assert hp.test.M % 2 == 0 test_batch = mel_db_batch test_batch = torch.reshape( test_batch, (hp.test.N * hp.test.M, test_batch.size(2), test_batch.size(3))) #print("test_batch.shape",test_batch.shape) #(10,160,40) enrollment_embeddings = embedder_net(test_batch) #print("enrollment_embeddings.shape", enrollment_embeddings.shape) # (10,256) # enrollment_embeddings = torch.reshape(enrollment_embeddings,(hp.test.N, hp.test.M, enrollment_embeddings.size(1))) embedding = enrollment_embeddings.detach().numpy() embeddings.append(embedding) #print('embedding.shape', type(embedding), embedding.shape) # (10,256) devector = np.concatenate(embeddings, axis=0) count = count + 1 np.save('/run/media/rice/DATA/speakerdvector.npy', devector)
def test(model_path): if hp.data.data_preprocessed: test_dataset = SpeakerDatasetTIMITPreprocessed() else: test_dataset = SpeakerDatasetTIMIT() test_loader = DataLoader(test_dataset, batch_size=hp.test.N, shuffle=True, num_workers=hp.test.num_workers, drop_last=True) embedder_net = SpeechEmbedder() embedder_net.load_state_dict(torch.load(model_path)) embedder_net.eval() avg_EER = 0 for e in range(hp.test.epochs): batch_avg_EER = 0 for batch_id, mel_db_batch in enumerate(test_loader): assert hp.test.M % 2 == 0 enrollment_batch, verification_batch = torch.split(mel_db_batch, int(mel_db_batch.size(1)/2), dim=1) enrollment_batch = torch.reshape(enrollment_batch, (hp.test.N*hp.test.M//2, enrollment_batch.size(2), enrollment_batch.size(3))) verification_batch = torch.reshape(verification_batch, (hp.test.N*hp.test.M//2, verification_batch.size(2), verification_batch.size(3))) perm = random.sample(range(0,verification_batch.size(0)), verification_batch.size(0)) unperm = list(perm) for i,j in enumerate(perm): unperm[j] = i verification_batch = verification_batch[perm] enrollment_embeddings = embedder_net(enrollment_batch) verification_embeddings = embedder_net(verification_batch) verification_embeddings = verification_embeddings[unperm] enrollment_embeddings = torch.reshape(enrollment_embeddings, (hp.test.N, hp.test.M//2, enrollment_embeddings.size(1))) verification_embeddings = torch.reshape(verification_embeddings, (hp.test.N, hp.test.M//2, verification_embeddings.size(1))) enrollment_centroids = get_centroids(enrollment_embeddings) sim_matrix = get_cossim(verification_embeddings, enrollment_centroids) # calculating EER diff = 1; EER=0; EER_thresh = 0; EER_FAR=0; EER_FRR=0 for thres in [0.01*i+0.5 for i in range(50)]: sim_matrix_thresh = sim_matrix>thres #sim_matrix_thresh = sim_matrix>0.54 FAR = (sum([sim_matrix_thresh[i].float().sum()-sim_matrix_thresh[i,:,i].float().sum() for i in range(int(hp.test.N))]) /(hp.test.N-1.0)/(float(hp.test.M/2))/hp.test.N) FRR = (sum([hp.test.M/2-sim_matrix_thresh[i,:,i].float().sum() for i in range(int(hp.test.N))]) /(float(hp.test.M/2))/hp.test.N) # Save threshold when FAR = FRR (=EER) if diff> abs(FAR-FRR): diff = abs(FAR-FRR) EER = (FAR+FRR)/2 EER_thresh = thres EER_FAR = FAR EER_FRR = FRR batch_avg_EER += EER print("\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)"%(EER,EER_thresh,EER_FAR,EER_FRR)) avg_EER += batch_avg_EER/(batch_id+1) avg_EER = avg_EER / hp.test.epochs print("\n EER across {0} epochs: {1:.4f}".format(hp.test.epochs, avg_EER))
def train(model_path): #print('Let us start training!') device = torch.device(hp.device) if hp.data.data_preprocessed: train_dataset = SpeakerDatasetTIMITPreprocessed() else: train_dataset = SpeakerDatasetTIMIT() train_loader = DataLoader(train_dataset, batch_size=hp.train.N, shuffle=True, num_workers=hp.train.num_workers, drop_last=True) print(len(train_loader)) embedder_net = SpeechEmbedder().to(device) if hp.train.restore: embedder_net.load_state_dict(torch.load(model_path)) ge2e_loss = GE2ELoss(device) print('ge2e_loss:', ge2e_loss) #Both net and loss have trainable parameters optimizer = torch.optim.SGD([ {'params': embedder_net.parameters()}, {'params': ge2e_loss.parameters()} ], lr=hp.train.lr) os.makedirs(hp.train.checkpoint_dir, exist_ok=True) embedder_net.train() iteration = 0 for e in range(hp.train.epochs): total_loss = 0 for batch_id, mel_db_batch in enumerate(train_loader): mel_db_batch = mel_db_batch.to(device) mel_db_batch = torch.reshape(mel_db_batch, (hp.train.N*hp.train.M, mel_db_batch.size(2), mel_db_batch.size(3))) perm = random.sample(range(0, hp.train.N*hp.train.M), hp.train.N*hp.train.M) unperm = list(perm) for i,j in enumerate(perm): unperm[j] = i mel_db_batch = mel_db_batch[perm] #gradient accumulates optimizer.zero_grad() embedder_net.load_state_dict(torch.load(model_path)) embeddings = embedder_net(mel_db_batch) embeddings = embeddings[unperm] embeddings = torch.reshape(embeddings, (hp.train.N, hp.train.M, embeddings.size(1))) print('embeddings size is:', embeddings.size()) #get loss, call backward, step optimizer loss = ge2e_loss(embeddings) #wants (Speaker, Utterances, embedding) loss.backward() torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0) torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0) optimizer.step() total_loss = total_loss + loss iteration += 1 if (batch_id + 1) % hp.train.log_interval == 0: mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(time.ctime(), e+1, batch_id+1, len(train_dataset)//hp.train.N, iteration,loss, total_loss / (batch_id + 1)) print(mesg) if hp.train.log_file is not None: with open(hp.train.log_file,'a') as f: f.write(mesg) if hp.train.checkpoint_dir is not None and (e + 1) % hp.train.checkpoint_interval == 0: embedder_net.eval().cpu() ckpt_model_filename = "ckpt_epoch_" + str(e+1) + "_batch_id_" + str(batch_id+1) + ".pth" ckpt_model_path = os.path.join(hp.train.checkpoint_dir, ckpt_model_filename) torch.save(embedder_net.state_dict(), ckpt_model_path) embedder_net.to(device).train() #save model embedder_net.eval().cpu() save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(batch_id + 1) + ".model" save_model_path = os.path.join(hp.train.checkpoint_dir, save_model_filename) torch.save(embedder_net.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
#model_path = './speech_id_checkpoint/512_ckpt_epoch_2880_batch_id_246.pth' model_path = './speech_id_checkpoint/4_lstmlayer_ckpt_epoch_4320_batch_id_246.pth' #model_path = './speech_id_checkpoint/ckpt_epoch_9840_batch_id_6.pth' if (__name__ == '__main__'): writer = SummaryWriter() device = torch.device(hp.device) #model_path = hp.model.model_path if hp.data.data_preprocessed: train_dataset = SpeakerDatasetTIMITPreprocessed( hp.data.train_path, hp.train.M) else: train_dataset = SpeakerDatasetTIMIT(hp.data.train_path, hp.train.M) if hp.data.data_preprocessed: test_dataset = SpeakerDatasetTIMITPreprocessed(hp.data.test_path, hp.test.M) else: test_dataset = SpeakerDatasetTIMIT(hp.data.test_path, hp.test.M) # if hp.data.data_preprocessed: # test_dataset = SpeakerDatasetTIMITPreprocessed(hp.data.zhouxingchi_path, hp.test.M) # else: # test_dataset = SpeakerDatasetTIMIT(hp.data.zhouxingchi_path, hp.test.M) train_loader = DataLoader(train_dataset, batch_size=hp.train.N, shuffle=True,
def train(model_path): FNULL = open(os.devnull, 'w') device = torch.device(hp.device) if hp.data.data_preprocessed: train_dataset = SpeakerDatasetTIMITPreprocessed(is_training=True) test_dataset = SpeakerDatasetTIMITPreprocessed(is_training=False) else: train_dataset = SpeakerDatasetTIMIT(is_training=True) test_dataset = SpeakerDatasetTIMIT(is_training=False) train_loader = DataLoader(train_dataset, batch_size=hp.train.N, shuffle=True, num_workers=hp.train.num_workers, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=hp.test.N, shuffle=True, num_workers=hp.test.num_workers, drop_last=True) embedder_net = SpeechEmbedder().to(device) if hp.train.restore: subprocess.call([ 'gsutil', 'cp', 'gs://edinquake/asr/baseline_TIMIT/model_best.pkl', model_path ], stdout=FNULL, stderr=subprocess.STDOUT) embedder_net.load_state_dict(torch.load(model_path)) ge2e_loss = GE2ELoss(device) #Both net and loss have trainable parameters optimizer = torch.optim.SGD([{ 'params': embedder_net.parameters() }, { 'params': ge2e_loss.parameters() }], lr=hp.train.lr) os.makedirs(hp.train.checkpoint_dir, exist_ok=True) iteration = 0 best_validate = float('inf') print('***Started training at {}***'.format(datetime.now())) for e in range(hp.train.epochs): total_loss = 0 progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(e), leave=False, disable=False) embedder_net.train() # Iterate over the training set for batch_id, mel_db_batch in enumerate(progress_bar): mel_db_batch = mel_db_batch.to(device) mel_db_batch = torch.reshape( mel_db_batch, (hp.train.N * hp.train.M, mel_db_batch.size(2), mel_db_batch.size(3))) perm = random.sample(range(0, hp.train.N * hp.train.M), hp.train.N * hp.train.M) unperm = list(perm) for i, j in enumerate(perm): unperm[j] = i mel_db_batch = mel_db_batch[perm] #gradient accumulates optimizer.zero_grad() embeddings = embedder_net(mel_db_batch) embeddings = embeddings[unperm] embeddings = torch.reshape( embeddings, (hp.train.N, hp.train.M, embeddings.size(1))) #get loss, call backward, step optimizer loss = ge2e_loss( embeddings) #wants (Speaker, Utterances, embedding) loss.backward() torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0) torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0) optimizer.step() total_loss = total_loss + loss.item() iteration += 1 # Update statistics for progress bar progress_bar.set_postfix(iteration=iteration, loss=loss.item(), total_loss=total_loss / (batch_id + 1)) print('| Epoch {:03d}: total_loss {}'.format(e, total_loss)) # Perform validation embedder_net.eval() validation_loss = 0 for batch_id, mel_db_batch in enumerate(test_loader): mel_db_batch = mel_db_batch.to(device) mel_db_batch = torch.reshape( mel_db_batch, (hp.test.N * hp.test.M, mel_db_batch.size(2), mel_db_batch.size(3))) perm = random.sample(range(0, hp.test.N * hp.test.M), hp.test.N * hp.test.M) unperm = list(perm) for i, j in enumerate(perm): unperm[j] = i mel_db_batch = mel_db_batch[perm] embeddings = embedder_net(mel_db_batch) embeddings = embeddings[unperm] embeddings = torch.reshape( embeddings, (hp.test.N, hp.test.M, embeddings.size(1))) #get loss loss = ge2e_loss( embeddings) #wants (Speaker, Utterances, embedding) validation_loss += loss.item() validation_loss /= len(test_loader) print('validation_loss: {}'.format(validation_loss)) if validation_loss <= best_validate: best_validate = validation_loss # Save best filename = 'model_best.pkl' ckpt_model_path = os.path.join(hp.train.checkpoint_dir, filename) torch.save(embedder_net.state_dict(), ckpt_model_path) subprocess.call([ 'gsutil', 'cp', ckpt_model_path, 'gs://edinquake/asr/baseline_TIMIT/model_best.pkl' ], stdout=FNULL, stderr=subprocess.STDOUT) filename = 'model_last.pkl' ckpt_model_path = os.path.join(hp.train.checkpoint_dir, filename) torch.save(embedder_net.state_dict(), ckpt_model_path)