def dump_checkpoint(self): checkpoint = { 'epoch': self.current_epoch + 1, 'global_step': self.global_step + 1, } if self.checkpoint_callback is not None and self.checkpoint_callback is not False: checkpoint[ 'checkpoint_callback_best'] = self.checkpoint_callback.best if self.early_stop_callback is not None and self.checkpoint_callback is not False: checkpoint[ 'early_stop_callback_wait'] = self.early_stop_callback.wait checkpoint[ 'early_stop_callback_patience'] = self.early_stop_callback.patience # save optimizers optimizer_states = [] for i, optimizer in enumerate(self.optimizers): optimizer_states.append(optimizer.state_dict()) checkpoint['optimizer_states'] = optimizer_states # save lr schedulers lr_schedulers = [] for scheduler in self.lr_schedulers: lr_schedulers.append(scheduler['scheduler'].state_dict()) checkpoint['lr_schedulers'] = lr_schedulers # add the hparams and state_dict from the model model = self.get_model() checkpoint['state_dict'] = model.state_dict() # save native amp scaling if self.use_amp and self.use_native_amp: checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() if hasattr(model, "hparams") and model.hparams is not None: parsing.clean_namespace(model.hparams) if isinstance(model.hparams, dict): checkpoint['hparams_type'] = 'dict' checkpoint['hparams'] = model.hparams elif isinstance(model.hparams, Namespace): checkpoint['hparams_type'] = 'Namespace' checkpoint['hparams'] = vars(model.hparams) else: raise ValueError( 'The acceptable hparams type is dict or argparse.Namespace,', f' not {checkpoint["hparams_type"]}') else: rank_zero_warn( "Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters." ) # give the model a chance to add a few things model.on_save_checkpoint(checkpoint) return checkpoint
def test_clean_namespace(tmpdir): # See the full list of picklable types at # https://docs.python.org/3/library/pickle.html#pickle-picklable class UnpicklableClass: # Only classes defined at the top level of a module are picklable. pass test_case = {"1": None, "2": True, "3": 123, "4": unpicklable_function, "5": UnpicklableClass} clean_namespace(test_case) assert test_case == {"1": None, "2": True, "3": 123}
def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) # links data to the trainer self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) # check that model is configured correctly self.trainer.config_validator.verify_loop_configurations(model) # attach model log function to callback self.trainer.callback_connector.attach_model_logging_functions(model)
def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # bind logger and other properties self.trainer.model_connector.copy_trainer_model_properties(model) # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) # links data to the trainer self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) # check that model is configured correctly self.trainer.config_validator.verify_loop_configurations(model)
def fit(self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None): r""" Runs the full optimization routine. Args: model: Model to fit. train_dataloader: A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped Example:: # Option 1, # Define the train_dataloader() and val_dataloader() fxs # in the lightningModule # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY trainer = Trainer() model = LightningModule() trainer.fit(model) # Option 2 # in production cases we might want to pass different datasets to the same model # Recommended for PRODUCTION SYSTEMS train, val = DataLoader(...), DataLoader(...) trainer = Trainer() model = LightningModule() trainer.fit(model, train_dataloader=train, val_dataloader=val) # Option 1 & 2 can be mixed, for example the training set can be # defined as part of the model, and validation can then be feed to .fit() """ # bind logger and other properties model.logger = self.logger self.copy_trainer_model_properties(model) # clean hparams if hasattr(model, 'hparams'): parsing.clean_namespace(model.hparams) # set up the passed in dataloaders (if needed) self.__attach_dataloaders(model, train_dataloader, val_dataloaders) # check that model is configured correctly self.check_model_configuration(model) # download the data and do whatever transforms we need # do before any spawn calls so that the model can assign properties # only on proc 0 because no spawn has happened yet if not self._is_data_prepared: model.prepare_data() self._is_data_prepared = True # Run auto batch size scaling if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): self.auto_scale_batch_size = 'power' self.scale_batch_size(model, mode=self.auto_scale_batch_size) model.logger = self.logger # reset logger binding # Run learning rate finder: if self.auto_lr_find: self._run_lr_finder_internally(model) model.logger = self.logger # reset logger binding # route to appropriate start method # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp2: if self.is_slurm_managing_tasks: task = int(os.environ['SLURM_LOCALID']) elif 'WORLD_SIZE' in os.environ and 'GROUP_RANK' in os.environ: task = int(os.environ['LOCAL_RANK']) self.ddp_train(task, model) elif self.use_ddp: if self.is_slurm_managing_tasks: task = int(os.environ['SLURM_LOCALID']) self.ddp_train(task, model) # torchelastic elif 'WORLD_SIZE' in os.environ and 'GROUP_RANK' in os.environ: task = int(os.environ['LOCAL_RANK']) self.ddp_train(task, model) else: self.__set_random_port() # track for predict self.model = model # train mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, )) # load weights if not interrupted if self.on_colab_kaggle: self.load_spawn_weights(model) self.model = model # 1 gpu or dp option triggers training using DP module # easier to avoid NCCL issues elif self.use_dp: self.dp_train(model) elif self.use_horovod: self.horovod_train(model) elif self.single_gpu: self.single_gpu_train(model) elif self.use_tpu: # pragma: no-cover log.info(f'training on {self.tpu_cores} TPU cores') # COLAB_GPU is an env var available by default in Colab environments. start_method = 'fork' if self.on_colab_kaggle else 'spawn' # track for predict self.model = model # train if self.tpu_id is not None: self.tpu_train(model) else: xmp.spawn(self.tpu_train, args=(model, ), nprocs=self.tpu_cores, start_method=start_method) # load weights if not interrupted self.load_spawn_weights(model) self.model = model # ON CPU else: # run through amp wrapper if self.use_amp: raise MisconfigurationException( 'amp + cpu is not supported. Please use a GPU option') # CHOOSE OPTIMIZER # allow for lr schedulers as well self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers( model) self.run_pretrain_routine(model) # return 1 when finished # used for testing or when we need to know that training succeeded return 1