Example #1
0
    def hpc_save(self, folderpath: str, logger):
        # make sure the checkpoint folder exists
        folderpath = str(folderpath)  # because the tests pass a path object
        if not gfile.exists(folderpath):
            makedirs(folderpath)

        # save logger to make sure we get all the metrics
        logger.save()

        ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

        if not gfile.exists(folderpath):
            makedirs(folderpath)
        filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

        # give model a chance to do something on hpc_save
        model = self.get_model()
        checkpoint = self.dump_checkpoint()

        model.on_hpc_save(checkpoint)

        # do the actual save
        # TODO: fix for anything with multiprocess DP, DDP, DDP2
        try:
            atomic_save(checkpoint, filepath)
        except AttributeError as err:
            if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
                del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
            rank_zero_warn(
                'warning, `module_arguments` dropped from checkpoint.'
                f' An attribute is not picklable {err}')
            atomic_save(checkpoint, filepath)

        return filepath
Example #2
0
    def save_checkpoint(self, filepath, weights_only: bool):
        """Slightly modified version of PyTorch Lightning's save_checkpoint.

        Args:
            filepath ([str]): [description]
            weights_only (bool): [description]

        Returns:
            [type]: [description]
        """
        app_state = AppState()
        if app_state.model_parallel_size is not None:
            # filepath needs to be updated to include mp_rank
            dirname = os.path.dirname(filepath)
            basename = os.path.basename(filepath)
            filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'

            # dump states as a checkpoint dictionary object
            checkpoint = self.dump_checkpoint(weights_only)

            # each model parallel rank needs to save a copy of its model
            if app_state.data_parallel_rank == 0:
                # write the checkpoint dictionary on the file
                if self.trainer.accelerator_backend:
                    checkpoint = self.trainer.accelerator_backend.on_save(checkpoint)
                try:
                    atomic_save(checkpoint, filepath)
                except AttributeError as err:
                    if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
                        del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
                    rank_zero_warn(
                        'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}'
                    )
                    atomic_save(checkpoint, filepath)
        return None
    def save_checkpoint(self,
                        checkpoint: Dict[str, Any],
                        path: _PATH,
                        storage_options: Optional[Any] = None) -> None:
        """Save model/training states as a checkpoint file through state-dump and file-write.

        Args:
            checkpoint: dict containing model and trainer state
            path: write-target path
            storage_options: not used in ``XLACheckpointIO.save_checkpoint``

        Raises:
            TypeError:
                If ``storage_options`` arg is passed in
        """
        if storage_options is not None:
            raise TypeError(
                "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
                f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
                " to define how you'd like to use `storage_options`.")
        fs = get_filesystem(path)
        fs.makedirs(os.path.dirname(path), exist_ok=True)

        checkpoint = move_data_to_device(checkpoint, torch.device("cpu"))
        # write the checkpoint dictionary to the provided path
        atomic_save(checkpoint, path)
    def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue,
                                                results):
        if self.trainer.distributed_backend.lower() not in [
                'ddp_spawn', 'ddp_cpu', 'tpu'
        ]:
            return

        # track the best model path
        best_model_path = None
        if self.trainer.checkpoint_callback is not None:
            best_model_path = self.trainer.checkpoint_callback.best_model_path

        if self.trainer.global_rank == 0 and mp_queue is not None:
            rank_zero_warn('cleaning up ddp environment...')
            # todo, pass complete checkpoint as state dictionary
            mp_queue.put(best_model_path)
            mp_queue.put(results)

            # save the last weights
            last_path = None
            if not self.trainer.testing and best_model_path is not None and len(
                    best_model_path) > 0:
                last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
                atomic_save(model.state_dict(), last_path)
            mp_queue.put(last_path)
