def main(): parser = ArgumentParser() parser.add_argument( "--pretrained_model", type=str, default="speakerrecognition_speakernet", required=False, help="Pass your trained .nemo model", ) parser.add_argument( "--finetune_config_file", type=str, required=True, help="path to speakernet config yaml file to load train, validation dataset and also for trainer parameters", ) parser.add_argument( "--freeze_encoder", type=bool, required=False, default=True, help="True if speakernet encoder paramteres needs to be frozen while finetuning", ) args = parser.parse_args() if args.pretrained_model.endswith('.nemo'): logging.info(f"Using local speaker model from {args.pretrained_model}") speaker_model = EncDecSpeakerLabelModel.restore_from(restore_path=args.pretrained_model) elif args.pretrained_model.endswith('.ckpt'): logging.info(f"Using local speaker model from checkpoint {args.pretrained_model}") speaker_model = EncDecSpeakerLabelModel.load_from_checkpoint(checkpoint_path=args.pretrained_model) else: logging.info("Using pretrained speaker recognition model from NGC") speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_name=args.pretrained_model) finetune_config = OmegaConf.load(args.finetune_config_file) if 'test_ds' in finetune_config.model and finetune_config.model.test_ds is not None: finetune_config.model.test_ds = None logging.warning("Removing test ds") speaker_model.setup_finetune_model(finetune_config.model) finetune_trainer = pl.Trainer(**finetune_config.trainer) speaker_model.set_trainer(finetune_trainer) _ = exp_manager(finetune_trainer, finetune_config.get('exp_manager', None)) speaker_model.setup_optimization(finetune_config.optim) if args.freeze_encoder: for param in speaker_model.encoder.parameters(): param.requires_grad = False finetune_trainer.fit(speaker_model)
def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') device = 'cuda' if torch.cuda.is_available() else 'cpu' enrollment_manifest = cfg.data.enrollment_manifest test_manifest = cfg.data.test_manifest out_manifest = cfg.data.out_manifest sample_rate = cfg.data.sample_rate backend = cfg.backend.backend_model.lower() if backend == 'cosine_similarity': model_path = cfg.backend.cosine_similarity.model_path batch_size = cfg.backend.cosine_similarity.batch_size if model_path.endswith('.nemo'): speaker_model = EncDecSpeakerLabelModel.restore_from(model_path) else: speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path) enroll_embs, _, enroll_truelabels, enroll_id2label = EncDecSpeakerLabelModel.get_batch_embeddings( speaker_model, enrollment_manifest, batch_size, sample_rate, device=device, ) test_embs, _, _, _ = EncDecSpeakerLabelModel.get_batch_embeddings( speaker_model, test_manifest, batch_size, sample_rate, device=device, ) # length normalize enroll_embs = enroll_embs / (np.linalg.norm( enroll_embs, ord=2, axis=-1, keepdims=True)) test_embs = test_embs / (np.linalg.norm( test_embs, ord=2, axis=-1, keepdims=True)) # reference embedding reference_embs = [] keyslist = list(enroll_id2label.keys()) for label_id in keyslist: indices = np.where(enroll_truelabels == label_id) embedding = (enroll_embs[indices].sum( axis=0).squeeze()) / len(indices) reference_embs.append(embedding) reference_embs = np.asarray(reference_embs) scores = np.matmul(test_embs, reference_embs.T) matched_labels = scores.argmax(axis=-1) elif backend == 'neural_classifier': model_path = cfg.backend.neural_classifier.model_path batch_size = cfg.backend.neural_classifier.batch_size if model_path.endswith('.nemo'): speaker_model = EncDecSpeakerLabelModel.restore_from(model_path) else: speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path) featurizer = WaveformFeaturizer(sample_rate=sample_rate) dataset = AudioToSpeechLabelDataset( manifest_filepath=enrollment_manifest, labels=None, featurizer=featurizer) enroll_id2label = dataset.id2label if speaker_model.decoder.final.out_features != len(enroll_id2label): raise ValueError( "number of labels mis match. Make sure you trained or finetuned neural classifier with labels from enrollement manifest_filepath" ) _, test_logits, _, _ = EncDecSpeakerLabelModel.get_batch_embeddings( speaker_model, test_manifest, batch_size, sample_rate, device=device, ) matched_labels = test_logits.argmax(axis=-1) with open(test_manifest, 'rb') as f1, open(out_manifest, 'w', encoding='utf-8') as f2: lines = f1.readlines() for idx, line in enumerate(lines): line = line.strip() item = json.loads(line) item['infer'] = enroll_id2label[matched_labels[idx]] json.dump(item, f2) f2.write('\n') logging.info( "Inference labels have been written to {} manifest file".format( out_manifest))