def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) plugins = [ NLPDDPPlugin( no_ddp_communication_hook=(megatron_amp_o2 and cfg.trainer.precision == 'bf16' ), # Only bf16 uses fp32_grad_accum. gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, find_unused_parameters=False, ) ] if cfg.trainer.precision in [16, 'bf16']: scaler = None if cfg.trainer.precision == 16: scaler = GradScaler( init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) if megatron_amp_o2: plugins.append( MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) else: plugins.append( NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) trainer = Trainer(plugins=plugins, **cfg.trainer) exp_manager(trainer, cfg.exp_manager) # update resume from checkpoint found by exp_manager resume_from_checkpoint = trainer.checkpoint_connector.resume_from_checkpoint_fit_path logging.info( f'Resuming training from checkpoint: {resume_from_checkpoint}') trainer.checkpoint_connector = CheckpointConnector( trainer, resume_from_checkpoint=resume_from_checkpoint) # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, ) # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision model = MegatronXNlIModel(cfg.model, trainer) trainer.fit(model) trainer.test(model)
def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) plugins = [ NLPDDPPlugin( num_nodes=cfg.trainer.num_nodes, no_ddp_communication_hook=(megatron_amp_o2 and cfg.trainer.precision == 'bf16' ), # Only bf16 uses fp32_grad_accum. gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, find_unused_parameters=False, ) ] if cfg.trainer.precision in [16, 'bf16']: scaler = None if cfg.trainer.precision == 16: scaler = GradScaler( init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) if megatron_amp_o2: plugins.append( MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) else: plugins.append( NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) trainer = Trainer(plugins=plugins, **cfg.trainer) exp_manager(trainer, cfg.exp_manager) # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, ) model = MegatronT5GLUEModel.restore_from( restore_path=cfg.model.restore_from_finetuned_path, trainer=trainer) model.freeze() trainer.validate(model)
def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') plugins = [NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes)] if cfg.trainer.precision == 16: scaler = GradScaler( init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), ) plugins.append( NativeMixedPrecisionPlugin(precision=16, device='cuda', scaler=scaler)) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) trainer = Trainer(plugins=plugins, **cfg.trainer) exp_manager(trainer, cfg.exp_manager) # update resume from checkpoint found by exp_manager resume_from_checkpoint = trainer.checkpoint_connector.resume_from_checkpoint_fit_path if resume_from_checkpoint is not None: # inject mp_rank into resume_from_checkpoint if cfg.model.tensor_model_parallel_size is not None and cfg.model.tensor_model_parallel_size > 1: mp_rank = compute_model_parallel_rank( trainer.local_rank, cfg.model.tensor_model_parallel_size) resume_from_checkpoint = Path(resume_from_checkpoint) resume_from_checkpoint = resume_from_checkpoint.parent.parent.joinpath( f'mp_rank_{mp_rank:02d}').joinpath(resume_from_checkpoint.name) resume_from_checkpoint = str(resume_from_checkpoint) logging.info( f'Resuming training from checkpoint: {resume_from_checkpoint}') trainer.checkpoint_connector = CheckpointConnector( trainer, resume_from_checkpoint=resume_from_checkpoint) # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, ) # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision model = MegatronT5Model(cfg.model, trainer) trainer.fit(model)
def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') plugins = [ NLPDDPPlugin( no_ddp_communication_hook=True, find_unused_parameters=False, ) ] if cfg.trainer.precision == 16: scaler = GradScaler( init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) plugins.append( PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) trainer = Trainer(plugins=plugins, **cfg.trainer) exp_manager(trainer, cfg.exp_manager) # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, ) # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision # load existing or init new soft prompt GPT model if cfg.model.get("restore_path", None): model = MegatronGPTPromptLearningModel.restore_from( cfg.model.restore_path, cfg.model, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector()) else: model = MegatronGPTPromptLearningModel(cfg.model, trainer=trainer) trainer.fit(model)
def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') plugins = [NLPDDPPlugin(find_unused_parameters=False)] if cfg.trainer.precision == 16: scaler = GradScaler( init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), ) plugins.append( NativeMixedPrecisionPlugin(precision=16, device='cuda', scaler=scaler)) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) trainer = Trainer(plugins=plugins, **cfg.trainer) exp_manager(trainer, cfg.exp_manager) # update resume from checkpoint found by exp_manager resume_from_checkpoint = trainer.checkpoint_connector.resume_from_checkpoint_fit_path logging.info( f'Resuming training from checkpoint: {resume_from_checkpoint}') trainer.checkpoint_connector = CheckpointConnector( trainer, resume_from_checkpoint=resume_from_checkpoint) # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, ) # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision model = MegatronBertModel(cfg.model, trainer) trainer.fit(model)
def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') # setup the data processor for processor_config in cfg.model.task_processors: processor = TemplateProcessor( template=processor_config.template, limit_length_field=processor_config.limit_length_field ) register_taskdata_processor(processor_config.taskname, processor) plugins = [NLPDDPPlugin()] if cfg.trainer.precision == 16: scaler = GradScaler( init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), ) plugins.append(NativeMixedPrecisionPlugin(precision=16, device='cuda', scaler=scaler)) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) trainer = Trainer(plugins=plugins, **cfg.trainer) exp_manager(trainer, cfg.exp_manager) # update resume from checkpoint found by exp_manager resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}') trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,) # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision model = MegatronT5PTuneModel(cfg.model, trainer) trainer.fit(model) if cfg.model.data.test_ds.file_path: logging.info("===========================================================================================") logging.info("Starting the testing of the trained model on test set...") trainer.test(model) logging.info("Testing finished!") logging.info("===========================================================================================") # extract the path of the best checkpoint from the training, you may update it to any checkpoint checkpoint_path = trainer.checkpoint_callback.best_model_path tensor_parallel_size = cfg.model.tensor_model_parallel_size pathobj = Path(checkpoint_path) checkpoint_folder = str(pathobj.parent) checkpoint_name = str(pathobj.name) rank = trainer.accelerator.training_type_plugin.local_rank if tensor_parallel_size > 1: # inject model parallel rank checkpoint_path = os.path.join(checkpoint_folder, f'mp_rank_{rank:02d}', checkpoint_name) else: checkpoint_path = os.path.join(checkpoint_folder, checkpoint_name) # Load the checkpoint best_eval_model = MegatronT5PTuneModel.load_from_checkpoint( checkpoint_path=checkpoint_path, strict=False, trainer=trainer ) logging.info(f'Best checkpoint path: {checkpoint_path}') logging.info("Running Test with best EVAL checkpoint!") # setup the test dataset # best_eval_model.setup_test_data(test_data_config=cfg.model.data.test_ds) if torch.distributed.is_initialized(): torch.distributed.barrier() trainer.test(model=best_eval_model, ckpt_path=None, verbose=False) logging.info("Beset EVAL Testing finished!") logging.info("===========================================================================================") if cfg.model.nemo_path: # '.nemo' file contains the last checkpoint and the params to initialize the model best_eval_model.save_to(cfg.model.nemo_path) logging.info(f'Model is saved into `.nemo` file: {cfg.model.nemo_path}') # perform inference on a list of queries. if "infer_samples" in cfg.model and cfg.model.infer_samples: logging.info("===========================================================================================") logging.info("Starting the inference on some sample queries...") # max_seq_length=512 is the maximum length BERT supports. results = best_eval_model.cuda().ptune_inference( queries=cfg.model.infer_samples, batch_size=1, decode_token_len=5 ) logging.info('The prediction results of some sample queries with the trained model:') for query, result in zip(cfg.model.infer_samples, results): logging.info(f'Query : {query}') logging.info(f'Predicted label: {result}') logging.info("Inference finished!") logging.info("===========================================================================================")
def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) plugins = [ NLPDDPPlugin( no_ddp_communication_hook=True, gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, find_unused_parameters=False, ) ] if cfg.trainer.precision in [16, 'bf16']: scaler = None if cfg.trainer.precision == 16: scaler = GradScaler( init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) if megatron_amp_o2: plugins.append( MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) else: plugins.append( PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) trainer = Trainer(plugins=plugins, **cfg.trainer, callbacks=[ModelSummary(max_depth=3)]) # tokenizers will be trained and and tarred training data will be created if needed # model config is then updated if cfg.model.preproc_out_dir is not None: MTDataPreproc(cfg=cfg.model, trainer=trainer) exp_manager(trainer, cfg.exp_manager) # update resume from checkpoint found by exp_manager if cfg.model.resume_from_checkpoint is not None: resume_from_checkpoint = cfg.model.resume_from_checkpoint else: resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path logging.info( f'Resuming training from checkpoint: {resume_from_checkpoint}') trainer._checkpoint_connector = CheckpointConnector( trainer, resume_from_checkpoint=resume_from_checkpoint) # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, ) # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision if hasattr(cfg.model, 'pretrained_model_path' ) and cfg.model.pretrained_model_path is not None: if not hasattr(cfg.model, 'pretrained_model_type'): raise ValueError(f"Pretrained model type must be in [T5, BART].") assert cfg.model.pretrained_model_type in ['T5', 'BART'] if cfg.model.pretrained_model_type == 'T5': pretrained_cfg = MegatronT5Model.restore_from( cfg.model.pretrained_model_path, trainer=trainer, return_config=True) else: pretrained_cfg = MegatronBARTModel.restore_from( cfg.model.pretrained_model_path, trainer=trainer, return_config=True) OmegaConf.set_struct(pretrained_cfg, True) with open_dict(pretrained_cfg): pretrained_cfg.masked_softmax_fusion = False # Set source and target language/multilingual pretrained_cfg.src_language = cfg.model.src_language pretrained_cfg.tgt_language = cfg.model.tgt_language pretrained_cfg.multilingual = cfg.model.multilingual pretrained_cfg.shared_tokenizer = True # Max generation delta pretrained_cfg.max_generation_delta = cfg.model.max_generation_delta # Set label smoothing pretrained_cfg.label_smoothing = cfg.model.label_smoothing # Set tokenizer paths: pretrained_cfg.encoder_tokenizer = pretrained_cfg.tokenizer pretrained_cfg.decoder_tokenizer = pretrained_cfg.tokenizer # Pre-trained models should use the legacy sentencepiece tokenizer ex: mT5 pretrained_cfg.encoder_tokenizer.sentencepiece_legacy = True pretrained_cfg.decoder_tokenizer.sentencepiece_legacy = True # Override dropout pretrained_cfg.hidden_dropout = cfg.model.hidden_dropout pretrained_cfg.attention_dropout = cfg.model.attention_dropout # Override precision pretrained_cfg.precision = cfg.model.precision # Set above from trainer.precision # Override data and global/micro batch size. pretrained_cfg.train_ds = cfg.model.train_ds pretrained_cfg.validation_ds = cfg.model.validation_ds pretrained_cfg.test_ds = cfg.model.test_ds pretrained_cfg.micro_batch_size = cfg.model.micro_batch_size pretrained_cfg.global_batch_size = cfg.model.global_batch_size # Class target for the new class being restored. pretrained_cfg.target = ( "nemo.collections.nlp.models.machine_translation.megatron_nmt_model.MegatronNMTModel" ) # Optimizer overrides. pretrained_cfg.optim = cfg.model.optim model = MegatronNMTModel.restore_from( cfg.model.pretrained_model_path, trainer=trainer, override_config_path=pretrained_cfg, save_restore_connector=NLPSaveRestoreConnector(), ) else: model = MegatronNMTModel(cfg.model, trainer) if cfg.do_training: trainer.fit(model) if cfg.do_testing: trainer.test(model)
def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) plugins = [ NLPDDPPlugin( no_ddp_communication_hook=True, gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, find_unused_parameters=False, ) ] if cfg.trainer.precision in [16, 'bf16']: scaler = None if cfg.trainer.precision == 16: scaler = GradScaler( init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) if megatron_amp_o2: plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) else: plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) trainer = Trainer(plugins=plugins, **cfg.trainer) exp_manager(trainer, cfg.exp_manager) # update resume from checkpoint found by exp_manager if cfg.model.resume_from_checkpoint is not None: resume_from_checkpoint = cfg.model.resume_from_checkpoint else: resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}') trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,) # Get the T5 Base configuration. t5_cfg = MegatronT5FinetuneModel.restore_from( restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True ) # Override the T5 configuration with the one from the config file. OmegaConf.set_struct(t5_cfg, True) with open_dict(t5_cfg): t5_cfg.masked_softmax_fusion = False t5_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) t5_cfg.hidden_dropout = cfg.model.get('hidden_dropout', 0.1) t5_cfg.attention_dropout = cfg.model.get('attention_dropout', 0.1) t5_cfg.data = cfg.model.data t5_cfg.precision = cfg.trainer.precision t5_cfg.optim = cfg.model.optim t5_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size t5_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size # XNLI has eval languages in the yaml config. if hasattr(cfg.model, 'eval_languages'): t5_cfg.eval_languages = cfg.model.eval_languages if hasattr(cfg.model.data.train_ds, 'task_name'): model = MegatronT5GLUEModel.restore_from( restore_path=cfg.model.restore_from_path, trainer=trainer, override_config_path=t5_cfg, save_restore_connector=NLPSaveRestoreConnector(), ) else: model = MegatronT5FinetuneModel.restore_from( restore_path=cfg.model.restore_from_path, trainer=trainer, override_config_path=t5_cfg, save_restore_connector=NLPSaveRestoreConnector(), ) trainer.fit(model) trainer.validate(model) if hasattr(cfg.model.data, 'test_ds'): trainer.test(model)
def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) plugins = [ NLPDDPPlugin( no_ddp_communication_hook=True, gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, find_unused_parameters=False, ) ] if cfg.trainer.precision in [16, 'bf16']: scaler = None if cfg.trainer.precision == 16: scaler = GradScaler( init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) if megatron_amp_o2: plugins.append( MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) else: plugins.append( PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) trainer = Trainer(plugins=plugins, **cfg.trainer) exp_manager(trainer, cfg.exp_manager) app_state = AppState() if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1: app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size ( app_state.tensor_model_parallel_rank, app_state.pipeline_model_parallel_rank, app_state.model_parallel_size, _, ) = fake_initialize_model_parallel( world_size=app_state.model_parallel_size, rank=trainer.global_rank, tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size, pipeline_model_parallel_size_=cfg.model. pipeline_model_parallel_size, ) # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, ) # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision model = MegatronGPTModel.restore_from(cfg.restore_from_path, cfg.model, trainer=trainer) trainer.fit(model)
def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) plugins = [ NLPDDPPlugin( no_ddp_communication_hook=True, gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, find_unused_parameters=False, ) ] if cfg.trainer.precision in [16, 'bf16']: scaler = None if cfg.trainer.precision == 16: scaler = GradScaler( init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) if megatron_amp_o2: plugins.append( MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) else: plugins.append( NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) trainer = Trainer(plugins=plugins, **cfg.trainer) exp_manager(trainer, cfg.exp_manager) # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time, ) # Get the T5 Base configuration. t5_cfg = MegatronT5GLUEModel.restore_from( restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True) # Override the T5 configuration with the one from the config file. # NOTE: Only data can be overriden here since this the file being restored here should already correspond to a GLUE/XNLI finetuned model. OmegaConf.set_struct(t5_cfg, True) with open_dict(t5_cfg): t5_cfg.masked_softmax_fusion = False t5_cfg.precision = cfg.trainer.precision # Overwrite data configs t5_cfg.data = cfg.model.data # XNLI has eval languages in the yaml config. if hasattr(cfg.model, 'eval_languages'): t5_cfg.eval_languages = cfg.model.eval_languages if hasattr(t5_cfg.data.validation_ds, 'task_name'): model = MegatronT5GLUEModel.restore_from( restore_path=cfg.model.restore_from_path, trainer=trainer, override_config_path=t5_cfg) else: model = MegatronT5FinetuneModel.restore_from( restore_path=cfg.model.restore_from_path, trainer=trainer, override_config_path=t5_cfg) model.freeze() trainer.validate(model) if hasattr(cfg.model.data, 'test_ds'): trainer.test(model)