def run(rank, model, train_dataloader, num_epochs, precision, accelerator,
        tmpdir):
    os.environ["LOCAL_RANK"] = str(rank)
    if torch.distributed.is_available(
    ) and not torch.distributed.is_initialized():
        torch.distributed.init_process_group("gloo", rank=rank, world_size=2)

    to_device = partial(move_data_to_device, device=torch.device("cuda", rank))
    model = DistributedDataParallel(
        to_device(model),
        device_ids=[rank],
    )
    train_dataloader = DataLoader(
        train_dataloader.dataset,
        sampler=DistributedSampler(train_dataloader.dataset,
                                   rank=rank,
                                   num_replicas=2,
                                   seed=42,
                                   drop_last=False),
    )
    with precision_context(precision, accelerator):
        main(to_device, model, train_dataloader, num_epochs=num_epochs)

    if rank == 0:
        atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt"))
    def hpc_save(self, folderpath: str, logger):
        # make sure the checkpoint folder exists
        folderpath = str(folderpath)  # because the tests pass a path object
        fs = get_filesystem(folderpath)
        fs.makedirs(folderpath, exist_ok=True)

        # save logger to make sure we get all the metrics
        logger.save()

        max_suffix = self.max_ckpt_version_in_folder(folderpath)
        ckpt_number = (max_suffix if max_suffix is not None else 0) + 1

        fs.makedirs(folderpath, exist_ok=True)
        filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt")

        # give model a chance to do something on hpc_save
        model = self.trainer.lightning_module
        checkpoint = self.dump_checkpoint()

        model.on_hpc_save(checkpoint)

        # do the actual save
        # TODO: fix for anything with multiprocess DP, DDP, DDP2
        try:
            atomic_save(checkpoint, filepath)
        except AttributeError as err:
            if pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
                del checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
            rank_zero_warn(
                f"warning, `hyper_parameters` dropped from checkpoint. An attribute is not picklable {err}"
            )
            atomic_save(checkpoint, filepath)

        return filepath
    def save_checkpoint(self,
                        checkpoint: Dict[str, Any],
                        path: _PATH,
                        storage_options: Optional[Any] = None) -> None:
        """Save model/training states as a checkpoint file through state-dump and file-write.

        Args:
            checkpoint: dict containing model and trainer state
            path: write-target path
            storage_options: not used in ``TorchCheckpointIO.save_checkpoint``

        Raises:
            TypeError:
                If ``storage_options`` arg is passed in
        """
        if storage_options is not None:
            raise TypeError(
                "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
                f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
                " to define how you'd like to use `storage_options`.")
        fs = get_filesystem(path)
        fs.makedirs(os.path.dirname(path), exist_ok=True)
        try:
            # write the checkpoint dictionary on the file
            atomic_save(checkpoint, path)
        except AttributeError as err:
            # todo (sean): is this try catch necessary still?
            # https://github.com/Lightning-AI/lightning/pull/431
            key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
            checkpoint.pop(key, None)
            rank_zero_warn(
                f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}"
            )
            atomic_save(checkpoint, path)
Example #8
0
    def transfer_distrib_spawn_state_on_fit_end(self, results):
        checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
        best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

        # requires to compute the state_dict on all processes in case Metrics are present
        state_dict = self.lightning_module.state_dict()

        if self.global_rank == 0 and self.mp_queue is not None:
            rank_zero_warn("cleaning up ddp environment...")

            # save the last weights
            last_path = None
            if (
                self.lightning_module.trainer.state.fn == TrainerFn.FITTING
                and best_model_path is not None
                and len(best_model_path) > 0
            ):
                last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
                atomic_save(self.on_save(state_dict), last_path)

            # todo, pass complete checkpoint as state dictionary
            self.mp_queue.put(best_model_path)
            self.mp_queue.put(last_path)
            self.mp_queue.put(results)
            self.lightning_module.add_to_queue(self.mp_queue)  # adds the `callback_metrics` to the queue
Example #9
0
 def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
     fs = get_filesystem(path)
     fs.makedirs(os.path.dirname(path), exist_ok=True)
     try:
         # write the checkpoint dictionary on the file
         atomic_save(checkpoint, path)
     except AttributeError as err:
         # todo (sean): is this try catch necessary still?
         # https://github.com/PyTorchLightning/pytorch-lightning/pull/431
         key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
         checkpoint.pop(key, None)
         rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
         atomic_save(checkpoint, path)
