def test_epoch_end(self, outputs): embs = torch.cat([x['embs'] for x in outputs]) slices = torch.cat([x['slices'] for x in outputs]) emb_shape = embs.shape[-1] embs = embs.view(-1, emb_shape).cpu().numpy() embs = embedding_normalize(embs) out_embeddings = {} start_idx = 0 with open(self.test_manifest, 'r') as manifest: for idx, line in enumerate(manifest.readlines()): line = line.strip() dic = json.loads(line) structure = dic['audio_filepath'].split('/')[-3:] uniq_name = '@'.join(structure) if uniq_name in out_embeddings: raise KeyError( "Embeddings for label {} already present in emb dictionary" .format(uniq_name)) num_slices = slices[idx] end_idx = start_idx + num_slices out_embeddings[uniq_name] = embs[start_idx:end_idx].mean( axis=0) start_idx = end_idx embedding_dir = os.path.join(self.embedding_dir, 'embeddings') if not os.path.exists(embedding_dir): os.mkdir(embedding_dir) prefix = self.test_manifest.split('/')[-1].split('.')[-2] name = os.path.join(embedding_dir, prefix) pkl.dump(out_embeddings, open(name + '_embeddings.pkl', 'wb')) logging.info("Saved embedding files to {}".format(embedding_dir)) return {}
def get_embeddings(speaker_model, manifest_file, batch_size=1, embedding_dir='./', device='cuda'): test_config = OmegaConf.create( dict( manifest_filepath=manifest_file, sample_rate=16000, labels=None, batch_size=batch_size, shuffle=False, time_length=20, )) speaker_model.setup_test_data(test_config) speaker_model = speaker_model.to(device) speaker_model.eval() all_embs = [] out_embeddings = {} for test_batch in tqdm(speaker_model.test_dataloader()): test_batch = [x.to(device) for x in test_batch] audio_signal, audio_signal_len, labels, slices = test_batch with autocast(): _, embs = speaker_model.forward( input_signal=audio_signal, input_signal_length=audio_signal_len) emb_shape = embs.shape[-1] embs = embs.view(-1, emb_shape) all_embs.extend(embs.cpu().detach().numpy()) del test_batch all_embs = np.asarray(all_embs) all_embs = embedding_normalize(all_embs) with open(manifest_file, 'r') as manifest: for i, line in enumerate(manifest.readlines()): line = line.strip() dic = json.loads(line) uniq_name = '@'.join(dic['audio_filepath'].split('/')[-3:]) out_embeddings[uniq_name] = all_embs[i] embedding_dir = os.path.join(embedding_dir, 'embeddings') if not os.path.exists(embedding_dir): os.makedirs(embedding_dir, exist_ok=True) prefix = manifest_file.split('/')[-1].rsplit('.', 1)[-2] name = os.path.join(embedding_dir, prefix) embeddings_file = name + '_embeddings.pkl' pkl.dump(out_embeddings, open(embeddings_file, 'wb')) logging.info("Saved embedding files to {}".format(embedding_dir))