def main(cfg): # Generate default asr model config asr_model_config = configs.EncDecCTCModelConfig() # Merge hydra updates with model config # `drop_missing_subconfig=True` is necessary here. Without it, while the data class will instantiate and be added # to the config, it contains test_ds.sample_rate = MISSING and test_ds.labels = MISSING. # This will raise a OmegaConf MissingMandatoryValue error when processing the dataloaders inside # model_utils.resolve_test_dataloaders(model=self) (used for multi data loader support). # In general, any operation that tries to use a DictConfig with MISSING in it will fail, # other than explicit update operations to change MISSING to some actual value. asr_model_config = update_model_config(asr_model_config, cfg, drop_missing_subconfigs=True) # From here on out, its a general OmegaConf DictConfig, directly usable by our code. trainer = pl.Trainer(**asr_model_config.trainer) exp_manager(trainer, asr_model_config.get("exp_manager", None)) asr_model = EncDecCTCModel(cfg=asr_model_config.model, trainer=trainer) trainer.fit(asr_model)
def test_dataclass_instantiation(self, asr_model): model_cfg = configs.EncDecCTCModelConfig() # Update mandatory values vocabulary = asr_model.decoder.vocabulary model_cfg.model.labels = vocabulary # Update encoder model_cfg.model.encoder.activation = 'relu' model_cfg.model.encoder.feat_in = 64 model_cfg.model.encoder.jasper = [ nemo_asr.modules.conv_asr.JasperEncoderConfig( filters=1024, repeat=1, kernel=[1], stride=[1], dilation=[1], dropout=0.0, residual=False, se=True, se_context_size=-1, ) ] # Update decoder model_cfg.model.decoder.feat_in = 1024 model_cfg.model.decoder.num_classes = 28 model_cfg.model.decoder.vocabulary = vocabulary # Construct the model asr_cfg = OmegaConf.create({'model': asr_model.cfg}) model_cfg_v1 = update_model_config(model_cfg, asr_cfg) new_model = EncDecCTCModel(cfg=model_cfg_v1.model) assert new_model.num_weights == asr_model.num_weights # trainer and exp manager should be there # assert 'trainer' in model_cfg_v1 # assert 'exp_manager' in model_cfg_v1 # datasets and optim/sched should not be there after ModelPT.update_model_dataclass() assert 'train_ds' not in model_cfg_v1.model assert 'validation_ds' not in model_cfg_v1.model assert 'test_ds' not in model_cfg_v1.model assert 'optim' not in model_cfg_v1.model # Construct the model, without dropping additional keys asr_cfg = OmegaConf.create({'model': asr_model.cfg}) model_cfg_v2 = update_model_config(model_cfg, asr_cfg, drop_missing_subconfigs=False) # Assert all components are in config # assert 'trainer' in model_cfg_v2 # assert 'exp_manager' in model_cfg_v2 assert 'train_ds' in model_cfg_v2.model assert 'validation_ds' in model_cfg_v2.model assert 'test_ds' in model_cfg_v2.model assert 'optim' in model_cfg_v2.model # Remove extra components (optim and sched can be kept without issue) with open_dict(model_cfg_v2.model): model_cfg_v2.model.pop('train_ds') model_cfg_v2.model.pop('validation_ds') model_cfg_v2.model.pop('test_ds') new_model = EncDecCTCModel(cfg=model_cfg_v2.model) assert new_model.num_weights == asr_model.num_weights
# See the License for the specific language governing permissions and # limitations under the License. from dataclasses import asdict import pytorch_lightning as pl import nemo.collections.asr as nemo_asr from nemo.collections.asr.models import EncDecCTCModel, configs from nemo.utils.exp_manager import exp_manager """ python speech_to_text_structured.py """ # Generate default asr model config cfg = configs.EncDecCTCModelConfig() # set global values cfg.model.repeat = 5 cfg.model.separable = True # fmt: off LABELS = [ " ", "a", "b", "c", "d", "e", "f", "g",