예제 #1
0
    def dump_checkpoint(self):
        checkpoint = {
            'epoch': self.current_epoch + 1,
            'global_step': self.global_step + 1,
        }

        if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
            checkpoint[
                'checkpoint_callback_best'] = self.checkpoint_callback.best

        if self.early_stop_callback is not None and self.checkpoint_callback is not False:
            checkpoint[
                'early_stop_callback_wait'] = self.early_stop_callback.wait
            checkpoint[
                'early_stop_callback_patience'] = self.early_stop_callback.patience

        # save optimizers
        optimizer_states = []
        for i, optimizer in enumerate(self.optimizers):
            optimizer_states.append(optimizer.state_dict())

        checkpoint['optimizer_states'] = optimizer_states

        # save lr schedulers
        lr_schedulers = []
        for scheduler in self.lr_schedulers:
            lr_schedulers.append(scheduler['scheduler'].state_dict())

        checkpoint['lr_schedulers'] = lr_schedulers

        # add the hparams and state_dict from the model
        model = self.get_model()

        checkpoint['state_dict'] = model.state_dict()

        # save native amp scaling
        if self.use_amp and self.use_native_amp:
            checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()

        if hasattr(model, "hparams") and model.hparams is not None:
            parsing.clean_namespace(model.hparams)
            if isinstance(model.hparams, dict):
                checkpoint['hparams_type'] = 'dict'
                checkpoint['hparams'] = model.hparams
            elif isinstance(model.hparams, Namespace):
                checkpoint['hparams_type'] = 'Namespace'
                checkpoint['hparams'] = vars(model.hparams)
            else:
                raise ValueError(
                    'The acceptable hparams type is dict or argparse.Namespace,',
                    f' not {checkpoint["hparams_type"]}')
        else:
            rank_zero_warn(
                "Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters."
            )

        # give the model a chance to add a few things
        model.on_save_checkpoint(checkpoint)

        return checkpoint
def test_clean_namespace(tmpdir):
    # See the full list of picklable types at
    # https://docs.python.org/3/library/pickle.html#pickle-picklable
    class UnpicklableClass:
        # Only classes defined at the top level of a module are picklable.
        pass

    test_case = {"1": None, "2": True, "3": 123, "4": unpicklable_function, "5": UnpicklableClass}

    clean_namespace(test_case)

    assert test_case == {"1": None, "2": True, "3": 123}
예제 #3
0
    def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
        # clean hparams
        if hasattr(model, "hparams"):
            parsing.clean_namespace(model.hparams)

        # links data to the trainer
        self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule)

        # check that model is configured correctly
        self.trainer.config_validator.verify_loop_configurations(model)

        # attach model log function to callback
        self.trainer.callback_connector.attach_model_logging_functions(model)
예제 #4
0
    def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
        # bind logger and other properties
        self.trainer.model_connector.copy_trainer_model_properties(model)

        # clean hparams
        if hasattr(model, "hparams"):
            parsing.clean_namespace(model.hparams)

        # links data to the trainer
        self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule)

        # check that model is configured correctly
        self.trainer.config_validator.verify_loop_configurations(model)
예제 #5
0
    def fit(self,
            model: LightningModule,
            train_dataloader: Optional[DataLoader] = None,
            val_dataloaders: Optional[Union[DataLoader,
                                            List[DataLoader]]] = None):
        r"""
        Runs the full optimization routine.

        Args:
            model: Model to fit.

            train_dataloader: A Pytorch
                DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single
                Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        Example::

            # Option 1,
            # Define the train_dataloader() and val_dataloader() fxs
            # in the lightningModule
            # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
            trainer = Trainer()
            model = LightningModule()
            trainer.fit(model)

            # Option 2
            # in production cases we might want to pass different datasets to the same model
            # Recommended for PRODUCTION SYSTEMS
            train, val = DataLoader(...), DataLoader(...)
            trainer = Trainer()
            model = LightningModule()
            trainer.fit(model, train_dataloader=train, val_dataloader=val)

            # Option 1 & 2 can be mixed, for example the training set can be
            # defined as part of the model, and validation can then be feed to .fit()

        """
        # bind logger and other properties
        model.logger = self.logger
        self.copy_trainer_model_properties(model)

        # clean hparams
        if hasattr(model, 'hparams'):
            parsing.clean_namespace(model.hparams)

        # set up the passed in dataloaders (if needed)
        self.__attach_dataloaders(model, train_dataloader, val_dataloaders)

        # check that model is configured correctly
        self.check_model_configuration(model)

        # download the data and do whatever transforms we need
        # do before any spawn calls so that the model can assign properties
        # only on proc 0 because no spawn has happened yet
        if not self._is_data_prepared:
            model.prepare_data()
            self._is_data_prepared = True

        # Run auto batch size scaling
        if self.auto_scale_batch_size:
            if isinstance(self.auto_scale_batch_size, bool):
                self.auto_scale_batch_size = 'power'
            self.scale_batch_size(model, mode=self.auto_scale_batch_size)
            model.logger = self.logger  # reset logger binding

        # Run learning rate finder:
        if self.auto_lr_find:
            self._run_lr_finder_internally(model)
            model.logger = self.logger  # reset logger binding

        # route to appropriate start method
        # when using multi-node or DDP within a node start each module in a separate process
        if self.use_ddp2:
            if self.is_slurm_managing_tasks:
                task = int(os.environ['SLURM_LOCALID'])
            elif 'WORLD_SIZE' in os.environ and 'GROUP_RANK' in os.environ:
                task = int(os.environ['LOCAL_RANK'])
            self.ddp_train(task, model)
        elif self.use_ddp:
            if self.is_slurm_managing_tasks:
                task = int(os.environ['SLURM_LOCALID'])
                self.ddp_train(task, model)
            # torchelastic
            elif 'WORLD_SIZE' in os.environ and 'GROUP_RANK' in os.environ:
                task = int(os.environ['LOCAL_RANK'])
                self.ddp_train(task, model)
            else:
                self.__set_random_port()
                # track for predict
                self.model = model
                # train
                mp.spawn(self.ddp_train,
                         nprocs=self.num_processes,
                         args=(model, ))
                # load weights if not interrupted
                if self.on_colab_kaggle:
                    self.load_spawn_weights(model)
                    self.model = model

        # 1 gpu or dp option triggers training using DP module
        # easier to avoid NCCL issues
        elif self.use_dp:
            self.dp_train(model)

        elif self.use_horovod:
            self.horovod_train(model)

        elif self.single_gpu:
            self.single_gpu_train(model)

        elif self.use_tpu:  # pragma: no-cover
            log.info(f'training on {self.tpu_cores} TPU cores')

            #  COLAB_GPU is an env var available by default in Colab environments.
            start_method = 'fork' if self.on_colab_kaggle else 'spawn'

            # track for predict
            self.model = model

            # train
            if self.tpu_id is not None:
                self.tpu_train(model)
            else:
                xmp.spawn(self.tpu_train,
                          args=(model, ),
                          nprocs=self.tpu_cores,
                          start_method=start_method)

            # load weights if not interrupted
            self.load_spawn_weights(model)
            self.model = model

        # ON CPU
        else:
            # run through amp wrapper
            if self.use_amp:
                raise MisconfigurationException(
                    'amp + cpu is not supported.  Please use a GPU option')

            # CHOOSE OPTIMIZER
            # allow for lr schedulers as well
            self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(
                model)

            self.run_pretrain_routine(model)

        # return 1 when finished
        # used for testing or when we need to know that training succeeded
        return 1