Example #1
0
def test_create_from_manifest():
    path = Path("tests/nemo_config_samples")
    for cfg in path.glob("*.yaml"):
        encoder_params, initial_vocab, preprocess_params = read_params_from_config(cfg)
        fb = FilterbankFeatures(**preprocess_params)
        encoder = Quartznet5(64, **encoder_params)
        decoder = Quartznet_decoder(len(initial_vocab))

        x = torch.randn(10, 1337)
        feat = fb(x)
        lens = torch.randint(10, 100, (10,))
        out, _ = encoder(feat, lens)
        out2 = decoder(out)
        assert feat.shape[0] == x.shape[0]
        assert feat.shape[1] == 64
        assert out.shape[0] == x.shape[0]
        assert not torch.isnan(out).any()
        assert out2.shape[0] == x.shape[0]
        assert not torch.isnan(out2).any()

        if "Net5x5" in cfg.name:
            encoder2 = Quartznet5x5_encoder(64)
            encoder2.load_state_dict(encoder.state_dict())
        else:
            encoder2 = Quartznet15x5_encoder(64)
            encoder2.load_state_dict(encoder.state_dict())
Example #2
0
    def load_from_nemo(cls,
                       *,
                       nemo_filepath: str = None,
                       checkpoint_name: NemoCheckpoint = None):
        if checkpoint_name is not None:
            nemo_filepath = download_checkpoint(checkpoint_name)
        if nemo_filepath is None and checkpoint_name is None:
            raise ValueError(
                "Either nemo_filepath or checkpoint_name must be passed")

        with TemporaryDirectory() as extract_path:
            extract_path = Path(extract_path)
            extract_archive(str(nemo_filepath), extract_path)
            config_path = extract_path / "model_config.yaml"
            encoder_params, initial_vocab, preprocess_params = read_params_from_config(
                config_path)
            module = cls(
                initial_vocab_tokens=initial_vocab,
                **encoder_params,
                **preprocess_params,
                nemo_compat_vocab=True,
            )
            weights_path = extract_path / "model_weights.ckpt"
            load_quartznet_weights(module.encoder, module.decoder,
                                   weights_path)
        # Here we set it in eval mode, so it correctly works during inference
        # Supposing that the majority of applications will be either load a checkpoint
        # and directly run inference, or fine-tuning. Either way this will prevent a silent
        # bug (case 1) or will be ignored (case 2).
        module.eval()
        return module
Example #3
0
def test_can_load_weights():
    # Quartznet 5x5 is small (25mb), so it can be downloaded while testing.
    try:

        cfg = download_checkpoint(NemoCheckpoint.QuartzNet5x5LS_En)
        with TemporaryDirectory() as extract_path:
            extract_path = Path(extract_path)
            extract_archive(str(cfg), extract_path)
            config_path = extract_path / "model_config.yaml"
            encoder_params, initial_vocab, _ = read_params_from_config(config_path)
            encoder = Quartznet5(64, **encoder_params)
            decoder = Quartznet_decoder(len(initial_vocab) + 1)
            load_quartznet_weights(
                encoder, decoder, extract_path / "model_weights.ckpt"
            )
    except HTTPError:
        return
Example #4
0
    def load_from_nemo(
            cls,
            *,
            nemo_filepath: str = None,
            checkpoint_name: NemoCheckpoint = None) -> "QuartznetModule":
        """Load from the original nemo checkpoint.

        Args:
            nemo_filepath : Path to local .nemo file.
            checkpoint_name : Name of checkpoint to be downloaded locally and lodaded.

        Raises:
            ValueError: You need to pass only one of the two parameters.

        Returns:
            The model loaded from the checkpoint
        """
        if checkpoint_name is not None:
            nemo_filepath = download_checkpoint(checkpoint_name)
        if nemo_filepath is None and checkpoint_name is None:
            raise ValueError(
                "Either nemo_filepath or checkpoint_name must be passed")

        with TemporaryDirectory() as extract_path:
            extract_path = Path(extract_path)
            extract_archive(str(nemo_filepath), extract_path)
            config_path = extract_path / "model_config.yaml"
            encoder_params, initial_vocab, preprocess_params = read_params_from_config(
                config_path)
            module = cls(
                initial_vocab_tokens=initial_vocab,
                **encoder_params,
                **preprocess_params,
                nemo_compat_vocab=True,
            )
            weights_path = extract_path / "model_weights.ckpt"
            load_quartznet_weights(module.encoder, module.decoder,
                                   weights_path)
        # Here we set it in eval mode, so it correctly works during inference
        # Supposing that the majority of applications will be either (1) load a checkpoint
        # and directly run inference, or (2) fine-tuning. Either way this will prevent a silent
        # bug (case 1) or will be ignored (case 2).
        module.eval()
        return module