Пример #1
0
def main(cfg: MTBottleneckConfig) -> None:
    # merge default config with user specified config
    default_cfg = MTBottleneckConfig()
    cfg = update_model_config(default_cfg, cfg)
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'Config: {OmegaConf.to_yaml(cfg)}')

    # training is managed by PyTorch Lightning
    trainer_cfg = OmegaConf.to_container(cfg.trainer)
    trainer_cfg.pop('plugins', None)
    trainer = Trainer(plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes)], **trainer_cfg)

    # tokenizers will be trained and and tarred training data will be created if needed
    # model config is then updated
    if cfg.model.preproc_out_dir is not None:
        MTDataPreproc(cfg=cfg.model, trainer=trainer)

    # experiment logs, checkpoints, and auto-resume are managed by exp_manager and PyTorch Lightning
    exp_manager(trainer, cfg.exp_manager)

    # everything needed to train translation models is encapsulated in the NeMo MTEncdDecModel
    mt_model = MTBottleneckModel(cfg.model, trainer=trainer)

    logging.info("\n\n************** Model parameters and their sizes ***********")
    for name, param in mt_model.named_parameters():
        print(name, param.size())
    logging.info("***********************************************************\n\n")

    if cfg.do_training:
        trainer.fit(mt_model)

    if cfg.do_testing:
        trainer.test(mt_model)
Пример #2
0
def main(cfg: MTEncDecConfig) -> None:
    # merge default config with user specified config
    default_cfg = MTEncDecConfig()
    cfg = update_model_config(default_cfg, cfg)
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'Config: {cfg.pretty()}')

    # training is managed by PyTorch Lightning
    trainer = Trainer(**cfg.trainer)

    # tokenizers will be trained and and tarred training data will be created if needed
    # model config is then updated
    MTDataPreproc(cfg=cfg.model, trainer=trainer)

    if cfg.do_training:
        # experiment logs, checkpoints, and auto-resume are managed by exp_manager and PyTorch Lightning
        exp_manager(trainer, cfg.exp_manager)

        # everything needed to train translation models is encapsulated in the NeMo MTEncdDecModel
        mt_model = MTEncDecModel(cfg.model, trainer=trainer)

        logging.info("\n\n************** Model parameters and their sizes ***********")
        for name, param in mt_model.named_parameters():
            print(name, param.size())
        logging.info("***********************************************************\n\n")

        trainer.fit(mt_model)
Пример #3
0
def main(cfg: MTEncDecConfig) -> None:
    # # merge default config with user specified config
    default_cfg = MTEncDecConfig()
    cfg = update_model_config(default_cfg, cfg)
    logging.info("\n\n************** Experiment configuration ***********")
    logging.info(f'Config: {cfg.pretty()}')

    trainer = Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.exp_manager)
    mt_model = MTEncDecModel(cfg.model, trainer=trainer)

    logging.info(
        "\n\n************** Model parameters and their sizes ***********")
    for name, param in mt_model.named_parameters():
        print(name, param.size())
    logging.info(
        "***********************************************************\n\n")
    trainer.fit(mt_model)
Пример #4
0
def main(cfg):
    # Generate default asr model config
    asr_model_config = configs.EncDecCTCModelConfig()

    # Merge hydra updates with model config
    # `drop_missing_subconfig=True` is necessary here. Without it, while the data class will instantiate and be added
    # to the config, it contains test_ds.sample_rate = MISSING and test_ds.labels = MISSING.
    # This will raise a OmegaConf MissingMandatoryValue error when processing the dataloaders inside
    # model_utils.resolve_test_dataloaders(model=self) (used for multi data loader support).
    # In general, any operation that tries to use a DictConfig with MISSING in it will fail,
    # other than explicit update operations to change MISSING to some actual value.
    asr_model_config = update_model_config(asr_model_config,
                                           cfg,
                                           drop_missing_subconfigs=True)

    # From here on out, its a general OmegaConf DictConfig, directly usable by our code.
    trainer = pl.Trainer(**asr_model_config.trainer)
    exp_manager(trainer, asr_model_config.get("exp_manager", None))
    asr_model = EncDecCTCModel(cfg=asr_model_config.model, trainer=trainer)

    trainer.fit(asr_model)
Пример #5
0
    def test_dataclass_instantiation(self, asr_model):
        model_cfg = configs.EncDecCTCModelConfig()

        # Update mandatory values
        vocabulary = asr_model.decoder.vocabulary
        model_cfg.model.labels = vocabulary

        # Update encoder
        model_cfg.model.encoder.activation = 'relu'
        model_cfg.model.encoder.feat_in = 64
        model_cfg.model.encoder.jasper = [
            nemo_asr.modules.conv_asr.JasperEncoderConfig(
                filters=1024,
                repeat=1,
                kernel=[1],
                stride=[1],
                dilation=[1],
                dropout=0.0,
                residual=False,
                se=True,
                se_context_size=-1,
            )
        ]

        # Update decoder
        model_cfg.model.decoder.feat_in = 1024
        model_cfg.model.decoder.num_classes = 28
        model_cfg.model.decoder.vocabulary = vocabulary

        # Construct the model
        asr_cfg = OmegaConf.create({'model': asr_model.cfg})
        model_cfg_v1 = update_model_config(model_cfg, asr_cfg)
        new_model = EncDecCTCModel(cfg=model_cfg_v1.model)

        assert new_model.num_weights == asr_model.num_weights
        # trainer and exp manager should be there
        # assert 'trainer' in model_cfg_v1
        # assert 'exp_manager' in model_cfg_v1
        # datasets and optim/sched should not be there after ModelPT.update_model_dataclass()
        assert 'train_ds' not in model_cfg_v1.model
        assert 'validation_ds' not in model_cfg_v1.model
        assert 'test_ds' not in model_cfg_v1.model
        assert 'optim' not in model_cfg_v1.model

        # Construct the model, without dropping additional keys
        asr_cfg = OmegaConf.create({'model': asr_model.cfg})
        model_cfg_v2 = update_model_config(model_cfg,
                                           asr_cfg,
                                           drop_missing_subconfigs=False)

        # Assert all components are in config
        # assert 'trainer' in model_cfg_v2
        # assert 'exp_manager' in model_cfg_v2
        assert 'train_ds' in model_cfg_v2.model
        assert 'validation_ds' in model_cfg_v2.model
        assert 'test_ds' in model_cfg_v2.model
        assert 'optim' in model_cfg_v2.model

        # Remove extra components (optim and sched can be kept without issue)
        with open_dict(model_cfg_v2.model):
            model_cfg_v2.model.pop('train_ds')
            model_cfg_v2.model.pop('validation_ds')
            model_cfg_v2.model.pop('test_ds')

        new_model = EncDecCTCModel(cfg=model_cfg_v2.model)

        assert new_model.num_weights == asr_model.num_weights