def get_model_config(in_features=50, update_adapter_cfg: bool = True): config = OmegaConf.create( { 'in_features': in_features, 'encoder': {'_target_': get_classpath(DefaultModule)}, 'decoder': {'_target_': get_classpath(DefaultModule)}, } ) if update_adapter_cfg: enc_adapter_metadata = adapter_mixins.get_registered_adapter(config.encoder._target_) if enc_adapter_metadata is not None: config.encoder._target_ = enc_adapter_metadata.adapter_class_path dec_adapter_metadata = adapter_mixins.get_registered_adapter(config.decoder._target_) if dec_adapter_metadata is not None: config.decoder._target_ = dec_adapter_metadata.adapter_class_path return config
def test_constructor_pretrained(self): # Check to/from config_dict: cfg = ASRModel.from_pretrained('stt_en_citrinet_256', map_location='cpu', return_config=True) adapter_metadata = get_registered_adapter(cfg.encoder._target_) if adapter_metadata is not None: cfg.encoder._target_ = adapter_metadata.adapter_class_path model = ASRModel.from_pretrained('stt_en_citrinet_256', override_config_path=cfg) assert isinstance(model, AdapterModuleMixin) assert hasattr(model, 'encoder') assert isinstance(model.encoder, AdapterModuleMixin) model.add_adapter( 'adapter_0', cfg=get_adapter_cfg(in_features=cfg.encoder.jasper[0].filters, dim=5)) assert model.is_adapter_available() model.freeze() model.unfreeze_enabled_adapters() assert model.num_weights < 1e5
def update_adapter_global_cfg(cfg: DictConfig, encoder_adapter=True, decoder_adapter=False): if 'adapters' not in cfg: cfg.adapters = adapter_mixins._prepare_default_adapter_config( global_key=AdapterModuleMixin.adapter_global_cfg_key, meta_key=AdapterModuleMixin.adapter_metadata_cfg_key ) cfg.adapters.global_cfg.encoder_adapter = encoder_adapter cfg.adapters.global_cfg.decoder_adapter = decoder_adapter return cfg def get_classpath(cls): return f'{cls.__module__}.{cls.__name__}' if adapter_mixins.get_registered_adapter(DefaultModule) is None: adapter_mixins.register_adapter(DefaultModule, DefaultModuleAdapter) class TestAdapterModelMixin: @pytest.mark.unit def test_base_model_no_support_for_adapters(self, caplog): logging._logger.propagate = True original_verbosity = logging.get_verbosity() logging.set_verbosity(logging.WARNING) caplog.set_level(logging.WARNING) cfg = get_model_config(in_features=50, update_adapter_cfg=False) model = DefaultAdapterModel(cfg) with pytest.raises(AttributeError):
def test_module_registered_adapter_by_adapter_class(self): adapter_meta = adapter_mixins.get_registered_adapter( DefaultModuleAdapter) assert adapter_meta is not None assert adapter_meta.base_class == DefaultModule assert adapter_meta.adapter_class == DefaultModuleAdapter
def test_module_registered_adapter_by_class_path(self): classpath = get_classpath(DefaultModule) adapter_meta = adapter_mixins.get_registered_adapter(classpath) assert adapter_meta is not None assert adapter_meta.base_class == DefaultModule assert adapter_meta.adapter_class == DefaultModuleAdapter
def test_adapter_registry_via_adapter_class(self, model): # The encoder is already an adapter compatible class metadata = get_registered_adapter(model.encoder.__class__) assert metadata is not None assert metadata.adapter_class == model.encoder.__class__
for conformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin conformer_layer.add_adapter(name, cfg) def is_adapter_available(self) -> bool: return any([ conformer_layer.is_adapter_available() for conformer_layer in self.layers ]) def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): for conformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin conformer_layer.set_enabled_adapters(name=name, enabled=enabled) def get_enabled_adapters(self) -> List[str]: names = set([]) for conformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin names.update(conformer_layer.get_enabled_adapters()) names = sorted(list(names)) return names """ Register any additional information """ if adapter_mixins.get_registered_adapter(ConformerEncoder) is None: adapter_mixins.register_adapter(base_class=ConformerEncoder, adapter_class=ConformerEncoderAdapter)
conv_mask: bool = True frame_splicing: int = 1 init_mode: Optional[str] = "xavier_uniform" @dataclass class ConvASRDecoderConfig: _target_: str = 'nemo.collections.asr.modules.ConvASRDecoder' feat_in: int = MISSING num_classes: int = MISSING init_mode: Optional[str] = "xavier_uniform" vocabulary: Optional[List[str]] = field(default_factory=list) @dataclass class ConvASRDecoderClassificationConfig: _target_: str = 'nemo.collections.asr.modules.ConvASRDecoderClassification' feat_in: int = MISSING num_classes: int = MISSING init_mode: Optional[str] = "xavier_uniform" return_logits: bool = True pooling_type: str = 'avg' """ Register any additional information """ if adapter_mixins.get_registered_adapter(ConvASREncoder) is None: adapter_mixins.register_adapter(base_class=ConvASREncoder, adapter_class=ConvASREncoderAdapter)