def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object if not gfile.exists(folderpath): makedirs(folderpath) # save logger to make sure we get all the metrics logger.save() ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 if not gfile.exists(folderpath): makedirs(folderpath) filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save model = self.get_model() checkpoint = self.dump_checkpoint() model.on_hpc_save(checkpoint) # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: atomic_save(checkpoint, filepath) except AttributeError as err: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( 'warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}') atomic_save(checkpoint, filepath) return filepath
def save_checkpoint(self, filepath, weights_only: bool): """Slightly modified version of PyTorch Lightning's save_checkpoint. Args: filepath ([str]): [description] weights_only (bool): [description] Returns: [type]: [description] """ app_state = AppState() if app_state.model_parallel_size is not None: # filepath needs to be updated to include mp_rank dirname = os.path.dirname(filepath) basename = os.path.basename(filepath) filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}' # dump states as a checkpoint dictionary object checkpoint = self.dump_checkpoint(weights_only) # each model parallel rank needs to save a copy of its model if app_state.data_parallel_rank == 0: # write the checkpoint dictionary on the file if self.trainer.accelerator_backend: checkpoint = self.trainer.accelerator_backend.on_save(checkpoint) try: atomic_save(checkpoint, filepath) except AttributeError as err: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( 'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}' ) atomic_save(checkpoint, filepath) return None
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state path: write-target path storage_options: not used in ``XLACheckpointIO.save_checkpoint`` Raises: TypeError: If ``storage_options`` arg is passed in """ if storage_options is not None: raise TypeError( "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" " to define how you'd like to use `storage_options`.") fs = get_filesystem(path) fs.makedirs(os.path.dirname(path), exist_ok=True) checkpoint = move_data_to_device(checkpoint, torch.device("cpu")) # write the checkpoint dictionary to the provided path atomic_save(checkpoint, path)
def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): if self.trainer.distributed_backend.lower() not in [ 'ddp_spawn', 'ddp_cpu', 'tpu' ]: return # track the best model path best_model_path = None if self.trainer.checkpoint_callback is not None: best_model_path = self.trainer.checkpoint_callback.best_model_path if self.trainer.global_rank == 0 and mp_queue is not None: rank_zero_warn('cleaning up ddp environment...') # todo, pass complete checkpoint as state dictionary mp_queue.put(best_model_path) mp_queue.put(results) # save the last weights last_path = None if not self.trainer.testing and best_model_path is not None and len( best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) atomic_save(model.state_dict(), last_path) mp_queue.put(last_path)
def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir): os.environ["LOCAL_RANK"] = str(rank) if torch.distributed.is_available( ) and not torch.distributed.is_initialized(): torch.distributed.init_process_group("gloo", rank=rank, world_size=2) to_device = partial(move_data_to_device, device=torch.device("cuda", rank)) model = DistributedDataParallel( to_device(model), device_ids=[rank], ) train_dataloader = DataLoader( train_dataloader.dataset, sampler=DistributedSampler(train_dataloader.dataset, rank=rank, num_replicas=2, seed=42, drop_last=False), ) with precision_context(precision, accelerator): main(to_device, model, train_dataloader, num_epochs=num_epochs) if rank == 0: atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt"))
def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object fs = get_filesystem(folderpath) fs.makedirs(folderpath, exist_ok=True) # save logger to make sure we get all the metrics logger.save() max_suffix = self.max_ckpt_version_in_folder(folderpath) ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 fs.makedirs(folderpath, exist_ok=True) filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt") # give model a chance to do something on hpc_save model = self.trainer.lightning_module checkpoint = self.dump_checkpoint() model.on_hpc_save(checkpoint) # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: atomic_save(checkpoint, filepath) except AttributeError as err: if pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( f"warning, `hyper_parameters` dropped from checkpoint. An attribute is not picklable {err}" ) atomic_save(checkpoint, filepath) return filepath
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state path: write-target path storage_options: not used in ``TorchCheckpointIO.save_checkpoint`` Raises: TypeError: If ``storage_options`` arg is passed in """ if storage_options is not None: raise TypeError( "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" " to define how you'd like to use `storage_options`.") fs = get_filesystem(path) fs.makedirs(os.path.dirname(path), exist_ok=True) try: # write the checkpoint dictionary on the file atomic_save(checkpoint, path) except AttributeError as err: # todo (sean): is this try catch necessary still? # https://github.com/Lightning-AI/lightning/pull/431 key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY checkpoint.pop(key, None) rank_zero_warn( f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}" ) atomic_save(checkpoint, path)
def transfer_distrib_spawn_state_on_fit_end(self, results): checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present state_dict = self.lightning_module.state_dict() if self.global_rank == 0 and self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") # save the last weights last_path = None if ( self.lightning_module.trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0 ): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) atomic_save(self.on_save(state_dict), last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) self.mp_queue.put(last_path) self.mp_queue.put(results) self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: fs = get_filesystem(path) fs.makedirs(os.path.dirname(path), exist_ok=True) try: # write the checkpoint dictionary on the file atomic_save(checkpoint, path) except AttributeError as err: # todo (sean): is this try catch necessary still? # https://github.com/PyTorchLightning/pytorch-lightning/pull/431 key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY checkpoint.pop(key, None) rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") atomic_save(checkpoint, path)
def save_checkpoint(self, filepath, weights_only: bool = False): checkpoint = self.dump_checkpoint(weights_only) if self.is_global_zero: # do the actual save try: atomic_save(checkpoint, filepath) except AttributeError as err: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( 'Warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}') atomic_save(checkpoint, filepath)
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path """ # dump states as a checkpoint dictionary object checkpoint = self.on_save(checkpoint) if self.is_global_zero: try: # write the checkpoint dictionary on the file atomic_save(checkpoint, filepath) except AttributeError as err: key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY checkpoint.pop(key, None) rank_zero_warn(f'Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}') atomic_save(checkpoint, filepath)
def run(self, model: nn.Module, train_dataloader: DataLoader, num_epochs: int = 10, tmpdir: str = None): optimizer = configure_optimizers(model) model, optimizer = self.setup(model, optimizer) train_dataloader = self.setup_dataloaders(train_dataloader) model.train() for _ in range(num_epochs): for batch in train_dataloader: batch = self.to_device(batch) optimizer.zero_grad() loss = model(batch) self.backward(loss) optimizer.step() if isinstance(self._strategy, DDPSpawnPlugin) and tmpdir and self.global_rank == 0: checkpoint_path = os.path.join(tmpdir, "model.pt") atomic_save(model.state_dict(), checkpoint_path) return checkpoint_path
def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing callback through model -> trainer -> callback? checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None if self.global_rank == 0 and self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") # save the last weights last_path = None # TODO: is there a better way than accessing trainer through model -> trainer? if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) self.mp_queue.put(last_path) self.mp_queue.put(results)
def transfer_distrib_spawn_state_on_fit_end(self, results): checkpoint_callback = self.lightning_module.trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None if self.global_rank == 0 and self.mp_queue is not None: rank_zero_warn("cleaning up ddp environment...") # save the last weights last_path = None if (self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None and len(best_model_path) > 0): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) self.mp_queue.put(last_path) self.mp_queue.put(results)
def save_checkpoint(self, filepath, weights_only: bool = False): """Save model/training states as a checkpoint file through state-dump and file-write. Args: filepath: write-target file's path weights_only: saving model weights only """ # dump states as a checkpoint dictionary object checkpoint = self.dump_checkpoint(weights_only) if self.trainer.is_global_zero: # write the checkpoint dictionary on the file try: atomic_save(checkpoint, filepath) except AttributeError as err: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( 'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}') atomic_save(checkpoint, filepath)
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path """ # dump states as a checkpoint dictionary object if self.is_global_zero: checkpoint = self.on_save(checkpoint) try: # write the checkpoint dictionary on the file atomic_save(checkpoint, filepath) except AttributeError as err: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( 'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}') atomic_save(checkpoint, filepath)
TGT_LANG = "de" ENSEMBLE = 10 LAST_CKPT = 63 trafo_states = [] print('Loading...') for epoch in tqdm(range(LAST_CKPT - ENSEMBLE, LAST_CKPT)): ckpt_path = os.path.join(MODELS_FOLDER, 'transformer', f'trafo_{SRC_LANG}_{TGT_LANG}_{epoch}.pt') trafo_states.append( pl_load(ckpt_path, map_location=lambda storage, loc: storage)['state_dict']) "Average models into model" print("Averaging...") avg_state = {} for key in trafo_states[-1]: mean = 0 for trafo_state in trafo_states: mean += trafo_state[key] avg_state[key] = mean / len(trafo_states) print('saving...') ckpt_path = os.path.join(MODELS_FOLDER, 'transformer', f'trafo_{SRC_LANG}_{TGT_LANG}_{LAST_CKPT}.pt') avg_ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) avg_ckpt['state_dict'] = avg_state atomic_save( avg_ckpt, f'.data/models/transformer/trafo_{SRC_LANG}_{TGT_LANG}_ensemble.pt')
def train(config): # ====================================================== # EXPERIMENT SETUP # ====================================================== from pytorch_lightning import seed_everything # Seed seed_everything(config.seed) # DATASET SETUP print("======================================================") print("SETTING UP DATASET") print("======================================================") from ml4floods.models.dataset_setup import get_dataset dataset = get_dataset(config.data_params) # MODEL SETUP print("======================================================") print("SETTING UP MODEL") print("======================================================") from ml4floods.models.model_setup import get_model config.model_params.test = False config.model_params.train = True model = get_model(config.model_params) # LOGGING SETUP print("======================================================") print("SETTING UP LOGGERS") print("======================================================") import wandb from pytorch_lightning.loggers import WandbLogger wandb_logger = WandbLogger( name=config.experiment_name, project=config.wandb_project, entity=config.wandb_entity, # save_dir=f"{config.model_params.model_folder}/{config.experiment_name}" ) # CHECKPOINTING SETUP print("======================================================") print("SETTING UP CHECKPOINTING") print("======================================================") from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping experiment_path = f"{config.model_params.model_folder}/{config.experiment_name}" checkpoint_path = f"{experiment_path}/checkpoint" checkpoint_callback = ModelCheckpoint( dirpath=checkpoint_path, save_top_k=True, verbose=True, monitor=config.model_params.hyperparameters.metric_monitor, mode='min', prefix='') early_stop_callback = EarlyStopping( monitor=config.model_params.hyperparameters.metric_monitor, patience=4, strict=False, verbose=False, mode='min') callbacks = [checkpoint_callback, early_stop_callback] # TRAINING SETUP print("======================================================") print("START TRAINING") print("======================================================") from pytorch_lightning import Trainer trainer = Trainer( fast_dev_run=False, logger=wandb_logger, callbacks=callbacks, default_root_dir= f"{config.model_params.model_folder}/{config.experiment_name}", accumulate_grad_batches=1, gradient_clip_val=0.0, auto_lr_find=False, benchmark=False, distributed_backend=None, gpus=config.gpus if config.gpus != '' else None, max_epochs=config.model_params.hyperparameters.max_epochs, check_val_every_n_epoch=config.model_params.hyperparameters.val_every, log_gpu_memory=None, resume_from_checkpoint=checkpoint_path if config.resume_from_checkpoint else None) trainer.fit(model, dataset) # ====================================================== # SAVING SETUP # ====================================================== print("======================================================") print("FINISHED TRAINING, SAVING MODEL") print("======================================================") from pytorch_lightning.utilities.cloud_io import atomic_save atomic_save(model.state_dict(), f"{experiment_path}/model.pt") torch.save(model.state_dict(), os.path.join(wandb_logger.save_dir, 'model.pt')) wandb.save(os.path.join(wandb_logger.save_dir, 'model.pt')) wandb.finish() # Save cofig file in experiment_path config_file_path = f"{experiment_path}/config.json" save_json(config, config_file_path) return 1