Example #10
0
    def save_checkpoint(self, filepath, weights_only: bool = False):
        checkpoint = self.dump_checkpoint(weights_only)

        if self.is_global_zero:
            # do the actual save
            try:
                atomic_save(checkpoint, filepath)
            except AttributeError as err:
                if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
                    del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
                rank_zero_warn(
                    'Warning, `module_arguments` dropped from checkpoint.'
                    f' An attribute is not picklable {err}')
                atomic_save(checkpoint, filepath)
    def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
        """Save model/training states as a checkpoint file through state-dump and file-write.

        Args:
            checkpoint: dict containing model and trainer state
            filepath: write-target file's path
        """
        # dump states as a checkpoint dictionary object
        checkpoint = self.on_save(checkpoint)
        if self.is_global_zero:
            try:
                # write the checkpoint dictionary on the file
                atomic_save(checkpoint, filepath)
            except AttributeError as err:
                key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
                checkpoint.pop(key, None)
                rank_zero_warn(f'Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}')
                atomic_save(checkpoint, filepath)
Example #12
0
    def run(self, model: nn.Module, train_dataloader: DataLoader, num_epochs: int = 10, tmpdir: str = None):
        optimizer = configure_optimizers(model)
        model, optimizer = self.setup(model, optimizer)
        train_dataloader = self.setup_dataloaders(train_dataloader)

        model.train()
        for _ in range(num_epochs):
            for batch in train_dataloader:
                batch = self.to_device(batch)
                optimizer.zero_grad()
                loss = model(batch)
                self.backward(loss)
                optimizer.step()

        if isinstance(self._strategy, DDPSpawnPlugin) and tmpdir and self.global_rank == 0:
            checkpoint_path = os.path.join(tmpdir, "model.pt")
            atomic_save(model.state_dict(), checkpoint_path)
            return checkpoint_path
Example #13
0
    def transfer_distrib_spawn_state_on_fit_end(self, results):
        # TODO: is there a better way than accessing callback through model -> trainer -> callback?
        checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
        best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

        if self.global_rank == 0 and self.mp_queue is not None:
            rank_zero_warn("cleaning up ddp environment...")

            # save the last weights
            last_path = None
            # TODO: is there a better way than accessing trainer through model -> trainer?
            if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
                last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
                atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)

            # todo, pass complete checkpoint as state dictionary
            self.mp_queue.put(best_model_path)
            self.mp_queue.put(last_path)
            self.mp_queue.put(results)
    def transfer_distrib_spawn_state_on_fit_end(self, results):
        checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
        best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

        if self.global_rank == 0 and self.mp_queue is not None:
            rank_zero_warn("cleaning up ddp environment...")

            # save the last weights
            last_path = None
            if (self.lightning_module.trainer.state == TrainerState.FITTING
                    and best_model_path is not None
                    and len(best_model_path) > 0):
                last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
                atomic_save(self.on_save(self.lightning_module.state_dict()),
                            last_path)

            # todo, pass complete checkpoint as state dictionary
            self.mp_queue.put(best_model_path)
            self.mp_queue.put(last_path)
            self.mp_queue.put(results)
    def save_checkpoint(self, filepath, weights_only: bool = False):
        """Save model/training states as a checkpoint file through state-dump and file-write.

        Args:
            filepath: write-target file's path
            weights_only: saving model weights only
        """
        # dump states as a checkpoint dictionary object
        checkpoint = self.dump_checkpoint(weights_only)

        if self.trainer.is_global_zero:
            # write the checkpoint dictionary on the file
            try:
                atomic_save(checkpoint, filepath)
            except AttributeError as err:
                if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
                    del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
                rank_zero_warn(
                    'Warning, `hyper_parameters` dropped from checkpoint.'
                    f' An attribute is not picklable {err}')
                atomic_save(checkpoint, filepath)
    def save_checkpoint(self, checkpoint: Dict[str, Any],
                        filepath: str) -> None:
        """Save model/training states as a checkpoint file through state-dump and file-write.

        Args:
            checkpoint: dict containing model and trainer state
            filepath: write-target file's path
        """
        # dump states as a checkpoint dictionary object
        if self.is_global_zero:
            checkpoint = self.on_save(checkpoint)
            try:
                # write the checkpoint dictionary on the file
                atomic_save(checkpoint, filepath)
            except AttributeError as err:
                if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
                    del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
                rank_zero_warn(
                    'Warning, `hyper_parameters` dropped from checkpoint.'
                    f' An attribute is not picklable {err}')
                atomic_save(checkpoint, filepath)
