Exemple #1
0
def get_model_config(in_features=50, update_adapter_cfg: bool = True):
    config = OmegaConf.create(
        {
            'in_features': in_features,
            'encoder': {'_target_': get_classpath(DefaultModule)},
            'decoder': {'_target_': get_classpath(DefaultModule)},
        }
    )

    if update_adapter_cfg:
        enc_adapter_metadata = adapter_mixins.get_registered_adapter(config.encoder._target_)
        if enc_adapter_metadata is not None:
            config.encoder._target_ = enc_adapter_metadata.adapter_class_path

        dec_adapter_metadata = adapter_mixins.get_registered_adapter(config.decoder._target_)
        if dec_adapter_metadata is not None:
            config.decoder._target_ = dec_adapter_metadata.adapter_class_path

    return config
Exemple #2
0
    def test_constructor_pretrained(self):
        # Check to/from config_dict:
        cfg = ASRModel.from_pretrained('stt_en_citrinet_256',
                                       map_location='cpu',
                                       return_config=True)
        adapter_metadata = get_registered_adapter(cfg.encoder._target_)
        if adapter_metadata is not None:
            cfg.encoder._target_ = adapter_metadata.adapter_class_path
        model = ASRModel.from_pretrained('stt_en_citrinet_256',
                                         override_config_path=cfg)

        assert isinstance(model, AdapterModuleMixin)
        assert hasattr(model, 'encoder')
        assert isinstance(model.encoder, AdapterModuleMixin)

        model.add_adapter(
            'adapter_0',
            cfg=get_adapter_cfg(in_features=cfg.encoder.jasper[0].filters,
                                dim=5))
        assert model.is_adapter_available()

        model.freeze()
        model.unfreeze_enabled_adapters()
        assert model.num_weights < 1e5
Exemple #3
0
def update_adapter_global_cfg(cfg: DictConfig, encoder_adapter=True, decoder_adapter=False):
    if 'adapters' not in cfg:
        cfg.adapters = adapter_mixins._prepare_default_adapter_config(
            global_key=AdapterModuleMixin.adapter_global_cfg_key, meta_key=AdapterModuleMixin.adapter_metadata_cfg_key
        )

    cfg.adapters.global_cfg.encoder_adapter = encoder_adapter
    cfg.adapters.global_cfg.decoder_adapter = decoder_adapter
    return cfg


def get_classpath(cls):
    return f'{cls.__module__}.{cls.__name__}'


if adapter_mixins.get_registered_adapter(DefaultModule) is None:
    adapter_mixins.register_adapter(DefaultModule, DefaultModuleAdapter)


class TestAdapterModelMixin:
    @pytest.mark.unit
    def test_base_model_no_support_for_adapters(self, caplog):
        logging._logger.propagate = True
        original_verbosity = logging.get_verbosity()
        logging.set_verbosity(logging.WARNING)
        caplog.set_level(logging.WARNING)

        cfg = get_model_config(in_features=50, update_adapter_cfg=False)
        model = DefaultAdapterModel(cfg)

        with pytest.raises(AttributeError):
Exemple #4
0
 def test_module_registered_adapter_by_adapter_class(self):
     adapter_meta = adapter_mixins.get_registered_adapter(
         DefaultModuleAdapter)
     assert adapter_meta is not None
     assert adapter_meta.base_class == DefaultModule
     assert adapter_meta.adapter_class == DefaultModuleAdapter
Exemple #5
0
 def test_module_registered_adapter_by_class_path(self):
     classpath = get_classpath(DefaultModule)
     adapter_meta = adapter_mixins.get_registered_adapter(classpath)
     assert adapter_meta is not None
     assert adapter_meta.base_class == DefaultModule
     assert adapter_meta.adapter_class == DefaultModuleAdapter
Exemple #6
0
 def test_adapter_registry_via_adapter_class(self, model):
     # The encoder is already an adapter compatible class
     metadata = get_registered_adapter(model.encoder.__class__)
     assert metadata is not None
     assert metadata.adapter_class == model.encoder.__class__
Exemple #7
0
        for conformer_layer in self.layers:  # type: adapter_mixins.AdapterModuleMixin
            conformer_layer.add_adapter(name, cfg)

    def is_adapter_available(self) -> bool:
        return any([
            conformer_layer.is_adapter_available()
            for conformer_layer in self.layers
        ])

    def set_enabled_adapters(self,
                             name: Optional[str] = None,
                             enabled: bool = True):
        for conformer_layer in self.layers:  # type: adapter_mixins.AdapterModuleMixin
            conformer_layer.set_enabled_adapters(name=name, enabled=enabled)

    def get_enabled_adapters(self) -> List[str]:
        names = set([])
        for conformer_layer in self.layers:  # type: adapter_mixins.AdapterModuleMixin
            names.update(conformer_layer.get_enabled_adapters())

        names = sorted(list(names))
        return names


"""
Register any additional information
"""
if adapter_mixins.get_registered_adapter(ConformerEncoder) is None:
    adapter_mixins.register_adapter(base_class=ConformerEncoder,
                                    adapter_class=ConformerEncoderAdapter)
Exemple #8
0
    conv_mask: bool = True
    frame_splicing: int = 1
    init_mode: Optional[str] = "xavier_uniform"


@dataclass
class ConvASRDecoderConfig:
    _target_: str = 'nemo.collections.asr.modules.ConvASRDecoder'
    feat_in: int = MISSING
    num_classes: int = MISSING
    init_mode: Optional[str] = "xavier_uniform"
    vocabulary: Optional[List[str]] = field(default_factory=list)


@dataclass
class ConvASRDecoderClassificationConfig:
    _target_: str = 'nemo.collections.asr.modules.ConvASRDecoderClassification'
    feat_in: int = MISSING
    num_classes: int = MISSING
    init_mode: Optional[str] = "xavier_uniform"
    return_logits: bool = True
    pooling_type: str = 'avg'


"""
Register any additional information
"""
if adapter_mixins.get_registered_adapter(ConvASREncoder) is None:
    adapter_mixins.register_adapter(base_class=ConvASREncoder,
                                    adapter_class=ConvASREncoderAdapter)