def main(cfg: MTBottleneckConfig) -> None: # merge default config with user specified config default_cfg = MTBottleneckConfig() cfg = update_model_config(default_cfg, cfg) logging.info("\n\n************** Experiment configuration ***********") logging.info(f'Config: {OmegaConf.to_yaml(cfg)}') # training is managed by PyTorch Lightning trainer_cfg = OmegaConf.to_container(cfg.trainer) trainer_cfg.pop('plugins', None) trainer = Trainer(plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes)], **trainer_cfg) # tokenizers will be trained and and tarred training data will be created if needed # model config is then updated if cfg.model.preproc_out_dir is not None: MTDataPreproc(cfg=cfg.model, trainer=trainer) # experiment logs, checkpoints, and auto-resume are managed by exp_manager and PyTorch Lightning exp_manager(trainer, cfg.exp_manager) # everything needed to train translation models is encapsulated in the NeMo MTEncdDecModel mt_model = MTBottleneckModel(cfg.model, trainer=trainer) logging.info("\n\n************** Model parameters and their sizes ***********") for name, param in mt_model.named_parameters(): print(name, param.size()) logging.info("***********************************************************\n\n") if cfg.do_training: trainer.fit(mt_model) if cfg.do_testing: trainer.test(mt_model)
def main(cfg: MTEncDecConfig) -> None: # merge default config with user specified config default_cfg = MTEncDecConfig() cfg = update_model_config(default_cfg, cfg) logging.info("\n\n************** Experiment configuration ***********") logging.info(f'Config: {cfg.pretty()}') # training is managed by PyTorch Lightning trainer = Trainer(**cfg.trainer) # tokenizers will be trained and and tarred training data will be created if needed # model config is then updated MTDataPreproc(cfg=cfg.model, trainer=trainer) if cfg.do_training: # experiment logs, checkpoints, and auto-resume are managed by exp_manager and PyTorch Lightning exp_manager(trainer, cfg.exp_manager) # everything needed to train translation models is encapsulated in the NeMo MTEncdDecModel mt_model = MTEncDecModel(cfg.model, trainer=trainer) logging.info("\n\n************** Model parameters and their sizes ***********") for name, param in mt_model.named_parameters(): print(name, param.size()) logging.info("***********************************************************\n\n") trainer.fit(mt_model)
def main(cfg: MTEncDecConfig) -> None: # # merge default config with user specified config default_cfg = MTEncDecConfig() cfg = update_model_config(default_cfg, cfg) logging.info("\n\n************** Experiment configuration ***********") logging.info(f'Config: {cfg.pretty()}') trainer = Trainer(**cfg.trainer) exp_manager(trainer, cfg.exp_manager) mt_model = MTEncDecModel(cfg.model, trainer=trainer) logging.info( "\n\n************** Model parameters and their sizes ***********") for name, param in mt_model.named_parameters(): print(name, param.size()) logging.info( "***********************************************************\n\n") trainer.fit(mt_model)
def main(cfg): # Generate default asr model config asr_model_config = configs.EncDecCTCModelConfig() # Merge hydra updates with model config # `drop_missing_subconfig=True` is necessary here. Without it, while the data class will instantiate and be added # to the config, it contains test_ds.sample_rate = MISSING and test_ds.labels = MISSING. # This will raise a OmegaConf MissingMandatoryValue error when processing the dataloaders inside # model_utils.resolve_test_dataloaders(model=self) (used for multi data loader support). # In general, any operation that tries to use a DictConfig with MISSING in it will fail, # other than explicit update operations to change MISSING to some actual value. asr_model_config = update_model_config(asr_model_config, cfg, drop_missing_subconfigs=True) # From here on out, its a general OmegaConf DictConfig, directly usable by our code. trainer = pl.Trainer(**asr_model_config.trainer) exp_manager(trainer, asr_model_config.get("exp_manager", None)) asr_model = EncDecCTCModel(cfg=asr_model_config.model, trainer=trainer) trainer.fit(asr_model)
def test_dataclass_instantiation(self, asr_model): model_cfg = configs.EncDecCTCModelConfig() # Update mandatory values vocabulary = asr_model.decoder.vocabulary model_cfg.model.labels = vocabulary # Update encoder model_cfg.model.encoder.activation = 'relu' model_cfg.model.encoder.feat_in = 64 model_cfg.model.encoder.jasper = [ nemo_asr.modules.conv_asr.JasperEncoderConfig( filters=1024, repeat=1, kernel=[1], stride=[1], dilation=[1], dropout=0.0, residual=False, se=True, se_context_size=-1, ) ] # Update decoder model_cfg.model.decoder.feat_in = 1024 model_cfg.model.decoder.num_classes = 28 model_cfg.model.decoder.vocabulary = vocabulary # Construct the model asr_cfg = OmegaConf.create({'model': asr_model.cfg}) model_cfg_v1 = update_model_config(model_cfg, asr_cfg) new_model = EncDecCTCModel(cfg=model_cfg_v1.model) assert new_model.num_weights == asr_model.num_weights # trainer and exp manager should be there # assert 'trainer' in model_cfg_v1 # assert 'exp_manager' in model_cfg_v1 # datasets and optim/sched should not be there after ModelPT.update_model_dataclass() assert 'train_ds' not in model_cfg_v1.model assert 'validation_ds' not in model_cfg_v1.model assert 'test_ds' not in model_cfg_v1.model assert 'optim' not in model_cfg_v1.model # Construct the model, without dropping additional keys asr_cfg = OmegaConf.create({'model': asr_model.cfg}) model_cfg_v2 = update_model_config(model_cfg, asr_cfg, drop_missing_subconfigs=False) # Assert all components are in config # assert 'trainer' in model_cfg_v2 # assert 'exp_manager' in model_cfg_v2 assert 'train_ds' in model_cfg_v2.model assert 'validation_ds' in model_cfg_v2.model assert 'test_ds' in model_cfg_v2.model assert 'optim' in model_cfg_v2.model # Remove extra components (optim and sched can be kept without issue) with open_dict(model_cfg_v2.model): model_cfg_v2.model.pop('train_ds') model_cfg_v2.model.pop('validation_ds') model_cfg_v2.model.pop('test_ds') new_model = EncDecCTCModel(cfg=model_cfg_v2.model) assert new_model.num_weights == asr_model.num_weights