Example #17
0
    TGT_LANG = "de"
    ENSEMBLE = 10
    LAST_CKPT = 63

    trafo_states = []
    print('Loading...')
    for epoch in tqdm(range(LAST_CKPT - ENSEMBLE, LAST_CKPT)):
        ckpt_path = os.path.join(MODELS_FOLDER, 'transformer',
                                 f'trafo_{SRC_LANG}_{TGT_LANG}_{epoch}.pt')
        trafo_states.append(
            pl_load(ckpt_path,
                    map_location=lambda storage, loc: storage)['state_dict'])

    "Average models into model"
    print("Averaging...")
    avg_state = {}
    for key in trafo_states[-1]:
        mean = 0
        for trafo_state in trafo_states:
            mean += trafo_state[key]
        avg_state[key] = mean / len(trafo_states)

    print('saving...')
    ckpt_path = os.path.join(MODELS_FOLDER, 'transformer',
                             f'trafo_{SRC_LANG}_{TGT_LANG}_{LAST_CKPT}.pt')
    avg_ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
    avg_ckpt['state_dict'] = avg_state
    atomic_save(
        avg_ckpt,
        f'.data/models/transformer/trafo_{SRC_LANG}_{TGT_LANG}_ensemble.pt')
Example #18
0
def train(config):
    # ======================================================
    # EXPERIMENT SETUP
    # ======================================================
    from pytorch_lightning import seed_everything
    # Seed
    seed_everything(config.seed)

    # DATASET SETUP
    print("======================================================")
    print("SETTING UP DATASET")
    print("======================================================")
    from ml4floods.models.dataset_setup import get_dataset
    dataset = get_dataset(config.data_params)

    # MODEL SETUP
    print("======================================================")
    print("SETTING UP MODEL")
    print("======================================================")
    from ml4floods.models.model_setup import get_model
    config.model_params.test = False
    config.model_params.train = True
    model = get_model(config.model_params)

    # LOGGING SETUP
    print("======================================================")
    print("SETTING UP LOGGERS")
    print("======================================================")
    import wandb
    from pytorch_lightning.loggers import WandbLogger
    wandb_logger = WandbLogger(
        name=config.experiment_name,
        project=config.wandb_project,
        entity=config.wandb_entity,
        #         save_dir=f"{config.model_params.model_folder}/{config.experiment_name}"
    )

    # CHECKPOINTING SETUP
    print("======================================================")
    print("SETTING UP CHECKPOINTING")
    print("======================================================")
    from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
    experiment_path = f"{config.model_params.model_folder}/{config.experiment_name}"

    checkpoint_path = f"{experiment_path}/checkpoint"
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_path,
        save_top_k=True,
        verbose=True,
        monitor=config.model_params.hyperparameters.metric_monitor,
        mode='min',
        prefix='')

    early_stop_callback = EarlyStopping(
        monitor=config.model_params.hyperparameters.metric_monitor,
        patience=4,
        strict=False,
        verbose=False,
        mode='min')

    callbacks = [checkpoint_callback, early_stop_callback]

    # TRAINING SETUP
    print("======================================================")
    print("START TRAINING")
    print("======================================================")
    from pytorch_lightning import Trainer
    trainer = Trainer(
        fast_dev_run=False,
        logger=wandb_logger,
        callbacks=callbacks,
        default_root_dir=
        f"{config.model_params.model_folder}/{config.experiment_name}",
        accumulate_grad_batches=1,
        gradient_clip_val=0.0,
        auto_lr_find=False,
        benchmark=False,
        distributed_backend=None,
        gpus=config.gpus if config.gpus != '' else None,
        max_epochs=config.model_params.hyperparameters.max_epochs,
        check_val_every_n_epoch=config.model_params.hyperparameters.val_every,
        log_gpu_memory=None,
        resume_from_checkpoint=checkpoint_path
        if config.resume_from_checkpoint else None)

    trainer.fit(model, dataset)

    # ======================================================
    # SAVING SETUP
    # ======================================================
    print("======================================================")
    print("FINISHED TRAINING, SAVING MODEL")
    print("======================================================")
    from pytorch_lightning.utilities.cloud_io import atomic_save
    atomic_save(model.state_dict(), f"{experiment_path}/model.pt")
    torch.save(model.state_dict(),
               os.path.join(wandb_logger.save_dir, 'model.pt'))
    wandb.save(os.path.join(wandb_logger.save_dir, 'model.pt'))
    wandb.finish()

    # Save cofig file in experiment_path
    config_file_path = f"{experiment_path}/config.json"

    save_json(config, config_file_path)

    return 1