def test_create_from_manifest(): path = Path("tests/nemo_config_samples") for cfg in path.glob("*.yaml"): encoder_params, initial_vocab, preprocess_params = read_params_from_config(cfg) fb = FilterbankFeatures(**preprocess_params) encoder = Quartznet5(64, **encoder_params) decoder = Quartznet_decoder(len(initial_vocab)) x = torch.randn(10, 1337) feat = fb(x) lens = torch.randint(10, 100, (10,)) out, _ = encoder(feat, lens) out2 = decoder(out) assert feat.shape[0] == x.shape[0] assert feat.shape[1] == 64 assert out.shape[0] == x.shape[0] assert not torch.isnan(out).any() assert out2.shape[0] == x.shape[0] assert not torch.isnan(out2).any() if "Net5x5" in cfg.name: encoder2 = Quartznet5x5_encoder(64) encoder2.load_state_dict(encoder.state_dict()) else: encoder2 = Quartznet15x5_encoder(64) encoder2.load_state_dict(encoder.state_dict())
def load_from_nemo(cls, *, nemo_filepath: str = None, checkpoint_name: NemoCheckpoint = None): if checkpoint_name is not None: nemo_filepath = download_checkpoint(checkpoint_name) if nemo_filepath is None and checkpoint_name is None: raise ValueError( "Either nemo_filepath or checkpoint_name must be passed") with TemporaryDirectory() as extract_path: extract_path = Path(extract_path) extract_archive(str(nemo_filepath), extract_path) config_path = extract_path / "model_config.yaml" encoder_params, initial_vocab, preprocess_params = read_params_from_config( config_path) module = cls( initial_vocab_tokens=initial_vocab, **encoder_params, **preprocess_params, nemo_compat_vocab=True, ) weights_path = extract_path / "model_weights.ckpt" load_quartznet_weights(module.encoder, module.decoder, weights_path) # Here we set it in eval mode, so it correctly works during inference # Supposing that the majority of applications will be either load a checkpoint # and directly run inference, or fine-tuning. Either way this will prevent a silent # bug (case 1) or will be ignored (case 2). module.eval() return module
def test_can_load_weights(): # Quartznet 5x5 is small (25mb), so it can be downloaded while testing. try: cfg = download_checkpoint(NemoCheckpoint.QuartzNet5x5LS_En) with TemporaryDirectory() as extract_path: extract_path = Path(extract_path) extract_archive(str(cfg), extract_path) config_path = extract_path / "model_config.yaml" encoder_params, initial_vocab, _ = read_params_from_config(config_path) encoder = Quartznet5(64, **encoder_params) decoder = Quartznet_decoder(len(initial_vocab) + 1) load_quartznet_weights( encoder, decoder, extract_path / "model_weights.ckpt" ) except HTTPError: return
def load_from_nemo( cls, *, nemo_filepath: str = None, checkpoint_name: NemoCheckpoint = None) -> "QuartznetModule": """Load from the original nemo checkpoint. Args: nemo_filepath : Path to local .nemo file. checkpoint_name : Name of checkpoint to be downloaded locally and lodaded. Raises: ValueError: You need to pass only one of the two parameters. Returns: The model loaded from the checkpoint """ if checkpoint_name is not None: nemo_filepath = download_checkpoint(checkpoint_name) if nemo_filepath is None and checkpoint_name is None: raise ValueError( "Either nemo_filepath or checkpoint_name must be passed") with TemporaryDirectory() as extract_path: extract_path = Path(extract_path) extract_archive(str(nemo_filepath), extract_path) config_path = extract_path / "model_config.yaml" encoder_params, initial_vocab, preprocess_params = read_params_from_config( config_path) module = cls( initial_vocab_tokens=initial_vocab, **encoder_params, **preprocess_params, nemo_compat_vocab=True, ) weights_path = extract_path / "model_weights.ckpt" load_quartznet_weights(module.encoder, module.decoder, weights_path) # Here we set it in eval mode, so it correctly works during inference # Supposing that the majority of applications will be either (1) load a checkpoint # and directly run inference, or (2) fine-tuning. Either way this will prevent a silent # bug (case 1) or will be ignored (case 2). module.eval() return module