예제 #1
0
class MTBottleneckConfig(NemoConfig):
    name: Optional[str] = 'MTBottleneck'
    do_training: bool = True
    do_testing: bool = False
    model: MTBottleneckModelConfig = MTBottleneckModelConfig()
    trainer: Optional[TrainerConfig] = TrainerConfig()
    exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(name='MTBottleneck', files_to_copy=[])
예제 #2
0
 def setup_model(self):
     # Stateless timer for 3 seconds.
     # Max steps shouldn't matter for it should stop in 3 seconds based on the timer.
     # Val check interval makes sure a checkpoint is written and can be restored from.
     callback_params = CallbackParams()
     callback_params.monitor = "val_loss"
     callback_params.save_top_k = 1
     trainer = Trainer(
         devices=1,
         val_check_interval=5,
         max_steps=10000,
         accelerator='gpu',
         strategy='ddp',
         logger=None,
         callbacks=[StatelessTimer('00:00:00:03')],
         checkpoint_callback=False,
     )
     exp_manager_cfg = ExpManagerConfig(
         explicit_log_dir='./ptl_stateless_timer_check/',
         use_datetime_version=False,
         version="",
         resume_ignore_no_checkpoint=True,
         create_checkpoint_callback=True,
         checkpoint_callback_params=callback_params,
         resume_if_exists=True,
     )
     exp_manager(trainer, cfg=OmegaConf.structured(exp_manager_cfg))
     model = ExampleModel(trainer=trainer)
     trainer.fit(model)
     return trainer
class PunctuationCapitalizationConfig(NemoConfig):
    """
    A config for punctuation model training and testing.

    See an example of full config in
    `nemo/examples/nlp/token_classification/conf/punctuation_capitalization_config.yaml
    <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/token_classification/conf/punctuation_capitalization_config.yaml>`_
    """

    pretrained_model: Optional[str] = None
    """Can be an NVIDIA's NGC cloud model or a path to a .nemo checkpoint. You can get list of possible cloud options
    by calling method
    :func:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_model.PunctuationCapitalizationModel.list_available_models`.
    """

    name: Optional[str] = 'Punctuation_and_Capitalization'
    """A name of the model. Used for naming output directories and ``.nemo`` checkpoints."""

    do_training: bool = True
    """Whether to perform training of the model."""

    do_testing: bool = False
    """Whether ot perform testing of the model."""

    model: PunctuationCapitalizationModelConfig = PunctuationCapitalizationModelConfig()
    """A configuration for the
    :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_model.PunctuationCapitalizationModel`
    model."""

    trainer: Optional[TrainerConfig] = TrainerConfig()
    """Contains ``Trainer`` Lightning class constructor parameters."""

    exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(name=name, files_to_copy=[])
    """A configuration with various NeMo training options such as output directories, resuming from checkpoint,
예제 #4
0
def instantiate_multinode_ddp_if_possible():
    num_gpus = torch.cuda.device_count()
    trainer = Trainer(gpus=num_gpus,
                      accelerator='ddp',
                      logger=None,
                      checkpoint_callback=None)

    exp_manager_cfg = ExpManagerConfig(exp_dir='./ddp_check/',
                                       use_datetime_version=False,
                                       version="")
    exp_manager(trainer, cfg=OmegaConf.structured(exp_manager_cfg))
    return trainer
예제 #5
0
파일: enc_dec_nmt.py 프로젝트: zt706/NeMo
class MTEncDecConfig(NemoConfig):
    model: AAYNBaseConfig = AAYNBaseConfig()
    trainer: Optional[TrainerConfig] = TrainerConfig()
    exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(
        name='MTEncDec', files_to_copy=[])
예제 #6
0
class MTEncDecConfig(NemoConfig):
    name: Optional[str] = 'MTEncDec'
    do_training: bool = True
    model: MTEncDecModelConfig = MTEncDecModelConfig()
    trainer: Optional[TrainerConfig] = TrainerConfig()
    exp_manager: Optional[ExpManagerConfig] = ExpManagerConfig(name='MTEncDec', files_to_copy=[])