Пример #1
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--pretrained_model",
        type=str,
        default="speakerrecognition_speakernet",
        required=False,
        help="Pass your trained .nemo model",
    )
    parser.add_argument(
        "--finetune_config_file",
        type=str,
        required=True,
        help="path to speakernet config yaml file to load train, validation dataset and also for trainer parameters",
    )

    parser.add_argument(
        "--freeze_encoder",
        type=bool,
        required=False,
        default=True,
        help="True if speakernet encoder paramteres needs to be frozen while finetuning",
    )

    args = parser.parse_args()

    if args.pretrained_model.endswith('.nemo'):
        logging.info(f"Using local speaker model from {args.pretrained_model}")
        speaker_model = EncDecSpeakerLabelModel.restore_from(restore_path=args.pretrained_model)
    elif args.pretrained_model.endswith('.ckpt'):
        logging.info(f"Using local speaker model from checkpoint {args.pretrained_model}")
        speaker_model = EncDecSpeakerLabelModel.load_from_checkpoint(checkpoint_path=args.pretrained_model)
    else:
        logging.info("Using pretrained speaker recognition model from NGC")
        speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_name=args.pretrained_model)

    finetune_config = OmegaConf.load(args.finetune_config_file)

    if 'test_ds' in finetune_config.model and finetune_config.model.test_ds is not None:
        finetune_config.model.test_ds = None
        logging.warning("Removing test ds")

    speaker_model.setup_finetune_model(finetune_config.model)
    finetune_trainer = pl.Trainer(**finetune_config.trainer)
    speaker_model.set_trainer(finetune_trainer)

    _ = exp_manager(finetune_trainer, finetune_config.get('exp_manager', None))
    speaker_model.setup_optimization(finetune_config.optim)

    if args.freeze_encoder:
        for param in speaker_model.encoder.parameters():
            param.requires_grad = False

    finetune_trainer.fit(speaker_model)
Пример #2
0
def main(cfg):

    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    enrollment_manifest = cfg.data.enrollment_manifest
    test_manifest = cfg.data.test_manifest
    out_manifest = cfg.data.out_manifest
    sample_rate = cfg.data.sample_rate

    backend = cfg.backend.backend_model.lower()

    if backend == 'cosine_similarity':
        model_path = cfg.backend.cosine_similarity.model_path
        batch_size = cfg.backend.cosine_similarity.batch_size
        if model_path.endswith('.nemo'):
            speaker_model = EncDecSpeakerLabelModel.restore_from(model_path)
        else:
            speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path)

        enroll_embs, _, enroll_truelabels, enroll_id2label = EncDecSpeakerLabelModel.get_batch_embeddings(
            speaker_model,
            enrollment_manifest,
            batch_size,
            sample_rate,
            device=device,
        )

        test_embs, _, _, _ = EncDecSpeakerLabelModel.get_batch_embeddings(
            speaker_model,
            test_manifest,
            batch_size,
            sample_rate,
            device=device,
        )

        # length normalize
        enroll_embs = enroll_embs / (np.linalg.norm(
            enroll_embs, ord=2, axis=-1, keepdims=True))
        test_embs = test_embs / (np.linalg.norm(
            test_embs, ord=2, axis=-1, keepdims=True))

        # reference embedding
        reference_embs = []
        keyslist = list(enroll_id2label.keys())
        for label_id in keyslist:
            indices = np.where(enroll_truelabels == label_id)
            embedding = (enroll_embs[indices].sum(
                axis=0).squeeze()) / len(indices)
            reference_embs.append(embedding)

        reference_embs = np.asarray(reference_embs)

        scores = np.matmul(test_embs, reference_embs.T)
        matched_labels = scores.argmax(axis=-1)

    elif backend == 'neural_classifier':
        model_path = cfg.backend.neural_classifier.model_path
        batch_size = cfg.backend.neural_classifier.batch_size

        if model_path.endswith('.nemo'):
            speaker_model = EncDecSpeakerLabelModel.restore_from(model_path)
        else:
            speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path)

        featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        dataset = AudioToSpeechLabelDataset(
            manifest_filepath=enrollment_manifest,
            labels=None,
            featurizer=featurizer)
        enroll_id2label = dataset.id2label

        if speaker_model.decoder.final.out_features != len(enroll_id2label):
            raise ValueError(
                "number of labels mis match. Make sure you trained or finetuned neural classifier with labels from enrollement manifest_filepath"
            )

        _, test_logits, _, _ = EncDecSpeakerLabelModel.get_batch_embeddings(
            speaker_model,
            test_manifest,
            batch_size,
            sample_rate,
            device=device,
        )
        matched_labels = test_logits.argmax(axis=-1)

    with open(test_manifest, 'rb') as f1, open(out_manifest,
                                               'w',
                                               encoding='utf-8') as f2:
        lines = f1.readlines()
        for idx, line in enumerate(lines):
            line = line.strip()
            item = json.loads(line)
            item['infer'] = enroll_id2label[matched_labels[idx]]
            json.dump(item, f2)
            f2.write('\n')

    logging.info(
        "Inference labels have been written to {} manifest file".format(
            out_manifest))