예제 #1
0
class Trainer(
        TrainerProperties,
        TrainerCallbackHookMixin,
        TrainerModelHooksMixin,
        TrainerOptimizersMixin,
        TrainerLoggingMixin,
        TrainerTrainingTricksMixin,
        TrainerDataLoadingMixin,
        DeprecatedDistDeviceAttributes,
):
    @overwrite_by_env_vars
    def __init__(
        self,
        logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                      bool] = True,
        checkpoint_callback: bool = True,
        callbacks: Optional[Union[List[Callback], Callback]] = None,
        default_root_dir: Optional[str] = None,
        gradient_clip_val: float = 0,
        process_position: int = 0,
        num_nodes: int = 1,
        num_processes: int = 1,
        gpus: Optional[Union[List[int], str, int]] = None,
        auto_select_gpus: bool = False,
        tpu_cores: Optional[Union[List[int], str, int]] = None,
        log_gpu_memory: Optional[str] = None,
        progress_bar_refresh_rate: Optional[int] = None,
        overfit_batches: Union[int, float] = 0.0,
        track_grad_norm: Union[int, float, str] = -1,
        check_val_every_n_epoch: int = 1,
        fast_dev_run: Union[int, bool] = False,
        accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
        max_epochs: Optional[int] = None,
        min_epochs: Optional[int] = None,
        max_steps: Optional[int] = None,
        min_steps: Optional[int] = None,
        limit_train_batches: Union[int, float] = 1.0,
        limit_val_batches: Union[int, float] = 1.0,
        limit_test_batches: Union[int, float] = 1.0,
        limit_predict_batches: Union[int, float] = 1.0,
        val_check_interval: Union[int, float] = 1.0,
        flush_logs_every_n_steps: int = 100,
        log_every_n_steps: int = 50,
        accelerator: Optional[Union[str, Accelerator]] = None,
        sync_batchnorm: bool = False,
        precision: int = 32,
        weights_summary: Optional[str] = 'top',
        weights_save_path: Optional[str] = None,
        num_sanity_val_steps: int = 2,
        truncated_bptt_steps: Optional[int] = None,
        resume_from_checkpoint: Optional[Union[Path, str]] = None,
        profiler: Optional[Union[BaseProfiler, bool, str]] = None,
        benchmark: bool = False,
        deterministic: bool = False,
        reload_dataloaders_every_epoch: bool = False,
        auto_lr_find: Union[bool, str] = False,
        replace_sampler_ddp: bool = True,
        terminate_on_nan: bool = False,
        auto_scale_batch_size: Union[str, bool] = False,
        prepare_data_per_node: bool = True,
        plugins: Optional[Union[str, list]] = None,
        amp_backend: str = 'native',
        amp_level: str = 'O2',
        distributed_backend: Optional[str] = None,
        automatic_optimization: Optional[bool] = None,
        move_metrics_to_cpu: bool = False,
        enable_pl_optimizer: bool = None,  # todo: remove in v1.3
        multiple_trainloader_mode: str = 'max_size_cycle',
    ):
        r"""
        Customize every aspect of training via flags

        Args:

            accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...).
                Can also take in an accelerator object for custom hardware.

            accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

            amp_backend: The mixed precision backend to use ("native" or "apex")

            amp_level: The optimization level to use (O1, O2, etc...).

            auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
                trying to optimize initial learning for faster convergence. trainer.tune() method will
                set the suggested learning rate in self.lr or self.learning_rate in the LightningModule.
                To use a different key set a string instead of True with the key name.

            auto_scale_batch_size: If set to True, will `initially` run a batch size
                finder trying to find the largest batch size that fits into memory.
                The result will be stored in self.batch_size in the LightningModule.
                Additionally, can be set to either `power` that estimates the batch size through
                a power search or `binsearch` that estimates the batch size through a binary search.

            auto_select_gpus: If enabled and `gpus` is an integer, pick available
                gpus automatically. This is especially useful when
                GPUs are configured to be in "exclusive mode", such
                that only one process at a time can access them.

            benchmark: If true enables cudnn.benchmark.

            callbacks: Add a callback or list of callbacks.

            checkpoint_callback: If ``True``, enable checkpointing.
                It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
                :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.

                .. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
                    v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead.

            check_val_every_n_epoch: Check val every n train epochs.

            default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
                Default: ``os.getcwd()``.
                Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'

            deterministic: If true enables cudnn.deterministic.

            distributed_backend: deprecated. Please use 'accelerator'

            fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
                of train, val and test to find any bugs (ie: a sort of unit test).

            flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).

            gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node

            gradient_clip_val: 0 means don't clip.

            limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)

            limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)

            limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)

            logger: Logger (or iterable collection of loggers) for experiment tracking.

            log_gpu_memory: None, 'min_max', 'all'. Might slow performance

            log_every_n_steps: How often to log within steps (defaults to every 50 steps).

            automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad
                in LightningModule. This argument has been moved to LightningModule. It is deprecated
                here in v1.1 and will be removed in v1.3.

            prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
                Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data

            process_position: orders the progress bar when running multiple models on same machine.

            progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
                Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means
                a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).

            profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
                value is deprecated in v1.1 and will be removed in v1.3.

            overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0

            plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.

            precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.

            max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
                If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.

            min_epochs: Force training for at least these many epochs. Disabled by default (None).
                If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1.

            max_steps: Stop training after this number of steps. Disabled by default (None).

            min_steps: Force training for at least these number of steps. Disabled by default (None).

            num_nodes: number of GPU nodes for distributed training.

            num_processes: number of processes for distributed training with distributed_backend="ddp_cpu"

            num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
                Set it to `-1` to run all batches in all validation dataloaders. Default: 2

            reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.

            replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this
                will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for
                train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
                you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.

            resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is
                no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
                training will start from the beginning of the next epoch.

            sync_batchnorm: Synchronize batch norm layers between process groups/whole world.

            terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
                end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

            tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]

            track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.

            truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
                sequence.

            val_check_interval: How often to check the validation set. Use float to check within a training epoch,
                use int to check every n steps (batches).

            weights_summary: Prints a summary of the weights when training begins.

            weights_save_path: Where to save weights if specified. Will override default_root_dir
                for checkpoints only. Use this if for whatever reason you need the checkpoints
                stored in a different place than the logs written in `default_root_dir`.
                Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
                Defaults to `default_root_dir`.

            move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu.
                This can save some gpu memory, but can make training slower. Use with attention.

            enable_pl_optimizer: If True, each optimizer will be wrapped by
                `pytorch_lightning.core.optimizer.LightningOptimizer`. It allows Lightning to
                handle AMP, TPU, accumulated_gradients, etc.
                .. warning:: Currently deprecated and it will be removed in v1.3

            multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders.
                In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
                and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
                reload when reaching the minimum length of datasets.
        """
        super().__init__()
        self._running_stage = None

        distributed_backend = distributed_backend or accelerator

        # init connectors
        self.dev_debugger = InternalDebugger(self)
        self.config_validator = ConfigValidator(self)
        self.data_connector = DataConnector(self)
        self.optimizer_connector = OptimizerConnector(self)

        self.accelerator_connector = BackendConnector(
            num_processes, tpu_cores, distributed_backend, auto_select_gpus,
            gpus, num_nodes, sync_batchnorm, benchmark, replace_sampler_ddp,
            deterministic, precision, amp_backend, amp_level, plugins)
        self.logger_connector = LoggerConnector(self, log_gpu_memory)
        self.model_connector = ModelConnector(self)
        self.callback_connector = CallbackConnector(self)
        self.debugging_connector = DebuggingConnector(self)
        self.training_tricks_connector = TrainingTricksConnector(self)
        self.profile_connector = ProfilerConnector(self)
        self.checkpoint_connector = CheckpointConnector(self)
        self.slurm_connector = SLURMConnector(self)
        self.tuner = Tuner(self)
        self.train_loop = TrainLoop(self, multiple_trainloader_mode)
        self.evaluation_loop = EvaluationLoop(self)
        self.predict_loop = PredictLoop(self)

        # training state
        self.weights_summary = weights_summary
        self.shown_warnings = set()

        # init callbacks
        # Declare attributes to be set in callback_connector on_trainer_init
        self.callback_connector.on_trainer_init(
            callbacks,
            checkpoint_callback,
            progress_bar_refresh_rate,
            process_position,
            default_root_dir,
            weights_save_path,
            resume_from_checkpoint,
        )

        # hook
        self.on_init_start()

        # init optimizer + lr scheduler related flags
        self.optimizer_connector.on_trainer_init(enable_pl_optimizer)

        # init data flags
        self.data_connector.on_trainer_init(check_val_every_n_epoch,
                                            reload_dataloaders_every_epoch,
                                            prepare_data_per_node)

        # init training tricks
        self.training_tricks_connector.on_trainer_init(
            gradient_clip_val, track_grad_norm, accumulate_grad_batches,
            truncated_bptt_steps, terminate_on_nan)

        # init train loop related flags
        # TODO: remove in 1.3.0
        if automatic_optimization is None:
            automatic_optimization = True
        else:
            rank_zero_warn(
                "Disable automatic optimization with the trainer flag is deprecated and will be removed in v1.3.0!"
                "Please use the property on the LightningModule for disabling automatic optimization"
            )
        self.train_loop.on_trainer_init(
            max_epochs,
            min_epochs,
            max_steps,
            min_steps,
            num_sanity_val_steps,
            automatic_optimization,
            weights_summary,
        )
        self.evaluation_loop.on_trainer_init()

        # configure tuner
        self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

        # configure profiler
        self.profile_connector.on_trainer_init(profiler)

        # init logger flags
        self.logger_connector.on_trainer_init(
            logger,
            flush_logs_every_n_steps,
            log_every_n_steps,
            move_metrics_to_cpu,
        )

        # init debugging flags
        self.debugging_connector.on_init_start(
            limit_train_batches,
            limit_val_batches,
            limit_test_batches,
            limit_predict_batches,
            val_check_interval,
            overfit_batches,
            fast_dev_run,
        )

        # Callback system
        self.on_init_end()

    def setup_trainer(self, model: LightningModule):
        """
        Sanity check a few things before starting actual training or testing.

        Args:
            model: The model to run sanity test on.
        """

        # log hyper-parameters
        if self.logger is not None:
            # save exp to get started (this is where the first experiment logs are written)
            self.logger.log_hyperparams(model.hparams_initial)
            self.logger.log_graph(model)
            self.logger.save()

    def fit(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Runs the full optimization routine.

        Args:
            datamodule: A instance of :class:`LightningDataModule`.

            model: Model to fit.

            train_dataloader: A Pytorch DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        """
        # bookkeeping
        self._state = TrainerState.RUNNING

        # bookkeeping
        # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified.
        if self._running_stage is None:
            self._set_running_stage(RunningStage.TRAINING, model)

        # set local properties on the model
        self.model_connector.copy_trainer_model_properties(model)

        # ----------------------------
        # LINK DATA
        # ----------------------------
        # setup data, etc...
        self.train_loop.setup_fit(model, train_dataloader, val_dataloaders,
                                  datamodule)

        # hook
        self.data_connector.prepare_data(model)
        self.callback_connector._attach_model_callbacks(model, self)

        # ----------------------------
        # SET UP TRAINING
        # ----------------------------
        self.call_setup_hook(model)
        self.call_hook("on_before_accelerator_backend_setup", model)
        self.accelerator_backend.setup(self, model)
        self.setup_trainer(model)

        # ----------------------------
        # INSPECT THE CORE LOOPS
        # ----------------------------
        #             Lightning internal flow looks like this.
        #
        #   trainer.fit(...) or trainer.test(...) or trainer.predict(...)   ||
        #                                |                                  ||
        #                        create accelerator                         ||
        #                                |                                  ||
        #                         trainer.dispatch                          ||  LIGHTNING
        #                                |                                  ||
        #    start_training or start_testing or start_predicting call       ||  FLOW
        #               from `accelerator.training_type_plugin`             ||
        #                                |                                  ||  DIRECTION
        #             run_train or run_test or run_predict call             ||
        #                           from `trainer`                          ||
        #                                |                                  ||
        #                             results                               \/
        # This is used to guide readers to the core loops: train, test, predict.
        # `run_predict` is the simplest to understand, use `Go to Definition` to read it :)
        # Search for `start_training` or `start_testing` or `start_predicting` in
        # `pytorch_lightning/plugins/training_type` folder to find accelerator dispatch functions.
        self.accelerator.train_loop = self.run_train
        self.accelerator.validation_loop = self.run_evaluation
        self.accelerator.test_loop = self.run_evaluation
        self.accelerator.predict_loop = self.run_predict

        # ----------------------------
        # TRAIN
        # ----------------------------
        # hook
        self.call_hook("on_fit_start")

        # plugin will setup fitting (e.g. ddp will launch child processes)
        self.pre_dispatch()

        # dispath `start_training` or `start_testing` or `start_predicting`
        self.dispatch()

        # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
        self.post_dispatch()

        # ----------------------------
        # POST-Training CLEAN UP
        # ----------------------------
        # hook
        self.call_hook('on_fit_end')

        # hook
        self.teardown('fit')
        if self.is_function_implemented('teardown'):
            model.teardown('fit')

        # return 1 when finished
        # used for testing or when we need to know that training succeeded
        if self._state != TrainerState.INTERRUPTED:
            self._state = TrainerState.FINISHED

        self._set_running_stage(None, model)

        return self.training_type_plugin.results or 1

    def pre_dispatch(self):
        self.training_type_plugin.pre_dispatch()
        self.precision_plugin.pre_dispatch()

    def post_dispatch(self):
        self.training_type_plugin.post_dispatch()
        self.precision_plugin.post_dispatch()
        self.accelerator_backend.teardown()

    def dispatch(self):
        if self.testing:
            self.training_type_plugin.start_testing(self)

        elif self.predicting:
            self.training_type_plugin.start_predicting(self)

        else:
            self.training_type_plugin.start_training(self)

    def train_or_test_or_predict(self):
        if self.testing:
            results = self.run_test()

        elif self.predicting:
            results = self.run_predict()

        else:
            results = self.run_train()

        return results

    def _set_running_stage(self, stage: LightningEnum,
                           model_ref: LightningModule):
        """
        This function is used to set the running_state on both
        the trainer and the model
        """
        model_ref.running_stage = stage
        self._running_stage = stage

    def _pre_training_routine(self):
        # wait for all to join if on distributed
        self.accelerator.training_type_plugin.barrier("setup_training")

        # register auto-resubmit when on SLURM
        self.slurm_connector.register_slurm_signal_handlers()

        # --------------------------
        # Pre-train
        # --------------------------
        # on pretrain routine start
        ref_model = self.get_model()

        self.on_pretrain_routine_start(ref_model)
        if self.is_function_implemented("on_pretrain_routine_start"):
            ref_model.on_pretrain_routine_start()

        # print model summary
        if self.is_global_zero and self.weights_summary is not None and not self.testing:
            if self.weights_summary in ModelSummary.MODES:
                ref_model.summarize(mode=self.weights_summary)
            else:
                raise MisconfigurationException(
                    "weights_summary can be None, " +
                    ", ".join(ModelSummary.MODES))

        # restore training and model before hpc is called
        self.checkpoint_connector.restore_weights()

        # on pretrain routine end
        self.on_pretrain_routine_end(ref_model)
        if self.is_function_implemented("on_pretrain_routine_end"):
            ref_model.on_pretrain_routine_end()

    def run_train(self):

        self._pre_training_routine()

        if not self.is_global_zero and self.progress_bar_callback is not None:
            self.progress_bar_callback.disable()

        self.run_sanity_check(self.get_model())

        # set stage for logging
        self._set_running_stage(RunningStage.TRAINING, self.get_model())

        self.checkpoint_connector.has_trained = False

        # enable train mode
        model = self.get_model()
        model.train()
        torch.set_grad_enabled(True)

        # reload data when needed
        self.train_loop.reset_train_val_dataloaders(model)

        # hook
        self.train_loop.on_train_start()

        try:
            if self.train_loop.should_skip_training():
                return
            # run all epochs
            epochs = range(self.current_epoch,
                           self.max_epochs) if self.max_epochs else count(
                               self.current_epoch)
            for epoch in epochs:

                # hook
                self.train_loop.on_train_epoch_start(epoch)

                with self.profiler.profile("run_training_epoch"):
                    # run train epoch
                    self.train_loop.run_training_epoch()

                if self.max_steps and self.max_steps <= self.global_step:
                    return

                # early stopping
                met_min_epochs = (
                    epoch >= self.min_epochs - 1) if self.min_epochs else True
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                if self.should_stop:
                    if met_min_epochs and met_min_steps:
                        return
                    else:
                        log.info(
                            'Trainer was signaled to stop but required minimum epochs'
                            f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
                            ' not been met. Training will continue...')

            # hook
            self.train_loop.on_train_end()

        except KeyboardInterrupt:
            rank_zero_warn(
                'Detected KeyboardInterrupt, attempting graceful shutdown...')

            # user could press ctrl+c many times... only shutdown once
            if not self.interrupted:
                self.interrupted = True
                self._state = TrainerState.INTERRUPTED
                self.on_keyboard_interrupt()
        finally:
            # hook
            self.train_loop.on_train_end()

    def run_evaluation(self, max_batches=None, on_epoch=False):

        # used to know if we are logging for val, test + reset cached results
        self._set_running_stage(
            RunningStage.TESTING if self.testing else RunningStage.EVALUATING,
            self.get_model())
        self.logger_connector.reset()

        # bookkeeping
        self.evaluation_loop.testing = self.testing

        # prepare dataloaders
        dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(
            max_batches)

        # check if we want to skip this evaluation
        if self.evaluation_loop.should_skip_evaluation(max_batches):
            return [], []

        # enable eval mode + no grads
        self.evaluation_loop.on_evaluation_model_eval()
        # ref model
        model = self.get_model()
        model.zero_grad()
        torch.set_grad_enabled(False)

        # hook
        self.evaluation_loop.on_evaluation_start()

        # set up the eval loop
        self.evaluation_loop.setup(model, max_batches, dataloaders)

        # hook
        self.evaluation_loop.on_evaluation_epoch_start()

        # run validation/testing
        for dataloader_idx, dataloader in enumerate(dataloaders):
            # bookkeeping
            dl_outputs = []
            dataloader = self.training_type_plugin.process_dataloader(
                dataloader)
            dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when running on limited batches
                if batch_idx >= dl_max_batches:
                    break

                # hook
                self.evaluation_loop.on_evaluation_batch_start(
                    batch, batch_idx, dataloader_idx)

                # lightning module methods
                with self.profiler.profile("evaluation_step_and_end"):
                    output = self.evaluation_loop.evaluation_step(
                        batch, batch_idx, dataloader_idx)
                    output = self.evaluation_loop.evaluation_step_end(output)

                # hook + store predictions
                self.evaluation_loop.on_evaluation_batch_end(
                    output, batch, batch_idx, dataloader_idx)

                # log batch metrics
                self.evaluation_loop.log_evaluation_step_metrics(
                    output, batch_idx)

                # track epoch level outputs
                dl_outputs = self.track_output_for_epoch_end(
                    dl_outputs, output)

            # store batch level output per dataloader
            self.evaluation_loop.outputs.append(dl_outputs)

        # lightning module method
        deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end()

        # hook
        self.evaluation_loop.on_evaluation_epoch_end()

        # update epoch-level lr_schedulers
        if on_epoch:
            self.optimizer_connector.update_learning_rates(interval='epoch')

        # hook
        self.evaluation_loop.on_evaluation_end()

        # log epoch metrics
        eval_loop_results = self.evaluation_loop.log_epoch_metrics_on_evaluation_end(
        )

        # save predictions to disk
        self.evaluation_loop.predictions.to_disk()

        # enable train mode again
        self.evaluation_loop.on_evaluation_model_train()

        torch.set_grad_enabled(True)

        return eval_loop_results, deprecated_eval_results

    def track_output_for_epoch_end(self, outputs, output):
        if output is not None:
            if isinstance(output, Result):
                output.detach()
                if self.move_metrics_to_cpu:
                    output.cpu()
            elif isinstance(output, dict):
                output = recursive_detach(output,
                                          to_cpu=self.move_metrics_to_cpu)
            elif isinstance(output, torch.Tensor
                            ) and output.is_cuda and self.move_metrics_to_cpu:
                output = output.cpu()
            outputs.append(output)
        return outputs

    def run_test(self):
        if not self.is_global_zero and self.progress_bar_callback is not None:
            self.progress_bar_callback.disable()

        # only load test dataloader for testing
        # self.reset_test_dataloader(ref_model)
        with self.profiler.profile("run_test_evaluation"):
            eval_loop_results, _ = self.run_evaluation()

        if len(eval_loop_results) == 0:
            return 1

        # remove the tensors from the eval results
        for i, result in enumerate(eval_loop_results):
            if isinstance(result, dict):
                for k, v in result.items():
                    if isinstance(v, torch.Tensor):
                        result[k] = v.cpu().item()

        return eval_loop_results

    def run_predict(self):
        # prepare dataloaders
        dataloaders, max_batches = self.predict_loop.get_predict_dataloaders(
            None)

        # check if we want to skip this evaluation
        if self.predict_loop.should_skip_predict(dataloaders, max_batches):
            return []

        # ref model
        model = self.get_model()

        # enable eval mode + no grads
        self.predict_loop.on_predict_model_eval()
        model.zero_grad()
        torch.set_grad_enabled(False)

        # set up the eval loop
        self.predict_loop.setup(model, max_batches, dataloaders)

        # run validation/testing
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dataloader = self.accelerator_backend.process_dataloader(
                dataloader)
            dl_max_batches = self.predict_loop.max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when running on limited batches
                if batch_idx >= dl_max_batches:
                    break

                # lightning module methods
                with self.profiler.profile("predict"):
                    self.predict_loop.predict(batch, batch_idx, dataloader_idx)

        results = self.predict_loop.on_predict_epoch_end()
        return results

    def run_sanity_check(self, ref_model):
        using_val_step = ref_model.val_dataloader is not None and is_overridden(
            'validation_step', ref_model)
        should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

        # run tiny validation (if validation defined)
        # to make sure program won't crash during val
        if should_sanity_check:
            self.reset_val_dataloader(ref_model)
            self.num_sanity_val_batches = [
                min(self.num_sanity_val_steps, val_batches)
                for val_batches in self.num_val_batches
            ]

            # hook and callback
            self.running_sanity_check = True
            self.on_sanity_check_start()

            # run eval step
            _, eval_results = self.run_evaluation(
                max_batches=self.num_sanity_val_batches)

            # allow no returns from eval
            if eval_results is not None and len(eval_results) > 0:
                # when we get a list back, used only the last item
                if isinstance(eval_results, list):
                    eval_results = eval_results[-1]

                _, _, _, callback_metrics, _ = self.process_dict_result(
                    eval_results)
                self.logger_connector.callback_metrics = callback_metrics

            self.on_sanity_check_end()
            self.running_sanity_check = False

    def test(
        self,
        model: Optional[LightningModule] = None,
        test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        ckpt_path: Optional[str] = 'best',
        verbose: bool = True,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""

        Separates from fit to make sure you never run on your test set until you want to.

        Args:
            ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
                If ``None``, use the weights from the last epoch to test. Default to ``best``.

            datamodule: A instance of :class:`LightningDataModule`.

            model: The model to test.

            test_dataloaders: Either a single
                Pytorch Dataloader or a list of them, specifying validation samples.

            verbose: If True, prints the test results

        Returns:
            Returns a list of dictionaries, one for each test dataloader containing their respective metrics.
        """
        # --------------------
        # SETUP HOOK
        # --------------------
        self.verbose_test = verbose

        self._set_running_stage(RunningStage.TESTING, model
                                or self.get_model())

        # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
        if test_dataloaders and datamodule:
            raise MisconfigurationException(
                'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
            )

        # Attach datamodule to get setup/prepare_data added to model before the call to it below
        self.data_connector.attach_datamodule(model or self.get_model(),
                                              datamodule, 'test')

        if model is not None:
            results = self.__test_given_model(model, test_dataloaders)
        else:
            results = self.__test_using_best_weights(ckpt_path,
                                                     test_dataloaders)

        self.teardown('test')
        self._set_running_stage(None, model or self.get_model())
        return results

    def __test_using_best_weights(self, ckpt_path, test_dataloaders):
        model = self.get_model()

        # if user requests the best checkpoint but we don't have it, error
        if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
            raise MisconfigurationException(
                'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.'
            )

        # load best weights
        if ckpt_path is not None:
            # ckpt_path is 'best' so load the best model
            if ckpt_path == 'best':
                ckpt_path = self.checkpoint_callback.best_model_path

            if len(ckpt_path) == 0:
                rank_zero_warn(
                    f'.test() found no path for the best weights, {ckpt_path}. Please '
                    f'specify a path for a checkpoint .test(ckpt_path=PATH)')
                return {}
            if not self._device_type == DeviceType.TPU:
                self.training_type_plugin.barrier()

            ckpt = pl_load(ckpt_path,
                           map_location=lambda storage, loc: storage)
            model.load_state_dict(ckpt['state_dict'])

        # attach dataloaders
        if test_dataloaders is not None:
            self.data_connector.attach_dataloaders(
                model, test_dataloaders=test_dataloaders)

        # run tests
        self.tested_ckpt_path = ckpt_path
        results = self.fit(model)

        # teardown
        if self.is_function_implemented('teardown'):
            model_ref = self.get_model()
            model_ref.teardown('test')

        return results

    def __test_given_model(self, model, test_dataloaders):

        # attach data
        if test_dataloaders is not None:
            self.data_connector.attach_dataloaders(
                model, test_dataloaders=test_dataloaders)

        # run test
        # sets up testing so we short circuit to eval
        results = self.fit(model)

        # teardown
        if self.is_function_implemented('teardown'):
            model.teardown('test')

        return results

    def predict(
        self,
        model: Optional[LightningModule] = None,
        dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""

        Separates from fit to make sure you never run on your predictions set until you want to.

        This will call the model forward function to compute predictions.

        Args:
            model: The model to predict on.

            dataloaders: Either a single
                Pytorch Dataloader or a list of them, specifying inference samples.

            datamodule: A instance of :class:`LightningDataModule`.

        Returns:
            Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
        """

        # --------------------
        # SETUP HOOK
        # --------------------
        # If you supply a datamodule you can't supply dataloaders

        model = model or self.get_model()

        self._set_running_stage(RunningStage.PREDICTING, model)

        if dataloaders and datamodule:
            raise MisconfigurationException(
                'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
            )

        if datamodule is not None:
            # Attach datamodule to get setup/prepare_data added to model before the call to it below
            self.data_connector.attach_datamodule(model, datamodule, 'predict')

        # attach data
        if dataloaders is not None:
            self.data_connector.attach_dataloaders(
                model, predict_dataloaders=dataloaders)

        self.model = model
        results = self.fit(model)
        self._set_running_stage(None, model)

        return results

    def tune(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Runs routines to tune hyperparameters before training.

        Args:
            datamodule: A instance of :class:`LightningDataModule`.

            model: Model to tune.

            train_dataloader: A Pytorch DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        """
        self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule)

    def call_setup_hook(self, model):
        # call setup after the ddp process has connected
        stage_name = 'test' if self.testing else 'fit'
        if self.datamodule is not None:
            called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
            if not called:
                self.datamodule.setup(stage_name)
        self.setup(model, stage_name)
        model.setup(stage_name)

    def _reset_result_and_set_hook_fx_name(self, hook_name):
        # on_before_zero_grad is called within training_step
        if "batch_start" in hook_name or "on_before_zero_grad" in hook_name:
            return True
        model_ref = self.get_model()
        if model_ref is not None:
            # used to track current hook name called
            model_ref._results = Result()
            model_ref._current_hook_fx_name = hook_name
        return False

    def _cache_logged_metrics(self):
        model_ref = self.get_model()
        if model_ref is not None:
            # capture logging for this hook
            self.logger_connector.cache_logged_metrics()

    def call_hook(self, hook_name, *args, **kwargs):
        # set hook_name to model + reset Result obj
        skip = self._reset_result_and_set_hook_fx_name(hook_name)

        # always profile hooks
        with self.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self, hook_name):
                trainer_hook = getattr(self, hook_name)
                trainer_hook(*args, **kwargs)

            # next call hook in lightningModule
            output = None
            model_ref = self.get_model()
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                output = hook_fx(*args, **kwargs)

            # if the PL module doesn't have the hook then call the accelerator
            # used to auto-reduce things for the user with Results obj
            elif hasattr(self.accelerator_backend, hook_name):
                accelerator_hook = getattr(self.accelerator_backend, hook_name)
                output = accelerator_hook(*args, **kwargs)

        if not skip:
            self._cache_logged_metrics()
        return output

    @property
    def training(self) -> bool:
        return self._running_stage == RunningStage.TRAINING

    @training.setter
    def training(self, val: bool) -> None:
        if val:
            self._running_stage = RunningStage.TRAINING
        elif self.training:
            self._running_stage = None

    @property
    def testing(self) -> bool:
        return self._running_stage == RunningStage.TESTING

    @testing.setter
    def testing(self, val: bool) -> None:
        if val:
            self._running_stage = RunningStage.TESTING
        elif self.testing:
            self._running_stage = None

    @property
    def predicting(self) -> bool:
        return self._running_stage == RunningStage.PREDICTING

    @predicting.setter
    def predicting(self, val: bool) -> None:
        if val:
            self._running_stage = RunningStage.PREDICTING
        elif self.predicting:
            self._running_stage = None

    @property
    def tuning(self) -> bool:
        return self._running_stage == RunningStage.TUNING

    @tuning.setter
    def tuning(self, val: bool) -> None:
        if val:
            self._running_stage = RunningStage.TUNING
        elif self.tuning:
            self._running_stage = None

    @property
    def evaluating(self) -> bool:
        return self._running_stage == RunningStage.EVALUATING

    @evaluating.setter
    def evaluating(self, val: bool) -> None:
        if val:
            self._running_stage = RunningStage.EVALUATING
        elif self.evaluating:
            self._running_stage = None
예제 #2
0
    def __init__(
        self,
        logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                      bool] = True,
        checkpoint_callback: bool = True,
        callbacks: Optional[Union[List[Callback], Callback]] = None,
        default_root_dir: Optional[str] = None,
        gradient_clip_val: float = 0,
        process_position: int = 0,
        num_nodes: int = 1,
        num_processes: int = 1,
        gpus: Optional[Union[List[int], str, int]] = None,
        auto_select_gpus: bool = False,
        tpu_cores: Optional[Union[List[int], str, int]] = None,
        log_gpu_memory: Optional[str] = None,
        progress_bar_refresh_rate: Optional[int] = None,
        overfit_batches: Union[int, float] = 0.0,
        track_grad_norm: Union[int, float, str] = -1,
        check_val_every_n_epoch: int = 1,
        fast_dev_run: Union[int, bool] = False,
        accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
        max_epochs: Optional[int] = None,
        min_epochs: Optional[int] = None,
        max_steps: Optional[int] = None,
        min_steps: Optional[int] = None,
        limit_train_batches: Union[int, float] = 1.0,
        limit_val_batches: Union[int, float] = 1.0,
        limit_test_batches: Union[int, float] = 1.0,
        limit_predict_batches: Union[int, float] = 1.0,
        val_check_interval: Union[int, float] = 1.0,
        flush_logs_every_n_steps: int = 100,
        log_every_n_steps: int = 50,
        accelerator: Optional[Union[str, Accelerator]] = None,
        sync_batchnorm: bool = False,
        precision: int = 32,
        weights_summary: Optional[str] = 'top',
        weights_save_path: Optional[str] = None,
        num_sanity_val_steps: int = 2,
        truncated_bptt_steps: Optional[int] = None,
        resume_from_checkpoint: Optional[Union[Path, str]] = None,
        profiler: Optional[Union[BaseProfiler, bool, str]] = None,
        benchmark: bool = False,
        deterministic: bool = False,
        reload_dataloaders_every_epoch: bool = False,
        auto_lr_find: Union[bool, str] = False,
        replace_sampler_ddp: bool = True,
        terminate_on_nan: bool = False,
        auto_scale_batch_size: Union[str, bool] = False,
        prepare_data_per_node: bool = True,
        plugins: Optional[Union[str, list]] = None,
        amp_backend: str = 'native',
        amp_level: str = 'O2',
        distributed_backend: Optional[str] = None,
        automatic_optimization: Optional[bool] = None,
        move_metrics_to_cpu: bool = False,
        enable_pl_optimizer: bool = None,  # todo: remove in v1.3
        multiple_trainloader_mode: str = 'max_size_cycle',
    ):
        r"""
        Customize every aspect of training via flags

        Args:

            accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...).
                Can also take in an accelerator object for custom hardware.

            accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

            amp_backend: The mixed precision backend to use ("native" or "apex")

            amp_level: The optimization level to use (O1, O2, etc...).

            auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
                trying to optimize initial learning for faster convergence. trainer.tune() method will
                set the suggested learning rate in self.lr or self.learning_rate in the LightningModule.
                To use a different key set a string instead of True with the key name.

            auto_scale_batch_size: If set to True, will `initially` run a batch size
                finder trying to find the largest batch size that fits into memory.
                The result will be stored in self.batch_size in the LightningModule.
                Additionally, can be set to either `power` that estimates the batch size through
                a power search or `binsearch` that estimates the batch size through a binary search.

            auto_select_gpus: If enabled and `gpus` is an integer, pick available
                gpus automatically. This is especially useful when
                GPUs are configured to be in "exclusive mode", such
                that only one process at a time can access them.

            benchmark: If true enables cudnn.benchmark.

            callbacks: Add a callback or list of callbacks.

            checkpoint_callback: If ``True``, enable checkpointing.
                It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
                :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.

                .. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
                    v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead.

            check_val_every_n_epoch: Check val every n train epochs.

            default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
                Default: ``os.getcwd()``.
                Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'

            deterministic: If true enables cudnn.deterministic.

            distributed_backend: deprecated. Please use 'accelerator'

            fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
                of train, val and test to find any bugs (ie: a sort of unit test).

            flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).

            gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node

            gradient_clip_val: 0 means don't clip.

            limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)

            limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)

            limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)

            logger: Logger (or iterable collection of loggers) for experiment tracking.

            log_gpu_memory: None, 'min_max', 'all'. Might slow performance

            log_every_n_steps: How often to log within steps (defaults to every 50 steps).

            automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad
                in LightningModule. This argument has been moved to LightningModule. It is deprecated
                here in v1.1 and will be removed in v1.3.

            prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
                Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data

            process_position: orders the progress bar when running multiple models on same machine.

            progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
                Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means
                a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).

            profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
                value is deprecated in v1.1 and will be removed in v1.3.

            overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0

            plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.

            precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.

            max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
                If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.

            min_epochs: Force training for at least these many epochs. Disabled by default (None).
                If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1.

            max_steps: Stop training after this number of steps. Disabled by default (None).

            min_steps: Force training for at least these number of steps. Disabled by default (None).

            num_nodes: number of GPU nodes for distributed training.

            num_processes: number of processes for distributed training with distributed_backend="ddp_cpu"

            num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
                Set it to `-1` to run all batches in all validation dataloaders. Default: 2

            reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.

            replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this
                will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for
                train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
                you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.

            resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is
                no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
                training will start from the beginning of the next epoch.

            sync_batchnorm: Synchronize batch norm layers between process groups/whole world.

            terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
                end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

            tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]

            track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.

            truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
                sequence.

            val_check_interval: How often to check the validation set. Use float to check within a training epoch,
                use int to check every n steps (batches).

            weights_summary: Prints a summary of the weights when training begins.

            weights_save_path: Where to save weights if specified. Will override default_root_dir
                for checkpoints only. Use this if for whatever reason you need the checkpoints
                stored in a different place than the logs written in `default_root_dir`.
                Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
                Defaults to `default_root_dir`.

            move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu.
                This can save some gpu memory, but can make training slower. Use with attention.

            enable_pl_optimizer: If True, each optimizer will be wrapped by
                `pytorch_lightning.core.optimizer.LightningOptimizer`. It allows Lightning to
                handle AMP, TPU, accumulated_gradients, etc.
                .. warning:: Currently deprecated and it will be removed in v1.3

            multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders.
                In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
                and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
                reload when reaching the minimum length of datasets.
        """
        super().__init__()
        self._running_stage = None

        distributed_backend = distributed_backend or accelerator

        # init connectors
        self.dev_debugger = InternalDebugger(self)
        self.config_validator = ConfigValidator(self)
        self.data_connector = DataConnector(self)
        self.optimizer_connector = OptimizerConnector(self)

        self.accelerator_connector = BackendConnector(
            num_processes, tpu_cores, distributed_backend, auto_select_gpus,
            gpus, num_nodes, sync_batchnorm, benchmark, replace_sampler_ddp,
            deterministic, precision, amp_backend, amp_level, plugins)
        self.logger_connector = LoggerConnector(self, log_gpu_memory)
        self.model_connector = ModelConnector(self)
        self.callback_connector = CallbackConnector(self)
        self.debugging_connector = DebuggingConnector(self)
        self.training_tricks_connector = TrainingTricksConnector(self)
        self.profile_connector = ProfilerConnector(self)
        self.checkpoint_connector = CheckpointConnector(self)
        self.slurm_connector = SLURMConnector(self)
        self.tuner = Tuner(self)
        self.train_loop = TrainLoop(self, multiple_trainloader_mode)
        self.evaluation_loop = EvaluationLoop(self)
        self.predict_loop = PredictLoop(self)

        # training state
        self.weights_summary = weights_summary
        self.shown_warnings = set()

        # init callbacks
        # Declare attributes to be set in callback_connector on_trainer_init
        self.callback_connector.on_trainer_init(
            callbacks,
            checkpoint_callback,
            progress_bar_refresh_rate,
            process_position,
            default_root_dir,
            weights_save_path,
            resume_from_checkpoint,
        )

        # hook
        self.on_init_start()

        # init optimizer + lr scheduler related flags
        self.optimizer_connector.on_trainer_init(enable_pl_optimizer)

        # init data flags
        self.data_connector.on_trainer_init(check_val_every_n_epoch,
                                            reload_dataloaders_every_epoch,
                                            prepare_data_per_node)

        # init training tricks
        self.training_tricks_connector.on_trainer_init(
            gradient_clip_val, track_grad_norm, accumulate_grad_batches,
            truncated_bptt_steps, terminate_on_nan)

        # init train loop related flags
        # TODO: remove in 1.3.0
        if automatic_optimization is None:
            automatic_optimization = True
        else:
            rank_zero_warn(
                "Disable automatic optimization with the trainer flag is deprecated and will be removed in v1.3.0!"
                "Please use the property on the LightningModule for disabling automatic optimization"
            )
        self.train_loop.on_trainer_init(
            max_epochs,
            min_epochs,
            max_steps,
            min_steps,
            num_sanity_val_steps,
            automatic_optimization,
            weights_summary,
        )
        self.evaluation_loop.on_trainer_init()

        # configure tuner
        self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

        # configure profiler
        self.profile_connector.on_trainer_init(profiler)

        # init logger flags
        self.logger_connector.on_trainer_init(
            logger,
            flush_logs_every_n_steps,
            log_every_n_steps,
            move_metrics_to_cpu,
        )

        # init debugging flags
        self.debugging_connector.on_init_start(
            limit_train_batches,
            limit_val_batches,
            limit_test_batches,
            limit_predict_batches,
            val_check_interval,
            overfit_batches,
            fast_dev_run,
        )

        # Callback system
        self.on_init_end()
예제 #3
0
class Trainer(
        TrainerProperties,
        TrainerCallbackHookMixin,
        TrainerModelHooksMixin,
        TrainerOptimizersMixin,
        TrainerLoggingMixin,
        TrainerTrainingTricksMixin,
        TrainerDataLoadingMixin,
        DeprecatedDistDeviceAttributes,
        DeprecatedTrainerAttributes,
):
    @_defaults_from_env_vars
    def __init__(self,
                 logger: Union[LightningLoggerBase,
                               Iterable[LightningLoggerBase], bool] = True,
                 checkpoint_callback: bool = True,
                 callbacks: Optional[Union[List[Callback], Callback]] = None,
                 default_root_dir: Optional[str] = None,
                 gradient_clip_val: float = 0,
                 gradient_clip_algorithm: str = 'norm',
                 process_position: int = 0,
                 num_nodes: int = 1,
                 num_processes: int = 1,
                 gpus: Optional[Union[List[int], str, int]] = None,
                 auto_select_gpus: bool = False,
                 tpu_cores: Optional[Union[List[int], str, int]] = None,
                 log_gpu_memory: Optional[str] = None,
                 progress_bar_refresh_rate: Optional[int] = None,
                 overfit_batches: Union[int, float] = 0.0,
                 track_grad_norm: Union[int, float, str] = -1,
                 check_val_every_n_epoch: int = 1,
                 fast_dev_run: Union[int, bool] = False,
                 accumulate_grad_batches: Union[int, Dict[int, int],
                                                List[list]] = 1,
                 max_epochs: Optional[int] = None,
                 min_epochs: Optional[int] = None,
                 max_steps: Optional[int] = None,
                 min_steps: Optional[int] = None,
                 max_time: Optional[Union[str, timedelta, Dict[str,
                                                               int]]] = None,
                 limit_train_batches: Union[int, float] = 1.0,
                 limit_val_batches: Union[int, float] = 1.0,
                 limit_test_batches: Union[int, float] = 1.0,
                 limit_predict_batches: Union[int, float] = 1.0,
                 val_check_interval: Union[int, float] = 1.0,
                 flush_logs_every_n_steps: int = 100,
                 log_every_n_steps: int = 50,
                 accelerator: Optional[Union[str, Accelerator]] = None,
                 sync_batchnorm: bool = False,
                 precision: int = 32,
                 weights_summary: Optional[str] = 'top',
                 weights_save_path: Optional[str] = None,
                 num_sanity_val_steps: int = 2,
                 truncated_bptt_steps: Optional[int] = None,
                 resume_from_checkpoint: Optional[Union[Path, str]] = None,
                 profiler: Optional[Union[BaseProfiler, str]] = None,
                 benchmark: bool = False,
                 deterministic: bool = False,
                 reload_dataloaders_every_epoch: bool = False,
                 auto_lr_find: Union[bool, str] = False,
                 replace_sampler_ddp: bool = True,
                 terminate_on_nan: bool = False,
                 auto_scale_batch_size: Union[str, bool] = False,
                 prepare_data_per_node: bool = True,
                 plugins: Optional[Union[Plugin, str, list]] = None,
                 amp_backend: str = 'native',
                 amp_level: str = 'O2',
                 distributed_backend: Optional[str] = None,
                 move_metrics_to_cpu: bool = False,
                 multiple_trainloader_mode: str = 'max_size_cycle',
                 stochastic_weight_avg: bool = False):
        r"""
        Customize every aspect of training via flags

        Args:

            accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...).
                Can also take in an accelerator object for custom hardware.

            accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

            amp_backend: The mixed precision backend to use ("native" or "apex")

            amp_level: The optimization level to use (O1, O2, etc...).

            auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
                trying to optimize initial learning for faster convergence. trainer.tune() method will
                set the suggested learning rate in self.lr or self.learning_rate in the LightningModule.
                To use a different key set a string instead of True with the key name.

            auto_scale_batch_size: If set to True, will `initially` run a batch size
                finder trying to find the largest batch size that fits into memory.
                The result will be stored in self.batch_size in the LightningModule.
                Additionally, can be set to either `power` that estimates the batch size through
                a power search or `binsearch` that estimates the batch size through a binary search.

            auto_select_gpus: If enabled and `gpus` is an integer, pick available
                gpus automatically. This is especially useful when
                GPUs are configured to be in "exclusive mode", such
                that only one process at a time can access them.

            benchmark: If true enables cudnn.benchmark.

            callbacks: Add a callback or list of callbacks.

            checkpoint_callback: If ``True``, enable checkpointing.
                It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
                :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.

            check_val_every_n_epoch: Check val every n train epochs.

            default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
                Default: ``os.getcwd()``.
                Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'

            deterministic: If true enables cudnn.deterministic.

            distributed_backend: deprecated. Please use 'accelerator'

            fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
                of train, val and test to find any bugs (ie: a sort of unit test).

            flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).

            gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node

            gradient_clip_val: 0 means don't clip.

            gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. Default: 'norm'

            limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches)

            limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches)

            limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches)

            limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches)

            logger: Logger (or iterable collection of loggers) for experiment tracking.

            log_gpu_memory: None, 'min_max', 'all'. Might slow performance

            log_every_n_steps: How often to log within steps (defaults to every 50 steps).

            prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
                Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data

            process_position: orders the progress bar when running multiple models on same machine.

            progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
                Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means
                a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).

            profiler: To profile individual steps during training and assist in identifying bottlenecks.

            overfit_batches: Overfit a fraction of training data (float) or a set number of batches (int).

            plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.

            precision: Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or
                TPUs.

            max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
                If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.

            min_epochs: Force training for at least these many epochs. Disabled by default (None).
                If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1.

            max_steps: Stop training after this number of steps. Disabled by default (None).

            min_steps: Force training for at least these number of steps. Disabled by default (None).

            max_time: Stop training after this amount of time has passed. Disabled by default (None).
                The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
                :class:`datetime.timedelta`, or a dictionary with keys that will be passed to
                :class:`datetime.timedelta`.

            num_nodes: number of GPU nodes for distributed training.

            num_processes: number of processes for distributed training with distributed_backend="ddp_cpu"

            num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
                Set it to `-1` to run all batches in all validation dataloaders.

            reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.

            replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this
                will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for
                train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
                you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.

            resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is
                no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
                training will start from the beginning of the next epoch.

            sync_batchnorm: Synchronize batch norm layers between process groups/whole world.

            terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
                end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

            tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]

            track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.

            truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
                sequence.

            val_check_interval: How often to check the validation set. Use float to check within a training epoch,
                use int to check every n steps (batches).

            weights_summary: Prints a summary of the weights when training begins.

            weights_save_path: Where to save weights if specified. Will override default_root_dir
                for checkpoints only. Use this if for whatever reason you need the checkpoints
                stored in a different place than the logs written in `default_root_dir`.
                Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
                Defaults to `default_root_dir`.

            move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu.
                This can save some gpu memory, but can make training slower. Use with attention.

            multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders.
                In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
                and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
                reload when reaching the minimum length of datasets.

            stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA)
                <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>_`

        """
        super().__init__()
        Trainer._log_api_event("init")
        distributed_backend = distributed_backend or accelerator

        # init connectors
        self.dev_debugger = InternalDebugger(self)
        self.config_validator = ConfigValidator(self)
        self.data_connector = DataConnector(self)
        self.optimizer_connector = OptimizerConnector(self)

        self.accelerator_connector = AcceleratorConnector(
            num_processes, tpu_cores, distributed_backend, auto_select_gpus,
            gpus, num_nodes, sync_batchnorm, benchmark, replace_sampler_ddp,
            deterministic, precision, amp_backend, amp_level, plugins)
        self.logger_connector = LoggerConnector(self, log_gpu_memory)
        self.model_connector = ModelConnector(self)
        self.callback_connector = CallbackConnector(self)
        self.debugging_connector = DebuggingConnector(self)
        self.training_tricks_connector = TrainingTricksConnector(self)
        self.profile_connector = ProfilerConnector(self)
        self.checkpoint_connector = CheckpointConnector(self)
        self.slurm_connector = SLURMConnector(self)
        self.tuner = Tuner(self)
        self.train_loop = TrainLoop(self, multiple_trainloader_mode)
        self.evaluation_loop = EvaluationLoop(self)
        self.predict_loop = PredictLoop(self)

        # training state
        if weights_summary is not None and weights_summary not in ModelSummary.MODES:
            raise MisconfigurationException(
                f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}"
            )
        self.weights_summary = weights_summary
        self.shown_warnings = set()

        # init callbacks
        # Declare attributes to be set in callback_connector on_trainer_init
        self.callback_connector.on_trainer_init(
            callbacks,
            checkpoint_callback,
            progress_bar_refresh_rate,
            process_position,
            default_root_dir,
            weights_save_path,
            resume_from_checkpoint,
            stochastic_weight_avg,
            max_time,
        )

        # hook
        self.on_init_start()

        # init optimizer + lr scheduler related flags
        self.optimizer_connector.on_trainer_init()

        # init data flags
        self.data_connector.on_trainer_init(check_val_every_n_epoch,
                                            reload_dataloaders_every_epoch,
                                            prepare_data_per_node)

        # init training tricks
        self.training_tricks_connector.on_trainer_init(
            gradient_clip_val,
            gradient_clip_algorithm,
            track_grad_norm,
            accumulate_grad_batches,
            truncated_bptt_steps,
            terminate_on_nan,
        )
        self.train_loop.on_trainer_init(
            max_epochs,
            min_epochs,
            max_steps,
            min_steps,
            num_sanity_val_steps,
        )
        self.evaluation_loop.on_trainer_init()

        # configure tuner
        self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

        # configure profiler
        self.profile_connector.on_trainer_init(profiler)

        # init logger flags
        self.logger_connector.on_trainer_init(
            logger,
            flush_logs_every_n_steps,
            log_every_n_steps,
            move_metrics_to_cpu,
        )

        # init debugging flags
        self.debugging_connector.on_init_start(
            limit_train_batches,
            limit_val_batches,
            limit_test_batches,
            limit_predict_batches,
            val_check_interval,
            overfit_batches,
            fast_dev_run,
        )

        # Callback system
        self.on_init_end()

    def fit(
        self,
        model: LightningModule,
        train_dataloader: Any = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Runs the full optimization routine.

        Args:
            datamodule: A instance of :class:`LightningDataModule`.

            model: Model to fit.

            train_dataloader: Either a single PyTorch DataLoader or a collection of these
                (list, dict, nested lists and dicts). In the case of multiple dataloaders, please
                see this :ref:`page <multiple-training-dataloaders>`

            val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        """
        Trainer._log_api_event("fit")
        # we reuse fit for other functions. When already set, it shouldn't be modified.
        if not self.state.running:
            self.state = TrainerState.FITTING
        if self._running_stage is None or self.tuning:
            self.training = True

        # set local properties on the model
        self.model_connector.copy_trainer_model_properties(model)

        # ----------------------------
        # LINK DATA
        # ----------------------------
        # setup data, etc...
        self.train_loop.setup_fit(model, train_dataloader, val_dataloaders,
                                  datamodule)

        # hook
        self.data_connector.prepare_data(model)
        self.callback_connector._attach_model_callbacks(model, self)

        # ----------------------------
        # SET UP TRAINING
        # ----------------------------
        self.call_hook("on_before_accelerator_backend_setup", model)
        self.accelerator.connect(model)
        self.accelerator.setup_environment()
        self.call_setup_hook(
            model
        )  # allow user to setup lightning_module in accelerator environment
        self.call_configure_sharded_model(
            model)  # allow user to setup in model sharded environment
        self.accelerator.setup(
            self, model)  # note: this sets up self.lightning_module

        # ----------------------------
        # INSPECT THE CORE LOOPS
        # ----------------------------
        f"""
             Lightning internal flow looks like this:
        {Trainer.fit} or {Trainer.test} or {Trainer.predict}  ||
                                |                             ||
                        create accelerator                    ||
                                |                             ||
                         {self.dispatch}                      ||
                                |                             ||  LIGHTNING
                  {self.accelerator.start_training}           ||
                or {self.accelerator.start_evaluating}        ||
                or {self.accelerator.start_predicting}        ||  FLOW
                                |                             ||
                         {self.run_stage}                     ||
                                |                             ||  DIRECTION
                        {self.run_train}                      ||
                     or {self.run_evaluation}                 ||
                     or {self.run_predict}                    ||
                                |                             ||
                             results                          \/
        This is used to guide readers to the core loops: train, test, predict.
        {self.run_predict} is the simplest to understand, use `Go to Definition` to read it :)
        Search for `start_training` or `start_evaluating` or `start_predicting` in
        `pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions.
        """  # noqa: W605

        # ----------------------------
        # TRAIN
        # ----------------------------
        # hook
        if self.state == TrainerState.FITTING:
            self.call_hook("on_fit_start")

        # plugin will setup fitting (e.g. ddp will launch child processes)
        self.pre_dispatch()

        # dispatch `start_training` or `start_evaluating` or `start_predicting`
        self.dispatch()

        # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
        self.post_dispatch()

        # ----------------------------
        # POST-Training CLEAN UP
        # ----------------------------
        # hook
        if self.state == TrainerState.FITTING:
            self.call_hook('on_fit_end')

        # teardown
        self.call_teardown_hook(model)

        if self.state != TrainerState.INTERRUPTED:
            self.state = TrainerState.FINISHED
        self._running_stage = None

        # return 1 when finished
        # used for testing or when we need to know that training succeeded
        return self.accelerator.results or 1

    def pre_dispatch(self):
        self.accelerator.pre_dispatch(self)

        # log hyper-parameters
        if self.logger is not None:
            # save exp to get started (this is where the first experiment logs are written)
            self.logger.log_hyperparams(self.lightning_module.hparams_initial)
            self.logger.log_graph(self.lightning_module)
            self.logger.save()

    def post_dispatch(self):
        self.accelerator.post_dispatch(self)
        self.accelerator.teardown()

    def dispatch(self):
        if self.evaluating:
            self.accelerator.start_evaluating(self)
        elif self.predicting:
            self.accelerator.start_predicting(self)
        else:
            self.accelerator.start_training(self)

    def run_stage(self):
        results = None

        self.profile_connector.setup()

        if self.evaluating:
            results = self.run_evaluate()
        elif self.predicting:
            results = self.run_predict()
        else:
            self.run_train()
        return results

    def _pre_training_routine(self):
        # wait for all to join if on distributed
        self.accelerator.barrier("setup_training")

        # register auto-resubmit when on SLURM
        self.slurm_connector.register_slurm_signal_handlers()

        # --------------------------
        # Pre-train
        # --------------------------
        # on pretrain routine start
        ref_model = self.lightning_module

        self.on_pretrain_routine_start()
        ref_model.on_pretrain_routine_start()

        # print model summary
        if self.is_global_zero and self.weights_summary is not None and not self.testing:
            ref_model.summarize(mode=self.weights_summary)

        # restore training and model before hpc is called
        self.checkpoint_connector.restore_weights()

        # on pretrain routine end
        self.on_pretrain_routine_end()
        ref_model.on_pretrain_routine_end()

    def run_train(self) -> None:

        self._pre_training_routine()

        if not self.is_global_zero and self.progress_bar_callback is not None:
            self.progress_bar_callback.disable()

        self.run_sanity_check(self.lightning_module)

        self.checkpoint_connector.has_trained = False

        # enable train mode
        self.model.train()
        torch.set_grad_enabled(True)

        # reload data when needed
        model = self.lightning_module
        self.train_loop.reset_train_val_dataloaders(model)

        # hook
        self.train_loop.on_train_start()

        try:
            if self.train_loop.should_skip_training():
                return
            # run all epochs
            epochs = range(self.current_epoch,
                           self.max_epochs) if self.max_epochs else count(
                               self.current_epoch)
            for epoch in epochs:

                # hook
                self.train_loop.on_train_epoch_start(epoch)

                with self.profiler.profile("run_training_epoch"):
                    # run train epoch
                    self.train_loop.run_training_epoch()

                if self.max_steps and self.max_steps <= self.global_step:
                    self.train_loop.on_train_end()
                    return

                # early stopping
                met_min_epochs = (
                    epoch >= self.min_epochs - 1) if self.min_epochs else True
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                if self.should_stop:
                    if met_min_epochs and met_min_steps:
                        self.train_loop.on_train_end()
                        return
                    else:
                        log.info(
                            'Trainer was signaled to stop but required minimum epochs'
                            f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
                            ' not been met. Training will continue...')
                        self.should_stop = False

            # hook
            self.train_loop.on_train_end()

        except KeyboardInterrupt:
            rank_zero_warn(
                'Detected KeyboardInterrupt, attempting graceful shutdown...')
            # user could press Ctrl+c many times... only shutdown once
            if not self.interrupted:
                self.state = TrainerState.INTERRUPTED
                self.on_keyboard_interrupt()
                # same treatment as below
                self.accelerator.on_train_end()
                self._running_stage = None
        except BaseException:
            # give accelerators a chance to finish
            self.accelerator.on_train_end()
            # reset bookkeeping
            self._running_stage = None
            raise

    def run_evaluation(self, on_epoch=False):
        if not (self.evaluating or self.sanity_checking):
            rank_zero_warn(
                f"`trainer.run_evaluation()` was called but the running stage is set to {self._running_stage}."
                " This should not happen normally. Setting it to `RunningStage.VALIDATING`",
                RuntimeWarning)
            self.validating = True

        # prepare dataloaders
        dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(
        )

        # check if we want to skip this evaluation
        if self.evaluation_loop.should_skip_evaluation(max_batches):
            return [], []

        # enable eval mode + no grads
        self.evaluation_loop.on_evaluation_model_eval()
        # ref model
        model = self.lightning_module
        model.zero_grad()
        torch.set_grad_enabled(False)

        # hook
        self.evaluation_loop.on_evaluation_start()

        # set up the eval loop
        self.evaluation_loop.setup(max_batches, dataloaders)

        # hook
        self.evaluation_loop.on_evaluation_epoch_start()

        # run validation/testing
        for dataloader_idx, dataloader in enumerate(dataloaders):
            # bookkeeping
            dl_outputs = []
            dataloader = self.accelerator.process_dataloader(dataloader)
            dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when running on limited batches
                if batch_idx >= dl_max_batches:
                    break

                # hook
                self.evaluation_loop.on_evaluation_batch_start(
                    batch, batch_idx, dataloader_idx)

                # lightning module methods
                with self.profiler.profile("evaluation_step_and_end"):
                    output = self.evaluation_loop.evaluation_step(
                        batch, batch_idx, dataloader_idx)
                    output = self.evaluation_loop.evaluation_step_end(output)

                # hook + store predictions
                self.evaluation_loop.on_evaluation_batch_end(
                    output, batch, batch_idx, dataloader_idx)

                # log batch metrics
                self.evaluation_loop.log_evaluation_step_metrics(batch_idx)

                # track epoch level outputs
                dl_outputs = self.track_output_for_epoch_end(
                    dl_outputs, output)

            # store batch level output per dataloader
            self.evaluation_loop.outputs.append(dl_outputs)

        outputs = self.evaluation_loop.outputs

        # reset outputs
        self.evaluation_loop.outputs = []

        # with a single dataloader don't pass a 2D list
        if self.evaluation_loop.num_dataloaders == 1:
            outputs = outputs[0]

        # lightning module method
        self.evaluation_loop.evaluation_epoch_end(outputs)

        # hook
        self.evaluation_loop.on_evaluation_epoch_end(outputs)

        # update epoch-level lr_schedulers
        if on_epoch:
            self.optimizer_connector.update_learning_rates(interval='epoch')

        # hook
        self.evaluation_loop.on_evaluation_end()

        # log epoch metrics
        eval_loop_results = self.logger_connector.get_evaluate_epoch_results()

        # save predictions to disk
        self.evaluation_loop.predictions.to_disk()

        # enable train mode again
        self.evaluation_loop.on_evaluation_model_train()

        # reset cached results
        self.logger_connector.reset()

        torch.set_grad_enabled(True)

        return eval_loop_results

    def track_output_for_epoch_end(self, outputs, output):
        if output is not None:
            if isinstance(output, Result):
                output = output.detach()
                if self.move_metrics_to_cpu:
                    output = output.cpu()
            elif isinstance(output, dict):
                output = recursive_detach(output,
                                          to_cpu=self.move_metrics_to_cpu)
            elif isinstance(output, torch.Tensor
                            ) and output.is_cuda and self.move_metrics_to_cpu:
                output = output.cpu()
            outputs.append(output)
        return outputs

    def run_evaluate(self):
        if not self.is_global_zero and self.progress_bar_callback is not None:
            self.progress_bar_callback.disable()

        assert self.evaluating

        with self.profiler.profile(f"run_{self._running_stage}_evaluation"):
            eval_loop_results = self.run_evaluation()

        if len(eval_loop_results) == 0:
            return 1

        # remove the tensors from the eval results
        for i, result in enumerate(eval_loop_results):
            if isinstance(result, dict):
                for k, v in result.items():
                    if isinstance(v, torch.Tensor):
                        result[k] = v.cpu().item()

        return eval_loop_results

    def run_predict(self):
        # prepare dataloaders
        dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()

        # check if we want to skip this evaluation
        if self.predict_loop.should_skip_predict(max_batches):
            return []

        # ref model
        model = self.lightning_module

        # enable eval mode + no grads
        self.predict_loop.on_predict_model_eval()
        model.zero_grad()
        torch.set_grad_enabled(False)

        # set up the eval loop
        self.predict_loop.setup(model, max_batches, dataloaders)

        # call hook
        self.call_hook("on_predict_start")
        self.call_hook("on_predict_epoch_start")

        # run validation/testing
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dataloader = self.accelerator.process_dataloader(dataloader)
            dl_max_batches = self.predict_loop.max_batches[dataloader_idx]
            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when running on limited batches
                if batch_idx >= dl_max_batches:
                    break

                # lightning module methods
                with self.profiler.profile("predict_step"):
                    self.predict_loop.predict_step(batch, batch_idx,
                                                   dataloader_idx)

        results = self.predict_loop.on_predict_epoch_end()
        self.call_hook("on_predict_end")

        # re-enable grads
        torch.set_grad_enabled(True)

        return results

    def run_sanity_check(self, ref_model):
        using_val_step = ref_model.val_dataloader is not None and is_overridden(
            'validation_step', ref_model)
        should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

        # run tiny validation (if validation defined)
        # to make sure program won't crash during val
        if should_sanity_check:
            stage = self._running_stage
            self.sanity_checking = True

            # hook and callback
            self.on_sanity_check_start()

            # run eval step
            self.run_evaluation()

            self.on_sanity_check_end()

            self._running_stage = stage

    def validate(
        self,
        model: Optional[LightningModule] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        ckpt_path: Optional[str] = 'best',
        verbose: bool = True,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Perform one evaluation epoch over the validation set.

        Args:
            model: The model to validate.

            val_dataloaders: Either a single PyTorch DataLoader or a list of them,
                specifying validation samples.

            ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
                If ``None``, use the current weights of the model.
                When the model is given as argument, this parameter will not apply.

            verbose: If True, prints the validation results.

            datamodule: A instance of :class:`LightningDataModule`.

        Returns:
            The dictionary with final validation results returned by validation_epoch_end.
            If validation_epoch_end is not defined, the output is a list of the dictionaries
            returned by validation_step.
        """
        # --------------------
        # SETUP HOOK
        # --------------------
        Trainer._log_api_event("validate")
        self.verbose_evaluate = verbose

        self.state = TrainerState.VALIDATING
        self.validating = True

        # If you supply a datamodule you can't supply val_dataloaders
        if val_dataloaders is not None and datamodule:
            raise MisconfigurationException(
                'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`'
            )

        model_provided = model is not None
        model = model or self.lightning_module

        # Attach datamodule to get setup/prepare_data added to model before the call to it below
        self.data_connector.attach_datamodule(model, datamodule)
        #  Attach dataloaders (if given)
        self.data_connector.attach_dataloaders(model,
                                               val_dataloaders=val_dataloaders)

        if not model_provided:
            self.validated_ckpt_path = self.__load_ckpt_weights(
                model, ckpt_path=ckpt_path)

        # run validate
        results = self.fit(model)

        assert self.state.stopped
        self.validating = False

        return results

    def test(
        self,
        model: Optional[LightningModule] = None,
        test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        ckpt_path: Optional[str] = 'best',
        verbose: bool = True,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Perform one evaluation epoch over the test set. It's separated from
        fit to make sure you never run on your test set until you want to.

        Args:
            model: The model to test.

            test_dataloaders: Either a single PyTorch DataLoader or a list of them,
                specifying test samples.

            ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
                If ``None``, use the current weights of the model.
                When the model is given as argument, this parameter will not apply.

            verbose: If True, prints the test results.

            datamodule: A instance of :class:`LightningDataModule`.

        Returns:
            Returns a list of dictionaries, one for each test dataloader containing their respective metrics.
        """
        # --------------------
        # SETUP HOOK
        # --------------------
        Trainer._log_api_event("test")
        self.verbose_evaluate = verbose

        self.state = TrainerState.TESTING
        self.testing = True

        # If you supply a datamodule you can't supply test_dataloaders
        if test_dataloaders is not None and datamodule:
            raise MisconfigurationException(
                'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`'
            )

        model_provided = model is not None
        model = model or self.lightning_module

        # Attach datamodule to get setup/prepare_data added to model before the call to it below
        self.data_connector.attach_datamodule(model, datamodule)
        #  Attach dataloaders (if given)
        self.data_connector.attach_dataloaders(
            model, test_dataloaders=test_dataloaders)

        if not model_provided:
            self.tested_ckpt_path = self.__load_ckpt_weights(
                model, ckpt_path=ckpt_path)

        # run test
        results = self.fit(model)

        assert self.state.stopped
        self.testing = False

        return results

    def __load_ckpt_weights(
        self,
        model,
        ckpt_path: Optional[str] = None,
    ) -> Optional[str]:
        if ckpt_path is None:
            return

        fn = self.state.value

        if ckpt_path == 'best':
            # if user requests the best checkpoint but we don't have it, error
            if not self.checkpoint_callback.best_model_path:
                if self.fast_dev_run:
                    raise MisconfigurationException(
                        f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do'
                        f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.'
                    )
                raise MisconfigurationException(
                    f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
                )
            # load best weights
            ckpt_path = self.checkpoint_callback.best_model_path

        if not ckpt_path:
            raise MisconfigurationException(
                f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
                f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`')

        # only one process running at this point for TPUs, as spawn isn't triggered yet
        # todo: move this logic internally within the barrier.
        if not self._device_type == DeviceType.TPU:
            self.training_type_plugin.barrier()

        self.training_type_plugin.restore_model_state_from_ckpt_path(
            ckpt_path, map_location=lambda storage, loc: storage)
        return ckpt_path

    def predict(
        self,
        model: Optional[LightningModule] = None,
        dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""

        Separates from fit to make sure you never run on your predictions set until you want to.
        This will call the model forward function to compute predictions.

        Args:
            model: The model to predict with.
            dataloaders: Either a single PyTorch DataLoader or a list of them, specifying inference samples.
            datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders.

        Returns:
            Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
        """

        # --------------------
        # SETUP HOOK
        # --------------------
        # If you supply a datamodule you can't supply dataloaders
        Trainer._log_api_event("predict")

        model = model or self.lightning_module

        self.state = TrainerState.PREDICTING
        self.predicting = True

        if dataloaders is not None and datamodule:
            raise MisconfigurationException(
                'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
            )

        # Attach datamodule to get setup/prepare_data added to model before the call to it below
        self.data_connector.attach_datamodule(model, datamodule)
        #  Attach dataloaders (if given)
        self.data_connector.attach_dataloaders(model,
                                               predict_dataloaders=dataloaders)

        results = self.fit(model)

        assert self.state.stopped
        self.predicting = False

        return results

    def tune(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Runs routines to tune hyperparameters before training.

        Args:
            datamodule: A instance of :class:`LightningDataModule`.

            model: Model to tune.

            train_dataloader: A Pytorch DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        """
        Trainer._log_api_event("tune")
        self.state = TrainerState.TUNING
        self.tuning = True

        self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule)

        assert self.state.stopped
        self.tuning = False

    def call_setup_hook(self, model: LightningModule) -> None:
        assert self.state.running, f"TrainerState: {self.state}"
        state = self._setup_state

        if self.datamodule is not None:
            called = getattr(self.datamodule, f'has_setup_{state}')
            if not called:
                self.datamodule.setup(stage=state)

        self.setup(model, stage=state)
        model.setup(stage=state)

    def call_configure_sharded_model(self, model: LightningModule) -> None:
        # Call configure sharded model hook if accelerator requests. In some cases
        # we will not call the hook; the hook has initialized the sharded model for example.

        # used on the model if the user re-create a trainer with resume_from_checkpoint
        model_call_configure_sharded_model_hook = getattr(
            model, "call_configure_sharded_model_hook", False)
        if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook:
            with self.accelerator.model_sharded_context():
                model.configure_sharded_model()
                self.configure_sharded_model(model)
            model.call_configure_sharded_model_hook = True
            self.accelerator.call_configure_sharded_model_hook = False

    def call_teardown_hook(self, model: LightningModule) -> None:
        state = self._teardown_state

        if self.datamodule is not None:
            called = getattr(self.datamodule, f'has_teardown_{state}')
            if not called:
                self.datamodule.teardown(stage=state)

        self.profiler.teardown(stage=state)
        self.teardown(stage=state)
        model.teardown(stage=state)

    def _reset_result_and_set_hook_fx_name(self, hook_name):
        # on_before_zero_grad is called within training_step
        if "batch_start" in hook_name or "on_before_zero_grad" in hook_name:
            return True
        model_ref = self.lightning_module
        if model_ref is not None:
            # used to track current hook name called
            model_ref._results = Result()
            model_ref._current_hook_fx_name = hook_name
        return False

    def _cache_logged_metrics(self):
        model_ref = self.lightning_module
        if model_ref is not None:
            # capture logging for this hook
            self.logger_connector.cache_logged_metrics()

    def call_hook(self, hook_name, *args, **kwargs):
        # set hook_name to model + reset Result obj
        skip = self._reset_result_and_set_hook_fx_name(hook_name)

        # always profile hooks
        with self.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self, hook_name):
                trainer_hook = getattr(self, hook_name)
                trainer_hook(*args, **kwargs)

            # next call hook in lightningModule
            output = None
            model_ref = self.lightning_module
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                output = hook_fx(*args, **kwargs)

            # if the PL module doesn't have the hook then call the accelerator
            # used to auto-reduce things for the user with Results obj
            elif hasattr(self.accelerator, hook_name):
                accelerator_hook = getattr(self.accelerator, hook_name)
                output = accelerator_hook(*args, **kwargs)

        if not skip:
            self._cache_logged_metrics()
        return output

    @staticmethod
    def _log_api_event(event: str) -> None:
        torch._C._log_api_usage_once("lightning.trainer." + event)
예제 #4
0
class Trainer(
        TrainerProperties,
        TrainerCallbackHookMixin,
        TrainerModelHooksMixin,
        TrainerOptimizersMixin,
        TrainerLoggingMixin,
        TrainerTrainingTricksMixin,
        TrainerDataLoadingMixin,
):
    @overwrite_by_env_vars
    def __init__(
        self,
        logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                      bool] = True,
        checkpoint_callback: bool = True,
        callbacks: Optional[List[Callback]] = None,
        default_root_dir: Optional[str] = None,
        gradient_clip_val: float = 0,
        process_position: int = 0,
        num_nodes: int = 1,
        num_processes: int = 1,
        gpus: Optional[Union[List[int], str, int]] = None,
        auto_select_gpus: bool = False,
        tpu_cores: Optional[Union[List[int], str, int]] = None,
        log_gpu_memory: Optional[str] = None,
        progress_bar_refresh_rate: int = 1,
        overfit_batches: Union[int, float] = 0.0,
        track_grad_norm: Union[int, float, str] = -1,
        check_val_every_n_epoch: int = 1,
        fast_dev_run: bool = False,
        accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
        max_epochs: int = 1000,
        min_epochs: int = 1,
        max_steps: Optional[int] = None,
        min_steps: Optional[int] = None,
        limit_train_batches: Union[int, float] = 1.0,
        limit_val_batches: Union[int, float] = 1.0,
        limit_test_batches: Union[int, float] = 1.0,
        val_check_interval: Union[int, float] = 1.0,
        flush_logs_every_n_steps: int = 100,
        log_every_n_steps: int = 50,
        accelerator: Optional[Union[str, Accelerator]] = None,
        sync_batchnorm: bool = False,
        precision: int = 32,
        weights_summary: Optional[str] = 'top',
        weights_save_path: Optional[str] = None,
        num_sanity_val_steps: int = 2,
        truncated_bptt_steps: Optional[int] = None,
        resume_from_checkpoint: Optional[str] = None,
        profiler: Optional[Union[BaseProfiler, bool, str]] = None,
        benchmark: bool = False,
        deterministic: bool = False,
        reload_dataloaders_every_epoch: bool = False,
        auto_lr_find: Union[bool, str] = False,
        replace_sampler_ddp: bool = True,
        terminate_on_nan: bool = False,
        auto_scale_batch_size: Union[str, bool] = False,
        prepare_data_per_node: bool = True,
        plugins: Optional[list] = None,
        amp_backend: str = 'native',
        amp_level: str = 'O2',
        distributed_backend: Optional[str] = None,
        automatic_optimization: bool = True,
    ):
        r"""
        Customize every aspect of training via flags

        Args:

            accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...).
                Can also take in an accelerator object for custom hardware.

            accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

            amp_backend: The mixed precision backend to use ("native" or "apex")

            amp_level: The optimization level to use (O1, O2, etc...).

            auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder,
                trying to optimize initial learning for faster convergence. trainer.tune() method will
                set the suggested learning rate in self.lr or self.learning_rate in the LightningModule.
                To use a different key set a string instead of True with the key name.

            auto_scale_batch_size: If set to True, will `initially` run a batch size
                finder trying to find the largest batch size that fits into memory.
                The result will be stored in self.batch_size in the LightningModule.
                Additionally, can be set to either `power` that estimates the batch size through
                a power search or `binsearch` that estimates the batch size through a binary search.

            auto_select_gpus: If enabled and `gpus` is an integer, pick available
                gpus automatically. This is especially useful when
                GPUs are configured to be in "exclusive mode", such
                that only one process at a time can access them.

            benchmark: If true enables cudnn.benchmark.

            callbacks: Add a list of callbacks.

            checkpoint_callback: If ``True``, enable checkpointing.
                It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
                :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.

                .. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
                    v1.1.0 and will be unsupported from v1.4.0.

            check_val_every_n_epoch: Check val every n train epochs.

            default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
                Default: ``os.getcwd()``.
                Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'

            deterministic: If true enables cudnn.deterministic.

            distributed_backend: deprecated. Please use 'accelerator'

            fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).

            flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).

            gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node

            gradient_clip_val: 0 means don't clip.

            limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)

            limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)

            limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)

            logger: Logger (or iterable collection of loggers) for experiment tracking.

            log_gpu_memory: None, 'min_max', 'all'. Might slow performance

            log_every_n_steps: How often to log within steps (defaults to every 50 steps).

            automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad.
                Meant to be used with multiple optimizers by advanced users.

            prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
                Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data

            process_position: orders the progress bar when running multiple models on same machine.

            progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
                Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.

            profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
                value is deprecated in v1.1 and will be removed in v1.3.

            overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0

            plugins: Plugins allow modification of core behavior like ddp and amp.

            precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.

            max_epochs: Stop training once this number of epochs is reached.

            min_epochs: Force training for at least these many epochs

            max_steps: Stop training after this number of steps. Disabled by default (None).

            min_steps: Force training for at least these number of steps. Disabled by default (None).

            num_nodes: number of GPU nodes for distributed training.

            num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
                Set it to `-1` to run all batches in all validation dataloaders. Default: 2

            reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.

            replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this
                will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for
                train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
                you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.

            resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
                This can be a URL.

            sync_batchnorm: Synchronize batch norm layers between process groups/whole world.

            terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
                end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

            tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]

            track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.

            truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
                sequence.

            val_check_interval: How often to check the validation set. Use float to check within a training epoch,
                use int to check every n steps (batches).

            weights_summary: Prints a summary of the weights when training begins.

            weights_save_path: Where to save weights if specified. Will override default_root_dir
                    for checkpoints only. Use this if for whatever reason you need the checkpoints
                    stored in a different place than the logs written in `default_root_dir`.
                    Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
                    Defaults to `default_root_dir`.
        """
        super().__init__()

        # init connectors
        self.dev_debugger = InternalDebugger(self)
        self.config_validator = ConfigValidator(self)
        self.data_connector = DataConnector(self)
        self.optimizer_connector = OptimizerConnector(self)
        self.accelerator_connector = AcceleratorConnector(self)
        self.logger_connector = LoggerConnector(self)
        self.model_connector = ModelConnector(self)
        self.precision_connector = PrecisionConnector(self)
        self.callback_connector = CallbackConnector(self)
        self.debugging_connector = DebuggingConnector(self)
        self.training_tricks_connector = TrainingTricksConnector(self)
        self.profile_connector = ProfilerConnector(self)
        self.checkpoint_connector = CheckpointConnector(self)
        self.slurm_connector = SLURMConnector(self)
        self.tuner = Tuner(self)
        self.accelerator_backend = None
        self.evaluation_loop = EvaluationLoop(self)
        self.train_loop = TrainLoop(self)
        self.plugin_connector = PluginConnector(self)

        # training state
        self.weights_summary = weights_summary
        self.model = None
        self.shown_warnings = set()

        # init callbacks
        # Declare attributes to be set in callback_connector on_trainer_init
        self.callback_connector.on_trainer_init(
            callbacks,
            checkpoint_callback,
            progress_bar_refresh_rate,
            process_position,
            default_root_dir,
            weights_save_path,
            resume_from_checkpoint,
        )

        # hook
        self.on_init_start()

        # init optimizer + lr scheduler related flags
        self.optimizer_connector.on_trainer_init()

        # init data flags
        self.data_connector.on_trainer_init(check_val_every_n_epoch,
                                            reload_dataloaders_every_epoch,
                                            prepare_data_per_node)

        # init training tricks
        self.training_tricks_connector.on_trainer_init(
            gradient_clip_val, track_grad_norm, accumulate_grad_batches,
            truncated_bptt_steps, terminate_on_nan)

        # init accelerator related flags
        self.accelerator_connector.on_trainer_init(
            num_processes,
            tpu_cores,
            accelerator,
            distributed_backend,
            auto_select_gpus,
            gpus,
            num_nodes,
            log_gpu_memory,
            sync_batchnorm,
            benchmark,
            replace_sampler_ddp,
            deterministic,
        )

        # init train loop related flags
        self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps,
                                        min_steps, num_sanity_val_steps,
                                        automatic_optimization)
        self.evaluation_loop.on_trainer_init()

        # configure tuner
        self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

        # configure profiler
        self.profile_connector.on_trainer_init(profiler)

        # init logger flags
        self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps,
                                              log_every_n_steps)

        # init debugging flags
        self.debugging_connector.on_init_start(
            limit_train_batches,
            limit_val_batches,
            limit_test_batches,
            val_check_interval,
            overfit_batches,
            fast_dev_run,
        )

        # set precision
        self.precision_connector.on_trainer_init(precision, amp_level,
                                                 amp_backend)

        # last thing are the plugins which override whatever the trainer used by default
        self.plugin_connector.on_trainer_init(plugins)

        # Callback system
        self.on_init_end()

    def fit(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Runs the full optimization routine.

        Args:
            datamodule: A instance of :class:`LightningDataModule`.

            model: Model to fit.

            train_dataloader: A Pytorch DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        """
        # bookkeeping
        self._state = TrainerState.RUNNING

        # ----------------------------
        # LINK DATA
        # ----------------------------
        # setup data, etc...
        self.train_loop.setup_fit(model, train_dataloader, val_dataloaders,
                                  datamodule)

        # hook
        self.data_connector.prepare_data(model)

        # bookkeeping
        # we reuse fit in .test() but change its behavior using this flag
        self.testing = os.environ.get('PL_TESTING_MODE', self.testing)

        # ----------------------------
        # SET UP TRAINING
        # ----------------------------
        self.accelerator_backend = self.accelerator_connector.select_accelerator(
        )
        self.accelerator_backend.setup(model)

        # ----------------------------
        # INSPECT THESE FOR MAIN LOOPS
        # ----------------------------
        # assign training and eval functions... inspect these to see the train and eval loops :)
        self.accelerator_backend.train_loop = self.train
        self.accelerator_backend.validation_loop = self.run_evaluation
        self.accelerator_backend.test_loop = self.run_evaluation

        # ----------------------------
        # TRAIN
        # ----------------------------
        # hook
        self.call_hook('on_fit_start')

        results = self.accelerator_backend.train()
        self.accelerator_backend.teardown()

        # ----------------------------
        # POST-Training CLEAN UP
        # ----------------------------
        # hook
        self.call_hook('on_fit_end')

        # hook
        self.teardown('fit')
        if self.is_function_implemented('teardown'):
            model.teardown('fit')

        # return 1 when finished
        # used for testing or when we need to know that training succeeded

        if self._state != TrainerState.INTERRUPTED:
            self._state = TrainerState.FINISHED
        return results or 1

    def train(self):
        self.run_sanity_check(self.get_model())

        self.checkpoint_connector.has_trained = False

        # enable train mode
        model = self.get_model()
        model.train()
        torch.set_grad_enabled(True)

        # reload data when needed
        self.train_loop.reset_train_val_dataloaders(model)

        # hook
        self.train_loop.on_train_start()

        try:
            # run all epochs
            for epoch in range(self.current_epoch, self.max_epochs):

                # hook
                self.train_loop.on_train_epoch_start(epoch)

                # run train epoch
                self.train_loop.run_training_epoch()

                if self.max_steps and self.max_steps <= self.global_step:

                    # hook
                    self.train_loop.on_train_end()
                    return

                # update LR schedulers
                self.optimizer_connector.update_learning_rates(
                    interval='epoch')

                # early stopping
                met_min_epochs = epoch >= self.min_epochs - 1
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                if self.should_stop:
                    if met_min_epochs and met_min_steps:
                        self.train_loop.on_train_end()
                        return
                    else:
                        log.info(
                            'Trainer was signaled to stop but required minimum epochs'
                            f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
                            ' not been met. Training will continue...')

            # hook
            self.train_loop.on_train_end()

        except KeyboardInterrupt:
            rank_zero_warn(
                'Detected KeyboardInterrupt, attempting graceful shutdown...')

            # user could press ctrl+c many times... only shutdown once
            if not self.interrupted:
                self.interrupted = True
                self._state = TrainerState.INTERRUPTED
                self.on_keyboard_interrupt()

                # hook
                self.train_loop.on_train_end()

    def run_evaluation(self, test_mode: bool = False, max_batches=None):
        # bookkeeping
        self.evaluation_loop.testing = test_mode
        dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(
            max_batches)
        if self.evaluation_loop.should_skip_evaluation(dataloaders,
                                                       max_batches):
            return [], []

        # enable eval mode + no grads
        model = self.get_model()
        self.evaluation_loop.on_evaluation_model_eval()

        model.zero_grad()
        torch.set_grad_enabled(False)

        # hook
        self.evaluation_loop.on_evaluation_start()

        # set up the eval loop
        self.evaluation_loop.setup(model, max_batches, dataloaders)

        # hook
        # TODO: should this be insider the dataloader loop?
        self.evaluation_loop.on_evaluation_epoch_start()

        # run validation/testing
        for dataloader_idx, dataloader in enumerate(dataloaders):
            # bookkeeping
            dl_outputs = []
            dl_step_metrics = []
            dataloader = self.accelerator_backend.process_dataloader(
                dataloader)
            dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when running on limited batches
                if batch_idx >= dl_max_batches:
                    break

                # hook
                self.evaluation_loop.on_evaluation_batch_start(
                    batch, batch_idx, dataloader_idx)

                # lightning module methods
                output = self.evaluation_loop.evaluation_step(
                    test_mode, batch, batch_idx, dataloader_idx)
                output = self.evaluation_loop.evaluation_step_end(output)

                # hook
                self.evaluation_loop.on_evaluation_batch_end(
                    output, batch, batch_idx, dataloader_idx)

                # clean up
                self.evaluation_loop.evaluation_batch_end_cleanup(
                    output, batch_idx, dataloader_idx)

                # TODO: deprecate 1.0
                self.evaluation_loop.log_evaluation_step_metrics_legacy(
                    output, batch_idx)

                # log step metrics
                step_metrics = self.evaluation_loop.log_evaluation_step_metrics(
                    batch, batch_idx)

                if step_metrics is not None:
                    dl_step_metrics.append(step_metrics)

                # track epoch level outputs
                if output is not None:
                    dl_outputs.append(output)

            self.evaluation_loop.outputs.append(dl_outputs)
            self.evaluation_loop.step_metrics.append(dl_step_metrics)

        # lightning module method
        deprecated_eval_results, epoch_logs = self.evaluation_loop.evaluation_epoch_end(
            num_dataloaders=len(dataloaders))

        # bookkeeping
        eval_loop_results = self.evaluation_loop.log_epoch_metrics(
            deprecated_eval_results, epoch_logs, test_mode)
        self.evaluation_loop.predictions.to_disk()

        # hook
        self.evaluation_loop.on_evaluation_epoch_end()

        # enable train mode again
        self.evaluation_loop.on_evaluation_model_train()
        torch.set_grad_enabled(True)

        # hook
        self.evaluation_loop.on_evaluation_end()

        return eval_loop_results, deprecated_eval_results

    def run_test(self):
        # only load test dataloader for testing
        # self.reset_test_dataloader(ref_model)
        eval_loop_results, _ = self.run_evaluation(test_mode=True)

        if len(eval_loop_results) == 0:
            return 1

        # remove the tensors from the eval results
        for i, result in enumerate(eval_loop_results):
            if isinstance(result, dict):
                for k, v in result.items():
                    if isinstance(v, torch.Tensor):
                        result[k] = v.cpu().item()

        return eval_loop_results

    def run_sanity_check(self, ref_model):
        using_val_step = ref_model.val_dataloader is not None and is_overridden(
            'validation_step', ref_model)
        should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

        # run tiny validation (if validation defined)
        # to make sure program won't crash during val
        if should_sanity_check:
            self.reset_val_dataloader(ref_model)
            self.num_sanity_val_batches = [
                min(self.num_sanity_val_steps, val_batches)
                for val_batches in self.num_val_batches
            ]

            # hook and callback
            self.running_sanity_check = True
            self.on_sanity_check_start()

            # run eval step
            _, eval_results = self.run_evaluation(
                test_mode=False, max_batches=self.num_sanity_val_batches)

            # allow no returns from eval
            if eval_results is not None and len(eval_results) > 0:
                # when we get a list back, used only the last item
                if isinstance(eval_results, list):
                    eval_results = eval_results[-1]

                if isinstance(eval_results, EvalResult):
                    callback_metrics = eval_results.callback_metrics
                else:
                    _, _, _, callback_metrics, _ = self.process_dict_result(
                        eval_results)
                self.logger_connector.callback_metrics = callback_metrics

            self.on_sanity_check_end()
            self.running_sanity_check = False

    def test(
        self,
        model: Optional[LightningModule] = None,
        test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        ckpt_path: Optional[str] = 'best',
        verbose: bool = True,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""

        Separates from fit to make sure you never run on your test set until you want to.

        Args:
            ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
                If ``None``, use the weights from the last epoch to test. Default to ``best``.

            datamodule: A instance of :class:`LightningDataModule`.

            model: The model to test.

            test_dataloaders: Either a single
                Pytorch Dataloader or a list of them, specifying validation samples.

            verbose: If True, prints the test results

        Returns:
            The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries
        """
        # --------------------
        # SETUP HOOK
        # --------------------
        self.verbose_test = verbose

        # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
        if test_dataloaders and datamodule:
            raise MisconfigurationException(
                'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
            )

        # Attach datamodule to get setup/prepare_data added to model before the call to it below
        self.data_connector.attach_datamodule(model or self.get_model(),
                                              datamodule, 'test')

        if model is not None:
            results = self.__test_given_model(model, test_dataloaders)
        else:
            results = self.__test_using_best_weights(ckpt_path,
                                                     test_dataloaders)

        self.teardown('test')

        return results

    def __test_using_best_weights(self, ckpt_path, test_dataloaders):
        model = self.get_model()

        # if user requests the best checkpoint but we don't have it, error
        if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
            raise MisconfigurationException(
                'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.'
            )

        # load best weights
        if ckpt_path is not None:
            # ckpt_path is 'best' so load the best model
            if ckpt_path == 'best':
                ckpt_path = self.checkpoint_callback.best_model_path

            if len(ckpt_path) == 0:
                rank_zero_warn(
                    f'.test() found no path for the best weights, {ckpt_path}. Please '
                    f'specify a path for a checkpoint .test(ckpt_path=PATH)')
                return {}
            if self.accelerator_backend is not None:
                self.accelerator_backend.barrier()

            ckpt = pl_load(ckpt_path,
                           map_location=lambda storage, loc: storage)
            model.load_state_dict(ckpt['state_dict'])

        # attach dataloaders
        if test_dataloaders is not None:
            self.data_connector.attach_dataloaders(
                model, test_dataloaders=test_dataloaders)

        # run tests
        self.tested_ckpt_path = ckpt_path
        self.testing = True
        os.environ['PL_TESTING_MODE'] = '1'
        self.model = model
        results = self.fit(model)
        self.testing = False
        del os.environ['PL_TESTING_MODE']

        # teardown
        if self.is_function_implemented('teardown'):
            model_ref = self.get_model()
            model_ref.teardown('test')

        return results

    def __test_given_model(self, model, test_dataloaders):

        # attach data
        if test_dataloaders is not None:
            self.data_connector.attach_dataloaders(
                model, test_dataloaders=test_dataloaders)

        # run test
        # sets up testing so we short circuit to eval
        self.testing = True
        self.model = model
        results = self.fit(model)
        self.testing = False

        # teardown
        if self.is_function_implemented('teardown'):
            model.teardown('test')

        return results

    def tune(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Runs routines to tune hyperparameters before training.

        Args:
            datamodule: A instance of :class:`LightningDataModule`.

            model: Model to tune.

            train_dataloader: A Pytorch DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        """
        self.tuner.tune(model, train_dataloader, val_dataloaders, datamodule)

    def call_setup_hook(self, model):
        # call setup after the ddp process has connected
        stage_name = 'test' if self.testing else 'fit'
        if self.datamodule is not None:
            called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
            if not called:
                self.datamodule.setup(stage_name)
        self.setup(stage_name)
        model.setup(stage_name)

    def call_hook(self, hook_name, *args, **kwargs):
        # always profile hooks
        with self.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self, hook_name):
                trainer_hook = getattr(self, hook_name)
                trainer_hook(*args, **kwargs)

            # next call hook in lightningModule
            output = None
            model_ref = self.get_model()
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                output = hook_fx(*args, **kwargs)

            # if the PL module doesn't have the hook then call the accelator
            # used to auto-reduce things for the user with Results obj
            elif hasattr(self.accelerator_backend, hook_name):
                accelerator_hook = getattr(self.accelerator_backend, hook_name)
                output = accelerator_hook(*args, **kwargs)

            return output
예제 #5
0
    def __init__(
            self,
            logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                          bool] = True,
            checkpoint_callback: Union[ModelCheckpoint, bool] = True,
            early_stop_callback: Optional[Union[
                EarlyStopping, bool]] = False,  # todo: remove in v1.0.0
            callbacks: Optional[List[Callback]] = None,
            default_root_dir: Optional[str] = None,
            gradient_clip_val: float = 0,
            process_position: int = 0,
            num_nodes: int = 1,
            num_processes: int = 1,
            gpus: Optional[Union[List[int], str, int]] = None,
            auto_select_gpus: bool = False,
            tpu_cores: Optional[Union[List[int], str, int]] = None,
            log_gpu_memory: Optional[str] = None,
            progress_bar_refresh_rate: int = 1,
            overfit_batches: Union[int, float] = 0.0,
            track_grad_norm: Union[int, float, str] = -1,
            check_val_every_n_epoch: int = 1,
            fast_dev_run: bool = False,
            accumulate_grad_batches: Union[int, Dict[int, int],
                                           List[list]] = 1,
            max_epochs: int = 1000,
            min_epochs: int = 1,
            max_steps: Optional[int] = None,
            min_steps: Optional[int] = None,
            limit_train_batches: Union[int, float] = 1.0,
            limit_val_batches: Union[int, float] = 1.0,
            limit_test_batches: Union[int, float] = 1.0,
            val_check_interval: Union[int, float] = 1.0,
            log_save_interval: int = 100,
            row_log_interval: int = 50,
            distributed_backend: Optional[str] = None,
            sync_batchnorm: bool = False,
            precision: int = 32,
            weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
            weights_save_path: Optional[str] = None,
            num_sanity_val_steps: int = 2,
            truncated_bptt_steps: Optional[int] = None,
            resume_from_checkpoint: Optional[str] = None,
            profiler: Optional[Union[BaseProfiler, bool]] = None,
            benchmark: bool = False,
            deterministic: bool = False,
            reload_dataloaders_every_epoch: bool = False,
            auto_lr_find: Union[bool, str] = False,
            replace_sampler_ddp: bool = True,
            terminate_on_nan: bool = False,
            auto_scale_batch_size: Union[str, bool] = False,
            prepare_data_per_node: bool = True,
            cluster_environment: ClusterEnvironment = None,
            amp_backend: str = 'native',
            amp_level: str = 'O2',  # backward compatible, todo: remove in v1.0.0
            overfit_pct:
        float = None,  # backward compatible, todo: remove in v1.0.0
    ):
        r"""
        Customize every aspect of training via flags

        Args:

            accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

            amp_backend: The mixed precision backend to use ("native" or "apex")

            amp_level: The optimization level to use (O1, O2, etc...).

            auto_lr_find: If set to True, will `initially` run a learning rate finder,
                trying to optimize initial learning for faster convergence. Sets learning
                rate in self.lr or self.learning_rate in the LightningModule.
                To use a different key, set a string instead of True with the key name.

            auto_scale_batch_size: If set to True, will `initially` run a batch size
                finder trying to find the largest batch size that fits into memory.
                The result will be stored in self.batch_size in the LightningModule.
                Additionally, can be set to either `power` that estimates the batch size through
                a power search or `binsearch` that estimates the batch size through a binary search.

            auto_select_gpus: If enabled and `gpus` is an integer, pick available
                gpus automatically. This is especially useful when
                GPUs are configured to be in "exclusive mode", such
                that only one process at a time can access them.

            benchmark: If true enables cudnn.benchmark.

            callbacks: Add a list of callbacks.

            checkpoint_callback: Callback for checkpointing.

            check_val_every_n_epoch: Check val every n train epochs.

            cluster_environment: Environment config to link up arbitrary clusters

            default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
                Default: ``os.getcwd()``.
                Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'

            deterministic: If true enables cudnn.deterministic.

            distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)

            early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`).
                Deprecated since v0.10.0 and will be removed in v1.0.

            fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).

            gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node

            gradient_clip_val: 0 means don't clip.

            limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)

            limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)

            limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)

            logger: Logger (or iterable collection of loggers) for experiment tracking.

            log_gpu_memory: None, 'min_max', 'all'. Might slow performance

            log_save_interval: Writes logs to disk this often

            prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
                Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data

            process_position: orders the progress bar when running multiple models on same machine.

            progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
                Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.

            profiler:  To profile individual steps during training and assist in identifying bottlenecks.

            overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0

            precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.

            max_epochs: Stop training once this number of epochs is reached.

            min_epochs: Force training for at least these many epochs

            max_steps: Stop training after this number of steps. Disabled by default (None).

            min_steps: Force training for at least these number of steps. Disabled by default (None).

            num_nodes: number of GPU nodes for distributed training.

            num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
                Set it to `-1` to run all batches in all validation dataloaders. Default: 2

            reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.

            replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this
                will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for
                train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
                you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.

            resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
                This can be a URL.

            row_log_interval: How often to add logging rows (does not write to disk)

            sync_batchnorm: Synchronize batch norm layers between process groups/whole world.

            terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
                end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

            tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]

            track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.

            truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
                sequence.

            val_check_interval: How often to check the validation set. Use float to check within a training epoch,
                use int to check every n steps (batches).

            weights_summary: Prints a summary of the weights when training begins.

            weights_save_path: Where to save weights if specified. Will override default_root_dir
                    for checkpoints only. Use this if for whatever reason you need the checkpoints
                    stored in a different place than the logs written in `default_root_dir`.
                    Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
                    Defaults to `default_root_dir`.
        """
        super().__init__()

        # init connectors
        self.dev_debugger = InternalDebugger(self)
        self.config_validator = ConfigValidator(self)
        self.data_connector = DataConnector(self)
        self.optimizer_connector = OptimizerConnector(self)
        self.accelerator_connector = AcceleratorConnector(self)
        self.logger_connector = LoggerConnector(self)
        self.model_connector = ModelConnector(self)
        self.precision_connector = PrecisionConnector(self)
        self.callback_connector = CallbackConnector(self)
        self.debugging_connector = DebuggingConnector(self)
        self.training_tricks_connector = TrainingTricksConnector(self)
        self.profile_connector = ProfilerConnector(self)
        self.checkpoint_connector = CheckpointConnector(self)
        self.slurm_connector = SLURMConnector(self)
        self.tuner = Tuner(self)
        self.accelerator_backend = None
        self.evaluation_loop = EvaluationLoop(self)
        self.train_loop = TrainLoop(self)

        # training state
        self.weights_summary = weights_summary
        self.model = None
        self.shown_warnings = set()

        # init callbacks
        # Declare attributes to be set in callback_connector on_trainer_init
        self.checkpoint_callback: Union[ModelCheckpoint,
                                        bool] = checkpoint_callback
        self.early_stop_callback: Optional[Union[EarlyStopping,
                                                 bool]] = early_stop_callback
        self.callback_connector.on_trainer_init(
            callbacks, early_stop_callback, checkpoint_callback,
            progress_bar_refresh_rate, process_position, default_root_dir,
            weights_save_path, resume_from_checkpoint)

        # hook
        self.on_init_start()

        # init optimizer + lr scheduler related flags
        self.optimizer_connector.on_trainer_init()

        # init data flags
        self.data_connector.on_trainer_init(check_val_every_n_epoch,
                                            reload_dataloaders_every_epoch,
                                            prepare_data_per_node)

        # init training tricks
        self.training_tricks_connector.on_trainer_init(
            gradient_clip_val, track_grad_norm, accumulate_grad_batches,
            truncated_bptt_steps, terminate_on_nan)

        # init accelerator related flags
        self.accelerator_connector.on_trainer_init(
            num_processes, tpu_cores, distributed_backend, auto_select_gpus,
            gpus, num_nodes, log_gpu_memory, sync_batchnorm, benchmark,
            replace_sampler_ddp, deterministic, cluster_environment)

        # init train loop related flags
        self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps,
                                        min_steps, num_sanity_val_steps)
        self.evaluation_loop.on_trainer_init()

        # configure tuner
        self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

        # configure profiler
        self.profile_connector.on_trainer_init(profiler)

        # init logger flags
        self.logger_connector.on_trainer_init(logger, log_save_interval,
                                              row_log_interval)

        # init debugging flags
        self.debugging_connector.on_init_start(overfit_pct,
                                               limit_train_batches,
                                               limit_val_batches,
                                               limit_test_batches,
                                               val_check_interval,
                                               overfit_batches, fast_dev_run)

        # set precision
        self.precision_connector.on_trainer_init(precision, amp_level,
                                                 amp_backend)

        # Callback system
        self.on_init_end()
예제 #6
0
    def __init__(
            self,
            logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                          bool] = True,
            checkpoint_callback: Union[ModelCheckpoint, bool] = True,
            early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
            callbacks: Optional[List[Callback]] = None,
            default_root_dir: Optional[str] = None,
            gradient_clip_val: float = 0,
            process_position: int = 0,
            num_nodes: int = 1,
            num_processes: int = 1,
            gpus: Optional[Union[List[int], str, int]] = None,
            auto_select_gpus: bool = False,
            tpu_cores: Optional[Union[List[int], str, int]] = None,
            log_gpu_memory: Optional[str] = None,
            progress_bar_refresh_rate: int = 1,
            overfit_batches: Union[int, float] = 0.0,
            track_grad_norm: Union[int, float, str] = -1,
            check_val_every_n_epoch: int = 1,
            fast_dev_run: bool = False,
            accumulate_grad_batches: Union[int, Dict[int, int],
                                           List[list]] = 1,
            max_epochs: int = 1000,
            min_epochs: int = 1,
            max_steps: Optional[int] = None,
            min_steps: Optional[int] = None,
            limit_train_batches: Union[int, float] = 1.0,
            limit_val_batches: Union[int, float] = 1.0,
            limit_test_batches: Union[int, float] = 1.0,
            val_check_interval: Union[int, float] = 1.0,
            log_save_interval: int = 100,
            row_log_interval: int = 50,
            distributed_backend: Optional[str] = None,
            sync_batchnorm: bool = False,
            precision: int = 32,
            weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
            weights_save_path: Optional[str] = None,
            num_sanity_val_steps: int = 2,
            truncated_bptt_steps: Optional[int] = None,
            resume_from_checkpoint: Optional[str] = None,
            profiler: Optional[Union[BaseProfiler, bool]] = None,
            benchmark: bool = False,
            deterministic: bool = False,
            reload_dataloaders_every_epoch: bool = False,
            auto_lr_find: Union[bool, str] = False,
            replace_sampler_ddp: bool = True,
            terminate_on_nan: bool = False,
            auto_scale_batch_size: Union[str, bool] = False,
            prepare_data_per_node: bool = True,
            amp_backend: str = 'native',
            amp_level: str = 'O2',  # backward compatible, todo: remove in v1.0.0
            val_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            test_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            train_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            overfit_pct:
        float = None,  # backward compatible, todo: remove in v1.0.0
    ):
        super().__init__()

        # init connectors
        self.dev_debugger = InternalDebugger(self)
        self.config_validator = ConfigValidator(self)
        self.data_connector = DataConnector(self)
        self.optimizer_connector = OptimizerConnector(self)
        self.accelerator_connector = AcceleratorConnector(self)
        self.logger_connector = LoggerConnector(self)
        self.model_connector = ModelConnector(self)
        self.precision_connector = PrecisionConnector(self)
        self.callback_connector = CallbackConnector(self)
        self.debugging_connector = DebuggingConnector(self)
        self.training_tricks_connector = TrainingTricksConnector(self)
        self.profile_connector = ProfilerConnector(self)
        self.tuner = Tuner(self)
        self.accelerator_backend = None
        self.evaluation_loop = EvaluationLoop(self)
        self.train_loop = TrainLoop(self)

        # training state
        self.weights_summary = weights_summary
        self.model = None
        self.shown_warnings = set()

        # init callbacks
        self.callback_connector.on_trainer_init(
            callbacks, early_stop_callback, checkpoint_callback,
            progress_bar_refresh_rate, process_position, default_root_dir,
            weights_save_path, resume_from_checkpoint)

        # hook
        self.on_init_start()

        # init optimizer + lr scheduler related flags
        self.optimizer_connector.on_trainer_init()

        # init data flags
        self.data_connector.on_trainer_init(check_val_every_n_epoch,
                                            reload_dataloaders_every_epoch,
                                            prepare_data_per_node)

        # init training tricks
        self.training_tricks_connector.on_trainer_init(
            gradient_clip_val, track_grad_norm, accumulate_grad_batches,
            truncated_bptt_steps, terminate_on_nan)

        # init accelerator related flags
        self.accelerator_connector.on_trainer_init(
            num_processes, tpu_cores, distributed_backend, auto_select_gpus,
            gpus, num_nodes, log_gpu_memory, sync_batchnorm, benchmark,
            replace_sampler_ddp, deterministic)

        # init train loop related flags
        self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps,
                                        min_steps, num_sanity_val_steps)
        self.evaluation_loop.on_trainer_init()

        # configure tuner
        self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

        # configure profiler
        self.profile_connector.on_trainer_init(profiler)

        # init logger flags
        self.logger_connector.on_trainer_init(logger, log_save_interval,
                                              row_log_interval)

        # init debugging flags
        self.debugging_connector.on_init_start(
            overfit_pct, val_percent_check, test_percent_check,
            train_percent_check, limit_train_batches, limit_val_batches,
            limit_test_batches, val_check_interval, overfit_batches,
            fast_dev_run)

        # set precision
        self.precision_connector.on_trainer_init(precision, amp_level,
                                                 amp_backend)

        # Callback system
        self.on_init_end()
예제 #7
0
class Trainer(
        TrainerProperties,
        TrainerIOMixin,
        TrainerCallbackHookMixin,
        TrainerModelHooksMixin,
        TrainerOptimizersMixin,
        TrainerLoggingMixin,
        TrainerTrainingTricksMixin,
        TrainerDataLoadingMixin,
        TrainerDeprecatedAPITillVer0_10,
):
    def __init__(
            self,
            logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                          bool] = True,
            checkpoint_callback: Union[ModelCheckpoint, bool] = True,
            early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
            callbacks: Optional[List[Callback]] = None,
            default_root_dir: Optional[str] = None,
            gradient_clip_val: float = 0,
            process_position: int = 0,
            num_nodes: int = 1,
            num_processes: int = 1,
            gpus: Optional[Union[List[int], str, int]] = None,
            auto_select_gpus: bool = False,
            tpu_cores: Optional[Union[List[int], str, int]] = None,
            log_gpu_memory: Optional[str] = None,
            progress_bar_refresh_rate: int = 1,
            overfit_batches: Union[int, float] = 0.0,
            track_grad_norm: Union[int, float, str] = -1,
            check_val_every_n_epoch: int = 1,
            fast_dev_run: bool = False,
            accumulate_grad_batches: Union[int, Dict[int, int],
                                           List[list]] = 1,
            max_epochs: int = 1000,
            min_epochs: int = 1,
            max_steps: Optional[int] = None,
            min_steps: Optional[int] = None,
            limit_train_batches: Union[int, float] = 1.0,
            limit_val_batches: Union[int, float] = 1.0,
            limit_test_batches: Union[int, float] = 1.0,
            val_check_interval: Union[int, float] = 1.0,
            log_save_interval: int = 100,
            row_log_interval: int = 50,
            distributed_backend: Optional[str] = None,
            sync_batchnorm: bool = False,
            precision: int = 32,
            weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
            weights_save_path: Optional[str] = None,
            num_sanity_val_steps: int = 2,
            truncated_bptt_steps: Optional[int] = None,
            resume_from_checkpoint: Optional[str] = None,
            profiler: Optional[Union[BaseProfiler, bool]] = None,
            benchmark: bool = False,
            deterministic: bool = False,
            reload_dataloaders_every_epoch: bool = False,
            auto_lr_find: Union[bool, str] = False,
            replace_sampler_ddp: bool = True,
            terminate_on_nan: bool = False,
            auto_scale_batch_size: Union[str, bool] = False,
            prepare_data_per_node: bool = True,
            amp_backend: str = 'native',
            amp_level: str = 'O2',  # backward compatible, todo: remove in v1.0.0
            val_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            test_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            train_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            overfit_pct:
        float = None,  # backward compatible, todo: remove in v1.0.0
    ):
        super().__init__()

        # init connectors
        self.dev_debugger = InternalDebugger(self)
        self.config_validator = ConfigValidator(self)
        self.data_connector = DataConnector(self)
        self.optimizer_connector = OptimizerConnector(self)
        self.accelerator_connector = AcceleratorConnector(self)
        self.logger_connector = LoggerConnector(self)
        self.model_connector = ModelConnector(self)
        self.precision_connector = PrecisionConnector(self)
        self.callback_connector = CallbackConnector(self)
        self.debugging_connector = DebuggingConnector(self)
        self.training_tricks_connector = TrainingTricksConnector(self)
        self.profile_connector = ProfilerConnector(self)
        self.tuner = Tuner(self)
        self.accelerator_backend = None
        self.evaluation_loop = EvaluationLoop(self)
        self.train_loop = TrainLoop(self)

        # training state
        self.weights_summary = weights_summary
        self.model = None
        self.shown_warnings = set()

        # init callbacks
        self.callback_connector.on_trainer_init(
            callbacks, early_stop_callback, checkpoint_callback,
            progress_bar_refresh_rate, process_position, default_root_dir,
            weights_save_path, resume_from_checkpoint)

        # hook
        self.on_init_start()

        # init optimizer + lr scheduler related flags
        self.optimizer_connector.on_trainer_init()

        # init data flags
        self.data_connector.on_trainer_init(check_val_every_n_epoch,
                                            reload_dataloaders_every_epoch,
                                            prepare_data_per_node)

        # init training tricks
        self.training_tricks_connector.on_trainer_init(
            gradient_clip_val, track_grad_norm, accumulate_grad_batches,
            truncated_bptt_steps, terminate_on_nan)

        # init accelerator related flags
        self.accelerator_connector.on_trainer_init(
            num_processes, tpu_cores, distributed_backend, auto_select_gpus,
            gpus, num_nodes, log_gpu_memory, sync_batchnorm, benchmark,
            replace_sampler_ddp, deterministic)

        # init train loop related flags
        self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps,
                                        min_steps, num_sanity_val_steps)
        self.evaluation_loop.on_trainer_init()

        # configure tuner
        self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

        # configure profiler
        self.profile_connector.on_trainer_init(profiler)

        # init logger flags
        self.logger_connector.on_trainer_init(logger, log_save_interval,
                                              row_log_interval)

        # init debugging flags
        self.debugging_connector.on_init_start(
            overfit_pct, val_percent_check, test_percent_check,
            train_percent_check, limit_train_batches, limit_val_batches,
            limit_test_batches, val_check_interval, overfit_batches,
            fast_dev_run)

        # set precision
        self.precision_connector.on_trainer_init(precision, amp_level,
                                                 amp_backend)

        # Callback system
        self.on_init_end()

    def tune(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        # TODO: temporary, need to decide if tune or separate object

        # setup data, etc...
        self.train_loop.setup_fit(model, train_dataloader, val_dataloaders,
                                  datamodule)

        # hook
        self.data_connector.prepare_data(model)

        # Run auto batch size scaling
        if self.auto_scale_batch_size:
            if isinstance(self.auto_scale_batch_size, bool):
                self.auto_scale_batch_size = 'power'
            self.tuner.scale_batch_size(
                model,
                mode=self.auto_scale_batch_size,
                train_dataloader=train_dataloader,
                val_dataloaders=val_dataloaders,
                datamodule=datamodule,
            )
            model.logger = self.logger  # reset logger binding

        # Run learning rate finder:
        if self.auto_lr_find:
            self.tuner.internal_find_lr(self, model)
            model.logger = self.logger  # reset logger binding

    # -----------------------------
    # MODEL TRAINING
    # -----------------------------
    @trainer_state(entering=TrainerState.RUNNING,
                   exiting=TrainerState.FINISHED)
    def fit(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        # setup data, etc...
        self.train_loop.setup_fit(model, train_dataloader, val_dataloaders,
                                  datamodule)

        # hook
        self.call_hook('on_fit_start', model)

        # hook
        self.data_connector.prepare_data(model)

        # set testing if set in environ
        self.testing = os.environ.get('PL_TESTING_MODE', self.testing)

        # -------------------------
        # TRAIN
        # -------------------------
        self.accelerator_backend = self.accelerator_connector.select_accelerator(
        )
        self.accelerator_backend.setup(model)
        results = self.accelerator_backend.train()
        self.accelerator_backend.teardown()

        # -------------------------
        # POST-Training
        # -------------------------
        # hook
        self.call_hook('on_fit_end')

        # hook
        self.teardown('fit')
        if self.is_function_implemented('teardown'):
            model.teardown('fit')

        # return 1 when finished
        # used for testing or when we need to know that training succeeded
        return results or 1

    def train(self):
        self.run_sanity_check(self.get_model())

        # enable train mode
        model = self.get_model()
        model.train()
        torch.set_grad_enabled(True)

        # reload data when needed
        self.train_loop.reset_train_val_dataloaders(model)

        # hook
        self.train_loop.on_train_start()

        try:
            # run all epochs
            for epoch in range(self.current_epoch, self.max_epochs):

                # reset train dataloader
                if self.reload_dataloaders_every_epoch:
                    self.reset_train_dataloader(model)

                # hook
                self.train_loop.on_train_epoch_start(epoch)

                # run train epoch
                self.train_loop.run_training_epoch()

                if self.max_steps and self.max_steps <= self.global_step:

                    # hook
                    self.train_loop.on_train_end()
                    return

                # update LR schedulers
                self.optimizer_connector.update_learning_rates(
                    interval='epoch')

                # early stopping
                met_min_epochs = epoch >= self.min_epochs - 1
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                if self.should_stop:
                    if (met_min_epochs and met_min_steps):
                        self.train_loop.on_train_end()
                        return
                    else:
                        log.info(
                            'Trainer was signaled to stop but required minimum epochs'
                            f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
                            ' not been met. Training will continue...')

            # hook
            self.train_loop.on_train_end()

        except KeyboardInterrupt:
            rank_zero_warn(
                'Detected KeyboardInterrupt, attempting graceful shutdown...')

            # user could press ctrl+c many times... only shutdown once
            if not self.interrupted:
                self.interrupted = True
                self._state = TrainerState.INTERRUPTED
                self.on_keyboard_interrupt()

                # hook
                self.train_loop.on_train_end()

    def run_evaluation(self, test_mode: bool = False, max_batches=None):
        # bookkeeping
        self.evaluation_loop.testing = test_mode
        dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(
            max_batches)
        if self.evaluation_loop.should_skip_evaluation(dataloaders,
                                                       max_batches):
            return [], []

        # enable eval mode + no grads
        model = self.get_model()
        model.zero_grad()
        model.eval()
        torch.set_grad_enabled(False)

        # hook
        self.evaluation_loop.on_evaluation_start()

        # set up the eval loop
        self.evaluation_loop.setup(model, max_batches, dataloaders)

        # hook
        # TODO: should this be insider the dataloader loop?
        self.evaluation_loop.on_evaluation_epoch_start()

        # run validation/testing
        for dataloader_idx, dataloader in enumerate(dataloaders):
            # bookkeeping
            dl_outputs = []
            dataloader = self.accelerator_backend.process_dataloader(
                dataloader)
            dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when running on limited batches
                if batch_idx >= dl_max_batches:
                    break

                # hook
                self.evaluation_loop.on_evaluation_batch_start(
                    batch, batch_idx, dataloader_idx)

                # lightning module methods
                output = self.evaluation_loop.evaluation_step(
                    test_mode, batch, batch_idx, dataloader_idx)
                output = self.evaluation_loop.evaluation_step_end(output)

                # hook
                self.evaluation_loop.on_evaluation_batch_end(
                    batch, batch_idx, dataloader_idx)

                # clean up
                self.evaluation_loop.evaluation_batch_end_cleanup(
                    output, batch_idx, dataloader_idx)
                self.evaluation_loop.log_step_metrics(output, batch_idx)

                # track epoch level metrics
                if output is not None:
                    dl_outputs.append(output)

            self.evaluation_loop.outputs.append(dl_outputs)

        # lightning module method
        eval_results = self.evaluation_loop.evaluation_epoch_end(
            num_dataloaders=len(dataloaders))

        # bookkeeping
        eval_loop_results = self.evaluation_loop.log_epoch_metrics(
            eval_results, test_mode)
        self.evaluation_loop.predictions.to_disk()

        # hook
        self.evaluation_loop.on_evaluation_epoch_end()

        # enable train mode again
        model.train()
        torch.set_grad_enabled(True)

        # hook
        self.evaluation_loop.on_evaluation_end()

        return eval_loop_results, eval_results

    def run_test(self):
        # only load test dataloader for testing
        # self.reset_test_dataloader(ref_model)
        eval_loop_results, _ = self.run_evaluation(test_mode=True)

        if len(eval_loop_results) == 0:
            return 1

        # remove the tensors from the eval results
        for i, result in enumerate(eval_loop_results):
            if isinstance(result, dict):
                for k, v in result.items():
                    if isinstance(v, torch.Tensor):
                        result[k] = v.cpu().item()

        return eval_loop_results

    def run_sanity_check(self, ref_model):
        using_val_step = ref_model.val_dataloader is not None and is_overridden(
            'validation_step', ref_model)
        should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

        # run tiny validation (if validation defined)
        # to make sure program won't crash during val
        if should_sanity_check:
            self.reset_val_dataloader(ref_model)
            self.num_sanity_val_batches = [
                min(self.num_sanity_val_steps, val_batches)
                for val_batches in self.num_val_batches
            ]

            # hook and callback
            self.running_sanity_check = True
            self.on_sanity_check_start()

            # run eval step
            _, eval_results = self.run_evaluation(
                test_mode=False, max_batches=self.num_sanity_val_batches)

            # allow no returns from eval
            if eval_results is not None and len(eval_results) > 0:
                # when we get a list back, used only the last item
                if isinstance(eval_results, list):
                    eval_results = eval_results[-1]

                if isinstance(eval_results, EvalResult):
                    callback_metrics = eval_results.callback_metrics
                else:
                    _, _, _, callback_metrics, _ = self.process_output(
                        eval_results)
                self.logger_connector.callback_metrics = callback_metrics

            self.on_sanity_check_end()
            self.running_sanity_check = False

    @trainer_state(entering=TrainerState.RUNNING,
                   exiting=TrainerState.FINISHED)
    def test(
        self,
        model: Optional[LightningModule] = None,
        test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        ckpt_path: Optional[str] = 'best',
        verbose: bool = True,
        datamodule: Optional[LightningDataModule] = None,
    ):
        # --------------------
        # SETUP HOOK
        # --------------------
        self.verbose_test = verbose

        if self.global_rank != 0:
            return

        # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
        if test_dataloaders and datamodule:
            raise MisconfigurationException(
                'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
            )

        # Attach datamodule to get setup/prepare_data added to model before the call to it below
        self.data_connector.attach_datamodule(model or self.get_model(),
                                              datamodule, 'test')

        if model is not None:
            results = self.__test_given_model(model, test_dataloaders)
        else:
            results = self.__test_using_best_weights(ckpt_path,
                                                     test_dataloaders)

        self.teardown('test')

        return results

    def __test_using_best_weights(self, ckpt_path, test_dataloaders):
        model = self.get_model()

        # if user requests the best checkpoint but we don't have it, error
        if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
            raise MisconfigurationException(
                'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.'
            )

        # load best weights
        if ckpt_path is not None:
            # ckpt_path is 'best' so load the best model
            if ckpt_path == 'best':
                ckpt_path = self.checkpoint_callback.best_model_path

            if len(ckpt_path) == 0:
                rank_zero_warn(
                    f'.test() found no path for the best weights, {ckpt_path}. Please '
                    f'specify a path for a checkpoint .test(ckpt_path=PATH)')
                return {}

            ckpt = torch.load(ckpt_path,
                              map_location=lambda storage, loc: storage)
            model.load_state_dict(ckpt['state_dict'])

        # attach dataloaders
        if test_dataloaders is not None:
            self.data_connector.attach_dataloaders(
                model, test_dataloaders=test_dataloaders)

        # run tests
        self.tested_ckpt_path = ckpt_path
        self.testing = True
        os.environ['PL_TESTING_MODE'] = '1'
        self.model = model
        results = self.fit(model)
        self.testing = False
        del os.environ['PL_TESTING_MODE']

        # teardown
        if self.is_function_implemented('teardown'):
            model_ref = self.get_model()
            model_ref.teardown('test')

        return results

    def __test_given_model(self, model, test_dataloaders):

        # attach data
        if test_dataloaders is not None:
            self.data_connector.attach_dataloaders(
                model, test_dataloaders=test_dataloaders)

        # run test
        # sets up testing so we short circuit to eval
        self.testing = True
        self.model = model
        results = self.fit(model)
        self.testing = False

        # teardown
        if self.is_function_implemented('teardown'):
            model.teardown('test')

        return results

    def call_setup_hook(self, model):
        # call setup after the ddp process has connected
        stage_name = 'test' if self.testing else 'fit'
        if self.datamodule is not None:
            called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
            if not called:
                self.datamodule.setup(stage_name)
        self.setup(stage_name)
        model.setup(stage_name)

    def call_hook(self, hook_name, *args, **kwargs):
        # always profile hooks
        with self.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self, hook_name):
                trainer_hook = getattr(self, hook_name)
                trainer_hook(*args, **kwargs)

            # next call hook in lightningModule
            output = None
            model_ref = self.get_model()
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                output = hook_fx(*args, **kwargs)

            # if the PL module doesn't have the hook then call the accelator
            # used to auto-reduce things for the user with Results obj
            elif hasattr(self.accelerator_backend, hook_name):
                accelerator_hook = getattr(self.accelerator_backend, hook_name)
                output = accelerator_hook(*args, **kwargs)

            return output
예제 #8
0
class Trainer(
        TrainerProperties,
        TrainerIOMixin,
        TrainerCallbackHookMixin,
        TrainerModelHooksMixin,
        TrainerOptimizersMixin,
        TrainerDDPMixin,
        TrainerLoggingMixin,
        TrainerTrainingTricksMixin,
        TrainerDataLoadingMixin,
        TrainerCallbackConfigMixin,
        TrainerDeprecatedAPITillVer0_10,
):
    def __init__(
            self,
            logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                          bool] = True,
            checkpoint_callback: Union[ModelCheckpoint, bool] = True,
            early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
            callbacks: Optional[List[Callback]] = None,
            default_root_dir: Optional[str] = None,
            gradient_clip_val: float = 0,
            process_position: int = 0,
            num_nodes: int = 1,
            num_processes: int = 1,
            gpus: Optional[Union[List[int], str, int]] = None,
            auto_select_gpus: bool = False,
            tpu_cores: Optional[Union[List[int], str, int]] = None,
            log_gpu_memory: Optional[str] = None,
            progress_bar_refresh_rate: int = 1,
            overfit_batches: Union[int, float] = 0.0,
            track_grad_norm: Union[int, float, str] = -1,
            check_val_every_n_epoch: int = 1,
            fast_dev_run: bool = False,
            accumulate_grad_batches: Union[int, Dict[int, int],
                                           List[list]] = 1,
            max_epochs: int = 1000,
            min_epochs: int = 1,
            max_steps: Optional[int] = None,
            min_steps: Optional[int] = None,
            limit_train_batches: Union[int, float] = 1.0,
            limit_val_batches: Union[int, float] = 1.0,
            limit_test_batches: Union[int, float] = 1.0,
            val_check_interval: Union[int, float] = 1.0,
            log_save_interval: int = 100,
            row_log_interval: int = 50,
            distributed_backend: Optional[str] = None,
            sync_batchnorm: bool = False,
            precision: int = 32,
            weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
            weights_save_path: Optional[str] = None,
            num_sanity_val_steps: int = 2,
            truncated_bptt_steps: Optional[int] = None,
            resume_from_checkpoint: Optional[str] = None,
            profiler: Optional[Union[BaseProfiler, bool]] = None,
            benchmark: bool = False,
            deterministic: bool = False,
            reload_dataloaders_every_epoch: bool = False,
            auto_lr_find: Union[bool, str] = False,
            replace_sampler_ddp: bool = True,
            terminate_on_nan: bool = False,
            auto_scale_batch_size: Union[str, bool] = False,
            prepare_data_per_node: bool = True,
            amp_backend: str = 'native',
            amp_level: str = 'O2',  # backward compatible, todo: remove in v1.0.0
            val_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            test_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            train_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            overfit_pct:
        float = None,  # backward compatible, todo: remove in v1.0.0
    ):
        super().__init__()

        self.deterministic = deterministic
        torch.backends.cudnn.deterministic = self.deterministic
        if self.deterministic:
            # fixing non-deterministic part of horovod
            # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
            os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

        # init the default rank if exists
        # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
        # this way we only show it on rank 0
        if 'LOCAL_RANK' in os.environ:
            rank_zero_only.rank = int(os.environ['LOCAL_RANK'])

        # tracks internal state for debugging
        self.dev_debugger = InternalDebugger(self)
        self.config_validator = ConfigValidator(self)
        self.data_connector = DataConnector(self)
        self.lr_scheduler_connector = LRSchedulerConnector(self)
        self.accelerator_connector = AcceleratorConnector(self)
        self.logger_connector = LoggerConnector(self)
        self.model_connector = ModelConnector(self)
        self.initializer = Initializer(self)
        self.tuner = Tuner(self)
        self.accelerator_backend = None

        # loops
        self.evaluation_loop = EvaluationLoop(self)
        self.train_loop = TrainLoop(self)

        # training bookeeping
        self.total_batch_idx = 0
        self.running_loss = TensorRunningAccum(window_length=20)
        self.batch_idx = 0
        self.num_training_batches = 0
        self.num_val_batches = []
        self.num_sanity_val_batches = []
        self.num_test_batches = []
        self.train_dataloader = None
        self.test_dataloaders = None
        self.val_dataloaders = None

        # when true, prints test results
        self.verbose_test = True

        # when .test() is called, it sets this
        self.tested_ckpt_path = None

        # training state
        self.model = None
        self.datamodule = None
        self.testing = False
        self.prepare_data_per_node = prepare_data_per_node
        self.lr_schedulers = []
        self.optimizers = None
        self.optimizer_frequencies = []
        self.global_step = 0
        self.current_epoch = 0
        self.interrupted = False
        self.should_stop = False
        self.running_sanity_check = False
        self._state = TrainerState.INITIALIZING

        self._default_root_dir = default_root_dir or os.getcwd()
        self._weights_save_path = weights_save_path or self._default_root_dir

        # init callbacks
        self.callbacks = callbacks or []

        # configure early stop callback
        # creates a default one if none passed in
        early_stop_callback = self.configure_early_stopping(
            early_stop_callback)
        if early_stop_callback:
            self.callbacks.append(early_stop_callback)

        # configure checkpoint callback
        # it is important that this is the last callback to run
        # pass through the required args to figure out defaults
        checkpoint_callback = self.configure_checkpoint_callback(
            checkpoint_callback)
        if checkpoint_callback:
            self.callbacks.append(checkpoint_callback)

        # TODO refactor codebase (tests) to not directly reach into these callbacks
        self.checkpoint_callback = checkpoint_callback
        self.early_stop_callback = early_stop_callback

        self.on_init_start()

        # benchmarking
        self.benchmark = benchmark
        torch.backends.cudnn.benchmark = self.benchmark

        # Transfer params
        self.num_nodes = num_nodes
        self.log_gpu_memory = log_gpu_memory

        # sync-bn backend
        self.sync_batchnorm = sync_batchnorm

        self.gradient_clip_val = gradient_clip_val
        self.check_val_every_n_epoch = check_val_every_n_epoch

        if not isinstance(track_grad_norm,
                          (int, float)) and track_grad_norm != 'inf':
            raise MisconfigurationException(
                "track_grad_norm can be an int, a float or 'inf' (infinity norm)."
            )
        self.track_grad_norm = float(track_grad_norm)

        self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
        self.on_tpu = self.tpu_cores is not None

        self.tpu_id = self.tpu_cores[0] if isinstance(self.tpu_cores,
                                                      list) else None

        if num_processes != 1 and distributed_backend != "ddp_cpu":
            rank_zero_warn(
                "num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it."
            )
        self.num_processes = num_processes

        self.weights_summary = weights_summary

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.max_steps = max_steps
        self.min_steps = min_steps

        if num_sanity_val_steps == -1:
            self.num_sanity_val_steps = float('inf')
        else:
            self.num_sanity_val_steps = num_sanity_val_steps

        self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

        self.auto_lr_find = auto_lr_find
        self.auto_scale_batch_size = auto_scale_batch_size
        self._is_data_prepared = False
        self.replace_sampler_ddp = replace_sampler_ddp

        self.truncated_bptt_steps = truncated_bptt_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.terminate_on_nan = terminate_on_nan
        self.shown_warnings = set()

        self.fast_dev_run = fast_dev_run
        if self.fast_dev_run:
            limit_train_batches = 1
            limit_val_batches = 1
            limit_test_batches = 1
            self.num_sanity_val_steps = 0
            self.max_epochs = 1
            rank_zero_info(
                'Running in fast_dev_run mode: will run a full train,'
                ' val and test loop using a single batch')

        # configure profiler
        if profiler is True:
            profiler = SimpleProfiler()
        self.profiler = profiler or PassThroughProfiler()

        # accumulated grads
        self.accumulate_grad_batches = accumulate_grad_batches
        self.configure_accumulated_gradients(accumulate_grad_batches)

        # override with environment flag
        gpus = os.environ.get('PL_TRAINER_GPUS', gpus)

        # for gpus allow int, string and gpu list
        if auto_select_gpus and isinstance(gpus, int):
            self.gpus = self.tuner.pick_multiple_gpus(gpus)
        else:
            self.gpus = gpus

        self.data_parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)
        self.root_gpu = device_parser.determine_root_gpu_device(
            self.data_parallel_device_ids)
        self.root_device = torch.device("cpu")

        self.on_gpu = True if (self.data_parallel_device_ids
                               and torch.cuda.is_available()) else False

        # tpu state flags
        self.use_tpu = False
        self.tpu_local_core_rank = None
        self.tpu_global_core_rank = None

        # distributed backend choice
        self.distributed_backend = distributed_backend
        self.set_distributed_mode(distributed_backend)

        # override dist backend when using tpus
        if self.on_tpu:
            self.distributed_backend = 'tpu'
            self.init_tpu()

        # init flags for SLURM+DDP to work
        self.world_size = 1
        self.interactive_ddp_procs = []
        self.configure_slurm_ddp(self.num_nodes)
        self.node_rank = self.determine_ddp_node_rank()
        self.local_rank = self.determine_local_rank()
        self.global_rank = 0

        # NVIDIA setup
        self.set_nvidia_flags(self.is_slurm_managing_tasks,
                              self.data_parallel_device_ids)

        self._progress_bar_callback = self.configure_progress_bar(
            progress_bar_refresh_rate, process_position)

        # logging
        self.configure_logger(logger)
        self.log_save_interval = log_save_interval
        self.row_log_interval = row_log_interval

        # how much of the data to use
        # TODO: remove in 0.10.0
        if overfit_pct is not None:
            rank_zero_warn(
                "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            overfit_batches = overfit_pct

        # TODO: remove in 0.10.0
        if val_percent_check is not None:
            rank_zero_warn(
                "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            limit_val_batches = val_percent_check

        # TODO: remove in 0.10.0
        if test_percent_check is not None:
            rank_zero_warn(
                "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            limit_test_batches = test_percent_check

        # TODO: remove in 0.10.0
        if train_percent_check is not None:
            rank_zero_warn(
                "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            limit_train_batches = train_percent_check

        self.limit_train_batches = _determine_batch_limits(
            limit_train_batches, 'limit_train_batches')
        self.limit_val_batches = _determine_batch_limits(
            limit_val_batches, 'limit_val_batches')
        self.limit_test_batches = _determine_batch_limits(
            limit_test_batches, 'limit_test_batches')
        self.val_check_interval = _determine_batch_limits(
            val_check_interval, 'val_check_interval')
        self.overfit_batches = _determine_batch_limits(overfit_batches,
                                                       'overfit_batches')
        self.determine_data_use_amount(self.overfit_batches)

        # AMP init
        # These are the only lines needed after v0.8.0
        # we wrap the user's forward with autocast and give it back at the end of fit
        self.autocast_original_forward = None
        self.precision = precision
        self.scaler = None

        self.amp_level = amp_level
        self.initializer.init_amp(amp_backend)

        self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv(
            'KAGGLE_URL_BASE')

        # Callback system
        self.on_init_end()

    def tune(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        # TODO: temporary, need to decide if tune or separate object

        # setup data, etc...
        self.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

        # hook
        self.call_hook('on_fit_start', model)

        # hook
        self.data_connector.prepare_data(model)

        # Run auto batch size scaling
        if self.auto_scale_batch_size:
            if isinstance(self.auto_scale_batch_size, bool):
                self.auto_scale_batch_size = 'power'
            self.tuner.scale_batch_size(
                model,
                mode=self.auto_scale_batch_size,
                train_dataloader=train_dataloader,
                val_dataloaders=val_dataloaders,
                datamodule=datamodule,
            )
            model.logger = self.logger  # reset logger binding

        # Run learning rate finder:
        if self.auto_lr_find:
            self.tuner.internal_find_lr(self, model)
            model.logger = self.logger  # reset logger binding

    # -----------------------------
    # MODEL TRAINING
    # -----------------------------
    @trainer_state(entering=TrainerState.RUNNING,
                   exiting=TrainerState.FINISHED)
    def fit(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        results = None

        # setup data, etc...
        self.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

        # hook
        self.call_hook('on_fit_start', model)

        # hook
        self.data_connector.prepare_data(model)

        # set testing if set in environ
        self.testing = os.environ.get('PL_TESTING_MODE', self.testing)

        # -------------------------
        # TRAIN
        # -------------------------
        self.accelerator_backend = self.accelerator_connector.select_accelerator(
        )
        self.accelerator_backend.setup(model)
        results = self.accelerator_backend.train()
        self.accelerator_backend.teardown()

        # -------------------------
        # POST-Training
        # -------------------------
        # hook
        self.call_hook('on_fit_end')

        # hook
        self.teardown('fit')
        if self.is_function_implemented('teardown'):
            model.teardown('fit')

        # return 1 when finished
        # used for testing or when we need to know that training succeeded
        return results or 1

    def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
        # bind logger and other properties
        self.model_connector.copy_trainer_model_properties(model)

        # clean hparams
        if hasattr(model, 'hparams'):
            parsing.clean_namespace(model.hparams)

        # links data to the trainer
        self.data_connector.attach_data(model, train_dataloader,
                                        val_dataloaders, datamodule)

        # check that model is configured correctly
        self.config_validator.verify_loop_configurations(model)

    def setup_training(self, model: LightningModule):
        """Sanity check a few things before starting actual training.

        Args:
            model: The model to run sanity test on.
        """
        # --------------------------
        # Setup??
        # --------------------------
        ref_model = model
        if self.data_parallel:
            ref_model = model.module

        # give model convenience properties
        ref_model.trainer = self

        # set local properties on the model
        self.model_connector.copy_trainer_model_properties(ref_model)

        # init amp. Must be done here instead of __init__ to allow ddp to work
        if self.amp_backend == AMPType.NATIVE and self.precision == 16 and not self.use_tpu:
            self.scaler = torch.cuda.amp.GradScaler()

        # log hyper-parameters
        if self.logger is not None:
            # save exp to get started
            self.logger.log_hyperparams(ref_model.hparams)
            self.logger.log_graph(ref_model)
            self.logger.save()

        if self.use_ddp or self.use_ddp2:
            torch_distrib.barrier()

        # wait for all models to restore weights
        if self.on_tpu and XLA_AVAILABLE:
            # wait for all processes to catch up
            torch_xla.core.xla_model.rendezvous("pl.Trainer.setup_training")

        elif self.use_horovod:
            # wait for all processes to catch up
            hvd.join()

        # register auto-resubmit when on SLURM
        self.register_slurm_signal_handlers()

        # --------------------------
        # Pre-train
        # --------------------------
        # on pretrain routine start
        self.on_pretrain_routine_start(ref_model)
        if self.is_function_implemented('on_pretrain_routine_start'):
            ref_model.on_pretrain_routine_start()

        # print model summary
        if self.is_global_zero and self.weights_summary is not None and not self.testing:
            if self.weights_summary in ModelSummary.MODES:
                ref_model.summarize(mode=self.weights_summary)
            else:
                raise MisconfigurationException(
                    "weights_summary can be None, " +
                    ", ".join(ModelSummary.MODES))

        # track model now.
        # if cluster resets state, the model will update with the saved weights
        self.model = model

        # restore training and model before hpc is called
        self.restore_weights(model)

        # on pretrain routine end
        self.on_pretrain_routine_end(ref_model)
        if self.is_function_implemented('on_pretrain_routine_end'):
            ref_model.on_pretrain_routine_end()

    def train(self):
        self.run_sanity_check(self.get_model())

        # enable train mode
        model = self.get_model()
        model.train()
        torch.set_grad_enabled(True)

        # reload data when needed
        self.train_loop.reset_train_val_dataloaders(model)

        # hook
        self.train_loop.on_train_start()

        try:
            # run all epochs
            for epoch in range(self.current_epoch, self.max_epochs):

                # reset train dataloader
                if self.reload_dataloaders_every_epoch:
                    self.reset_train_dataloader(model)

                # hook
                self.train_loop.on_train_epoch_start(epoch)

                # run train epoch
                self.train_loop.run_training_epoch()

                if self.max_steps and self.max_steps <= self.global_step:

                    # hook
                    self.train_loop.on_train_end()
                    return

                # update LR schedulers
                self.lr_scheduler_connector.update_learning_rates(
                    interval='epoch')

                # early stopping
                met_min_epochs = epoch >= self.min_epochs - 1
                met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

                if self.should_stop:
                    if (met_min_epochs and met_min_steps):
                        self.train_loop.on_train_end()
                        return
                    else:
                        log.info(
                            'Trainer was signaled to stop but required minimum epochs'
                            f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
                            ' not been met. Training will continue...')

            # hook
            self.train_loop.on_train_end()

        except KeyboardInterrupt:
            rank_zero_warn(
                'Detected KeyboardInterrupt, attempting graceful shutdown...')

            # user could press ctrl+c many times... only shutdown once
            if not self.interrupted:
                self.interrupted = True
                self._state = TrainerState.INTERRUPTED
                self.on_keyboard_interrupt()

                # hook
                self.train_loop.on_train_end()

    def run_evaluation(self, test_mode: bool = False, max_batches=None):
        # bookkeeping
        self.evaluation_loop.testing = test_mode
        dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(
            max_batches)
        if self.evaluation_loop.should_skip_evaluation(dataloaders,
                                                       max_batches):
            return [], []

        # enable eval mode + no grads
        model = self.get_model()
        model.zero_grad()
        model.eval()
        torch.set_grad_enabled(False)

        # hook
        self.evaluation_loop.on_evaluation_start()

        # set up the eval loop
        self.evaluation_loop.setup(model, max_batches, dataloaders)

        # hook
        # TODO: should this be insider the dataloader loop?
        self.evaluation_loop.on_evaluation_epoch_start()

        # run validation/testing
        for dataloader_idx, dataloader in enumerate(dataloaders):
            # bookkeeping
            dl_outputs = []
            dataloader = self.accelerator_backend.process_dataloader(
                dataloader)
            dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when running on limited batches
                if batch_idx >= dl_max_batches:
                    break

                # hook
                self.evaluation_loop.on_evaluation_batch_start(
                    batch, batch_idx, dataloader_idx)

                # lightning module methods
                output = self.evaluation_loop.evaluation_step(
                    test_mode, batch, batch_idx, dataloader_idx)
                output = self.evaluation_loop.evaluation_step_end(output)

                # hook
                self.evaluation_loop.on_evaluation_batch_end(
                    batch, batch_idx, dataloader_idx)

                # clean up
                self.evaluation_loop.evaluation_batch_end_cleanup(
                    output, batch_idx, dataloader_idx)
                self.evaluation_loop.log_step_metrics(output, batch_idx)

                # track epoch level metrics
                if output is not None:
                    dl_outputs.append(output)

            self.evaluation_loop.outputs.append(dl_outputs)

        # lightning module method
        eval_results = self.evaluation_loop.evaluation_epoch_end(
            num_dataloaders=len(dataloaders))

        # bookkeeping
        eval_loop_results = self.evaluation_loop.log_epoch_metrics(
            eval_results, test_mode)
        self.evaluation_loop.predictions.to_disk()

        # hook
        self.evaluation_loop.on_evaluation_epoch_end()

        # enable train mode again
        model.train()
        torch.set_grad_enabled(True)

        # hook
        self.evaluation_loop.on_evaluation_end()

        return eval_loop_results, eval_results

    def run_test(self):
        # only load test dataloader for testing
        # self.reset_test_dataloader(ref_model)
        eval_loop_results, _ = self.run_evaluation(test_mode=True)

        if len(eval_loop_results) == 0:
            return 1

        # remove the tensors from the eval results
        for i, result in enumerate(eval_loop_results):
            if isinstance(result, dict):
                for k, v in result.items():
                    if isinstance(v, torch.Tensor):
                        result[k] = v.cpu().item()

        return eval_loop_results

    def train_or_test(self):
        if self.testing:
            results = self.run_test()
        else:
            results = self.train()
        return results

    def run_sanity_check(self, ref_model):
        using_val_step = ref_model.val_dataloader is not None and is_overridden(
            'validation_step', ref_model)
        should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

        # run tiny validation (if validation defined)
        # to make sure program won't crash during val
        if should_sanity_check:
            self.reset_val_dataloader(ref_model)
            self.num_sanity_val_batches = [
                min(self.num_sanity_val_steps, val_batches)
                for val_batches in self.num_val_batches
            ]

            # hook and callback
            self.running_sanity_check = True
            self.on_sanity_check_start()

            # run eval step
            _, eval_results = self.run_evaluation(
                test_mode=False, max_batches=self.num_sanity_val_batches)

            # allow no returns from eval
            if eval_results is not None and len(eval_results) > 0:
                # when we get a list back, used only the last item
                if isinstance(eval_results, list):
                    eval_results = eval_results[-1]

                if isinstance(eval_results, EvalResult):
                    callback_metrics = eval_results.callback_metrics
                else:
                    _, _, _, callback_metrics, _ = self.process_output(
                        eval_results)
                self.logger_connector.callback_metrics = callback_metrics

            self.on_sanity_check_end()
            self.running_sanity_check = False

    @trainer_state(entering=TrainerState.RUNNING,
                   exiting=TrainerState.FINISHED)
    def test(
        self,
        model: Optional[LightningModule] = None,
        test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        ckpt_path: Optional[str] = 'best',
        verbose: bool = True,
        datamodule: Optional[LightningDataModule] = None,
    ):
        # --------------------
        # SETUP HOOK
        # --------------------
        self.verbose_test = verbose

        if self.global_rank != 0:
            return

        # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
        if test_dataloaders and datamodule:
            raise MisconfigurationException(
                'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
            )

        # Attach datamodule to get setup/prepare_data added to model before the call to it below
        self.data_connector.attach_datamodule(model or self.get_model(),
                                              datamodule, 'test')

        if model is not None:
            results = self.__test_given_model(model, test_dataloaders)
        else:
            results = self.__test_using_best_weights(ckpt_path,
                                                     test_dataloaders)

        self.teardown('test')

        return results

    def __test_using_best_weights(self, ckpt_path, test_dataloaders):
        model = self.get_model()

        # if user requests the best checkpoint but we don't have it, error
        if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
            raise MisconfigurationException(
                'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.'
            )

        # load best weights
        if ckpt_path is not None:
            # ckpt_path is 'best' so load the best model
            if ckpt_path == 'best':
                ckpt_path = self.checkpoint_callback.best_model_path

            if len(ckpt_path) == 0:
                rank_zero_warn(
                    f'.test() found no path for the best weights, {ckpt_path}. Please '
                    f'specify a path for a checkpoint .test(ckpt_path=PATH)')
                return {}

            ckpt = torch.load(ckpt_path,
                              map_location=lambda storage, loc: storage)
            model.load_state_dict(ckpt['state_dict'])

        # attach dataloaders
        if test_dataloaders is not None:
            self.data_connector.attach_dataloaders(
                model, test_dataloaders=test_dataloaders)

        # run tests
        self.tested_ckpt_path = ckpt_path
        self.testing = True
        os.environ['PL_TESTING_MODE'] = '1'
        self.model = model
        results = self.fit(model)
        self.testing = False
        del os.environ['PL_TESTING_MODE']

        # teardown
        if self.is_function_implemented('teardown'):
            model_ref = self.get_model()
            model_ref.teardown('test')

        return results

    def __test_given_model(self, model, test_dataloaders):

        # attach data
        if test_dataloaders is not None:
            self.data_connector.attach_dataloaders(
                model, test_dataloaders=test_dataloaders)

        # run test
        # sets up testing so we short circuit to eval
        self.testing = True
        self.model = model
        results = self.fit(model)
        self.testing = False

        # teardown
        if self.is_function_implemented('teardown'):
            model.teardown('test')

        return results

    def call_setup_hook(self, model):
        # call setup after the ddp process has connected
        stage_name = 'test' if self.testing else 'fit'
        if self.datamodule is not None:
            called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
            if not called:
                self.datamodule.setup(stage_name)
        self.setup(stage_name)
        model.setup(stage_name)

    def call_hook(self, hook_name, *args, **kwargs):
        # always profile hooks
        with self.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self, hook_name):
                trainer_hook = getattr(self, hook_name)
                trainer_hook(*args, **kwargs)

            # next call hook in lightningModule
            output = None
            model_ref = self.get_model()
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                output = hook_fx(*args, **kwargs)

            # if the PL module doesn't have the hook then call the accelator
            # used to auto-reduce things for the user with Results obj
            elif hasattr(self.accelerator_backend, hook_name):
                accelerator_hook = getattr(self.accelerator_backend, hook_name)
                output = accelerator_hook(*args, **kwargs)

            return output
예제 #9
0
    def __init__(
            self,
            logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase],
                          bool] = True,
            checkpoint_callback: Union[ModelCheckpoint, bool] = True,
            early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
            callbacks: Optional[List[Callback]] = None,
            default_root_dir: Optional[str] = None,
            gradient_clip_val: float = 0,
            process_position: int = 0,
            num_nodes: int = 1,
            num_processes: int = 1,
            gpus: Optional[Union[List[int], str, int]] = None,
            auto_select_gpus: bool = False,
            tpu_cores: Optional[Union[List[int], str, int]] = None,
            log_gpu_memory: Optional[str] = None,
            progress_bar_refresh_rate: int = 1,
            overfit_batches: Union[int, float] = 0.0,
            track_grad_norm: Union[int, float, str] = -1,
            check_val_every_n_epoch: int = 1,
            fast_dev_run: bool = False,
            accumulate_grad_batches: Union[int, Dict[int, int],
                                           List[list]] = 1,
            max_epochs: int = 1000,
            min_epochs: int = 1,
            max_steps: Optional[int] = None,
            min_steps: Optional[int] = None,
            limit_train_batches: Union[int, float] = 1.0,
            limit_val_batches: Union[int, float] = 1.0,
            limit_test_batches: Union[int, float] = 1.0,
            val_check_interval: Union[int, float] = 1.0,
            log_save_interval: int = 100,
            row_log_interval: int = 50,
            distributed_backend: Optional[str] = None,
            sync_batchnorm: bool = False,
            precision: int = 32,
            weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
            weights_save_path: Optional[str] = None,
            num_sanity_val_steps: int = 2,
            truncated_bptt_steps: Optional[int] = None,
            resume_from_checkpoint: Optional[str] = None,
            profiler: Optional[Union[BaseProfiler, bool]] = None,
            benchmark: bool = False,
            deterministic: bool = False,
            reload_dataloaders_every_epoch: bool = False,
            auto_lr_find: Union[bool, str] = False,
            replace_sampler_ddp: bool = True,
            terminate_on_nan: bool = False,
            auto_scale_batch_size: Union[str, bool] = False,
            prepare_data_per_node: bool = True,
            amp_backend: str = 'native',
            amp_level: str = 'O2',  # backward compatible, todo: remove in v1.0.0
            val_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            test_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            train_percent_check:
        float = None,  # backward compatible, todo: remove in v0.10.0
            overfit_pct:
        float = None,  # backward compatible, todo: remove in v1.0.0
    ):
        super().__init__()

        self.deterministic = deterministic
        torch.backends.cudnn.deterministic = self.deterministic
        if self.deterministic:
            # fixing non-deterministic part of horovod
            # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
            os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

        # init the default rank if exists
        # we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
        # this way we only show it on rank 0
        if 'LOCAL_RANK' in os.environ:
            rank_zero_only.rank = int(os.environ['LOCAL_RANK'])

        # tracks internal state for debugging
        self.dev_debugger = InternalDebugger(self)
        self.config_validator = ConfigValidator(self)
        self.data_connector = DataConnector(self)
        self.lr_scheduler_connector = LRSchedulerConnector(self)
        self.accelerator_connector = AcceleratorConnector(self)
        self.logger_connector = LoggerConnector(self)
        self.model_connector = ModelConnector(self)
        self.initializer = Initializer(self)
        self.tuner = Tuner(self)
        self.accelerator_backend = None

        # loops
        self.evaluation_loop = EvaluationLoop(self)
        self.train_loop = TrainLoop(self)

        # training bookeeping
        self.total_batch_idx = 0
        self.running_loss = TensorRunningAccum(window_length=20)
        self.batch_idx = 0
        self.num_training_batches = 0
        self.num_val_batches = []
        self.num_sanity_val_batches = []
        self.num_test_batches = []
        self.train_dataloader = None
        self.test_dataloaders = None
        self.val_dataloaders = None

        # when true, prints test results
        self.verbose_test = True

        # when .test() is called, it sets this
        self.tested_ckpt_path = None

        # training state
        self.model = None
        self.datamodule = None
        self.testing = False
        self.prepare_data_per_node = prepare_data_per_node
        self.lr_schedulers = []
        self.optimizers = None
        self.optimizer_frequencies = []
        self.global_step = 0
        self.current_epoch = 0
        self.interrupted = False
        self.should_stop = False
        self.running_sanity_check = False
        self._state = TrainerState.INITIALIZING

        self._default_root_dir = default_root_dir or os.getcwd()
        self._weights_save_path = weights_save_path or self._default_root_dir

        # init callbacks
        self.callbacks = callbacks or []

        # configure early stop callback
        # creates a default one if none passed in
        early_stop_callback = self.configure_early_stopping(
            early_stop_callback)
        if early_stop_callback:
            self.callbacks.append(early_stop_callback)

        # configure checkpoint callback
        # it is important that this is the last callback to run
        # pass through the required args to figure out defaults
        checkpoint_callback = self.configure_checkpoint_callback(
            checkpoint_callback)
        if checkpoint_callback:
            self.callbacks.append(checkpoint_callback)

        # TODO refactor codebase (tests) to not directly reach into these callbacks
        self.checkpoint_callback = checkpoint_callback
        self.early_stop_callback = early_stop_callback

        self.on_init_start()

        # benchmarking
        self.benchmark = benchmark
        torch.backends.cudnn.benchmark = self.benchmark

        # Transfer params
        self.num_nodes = num_nodes
        self.log_gpu_memory = log_gpu_memory

        # sync-bn backend
        self.sync_batchnorm = sync_batchnorm

        self.gradient_clip_val = gradient_clip_val
        self.check_val_every_n_epoch = check_val_every_n_epoch

        if not isinstance(track_grad_norm,
                          (int, float)) and track_grad_norm != 'inf':
            raise MisconfigurationException(
                "track_grad_norm can be an int, a float or 'inf' (infinity norm)."
            )
        self.track_grad_norm = float(track_grad_norm)

        self.tpu_cores = device_parser.parse_tpu_cores(tpu_cores)
        self.on_tpu = self.tpu_cores is not None

        self.tpu_id = self.tpu_cores[0] if isinstance(self.tpu_cores,
                                                      list) else None

        if num_processes != 1 and distributed_backend != "ddp_cpu":
            rank_zero_warn(
                "num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it."
            )
        self.num_processes = num_processes

        self.weights_summary = weights_summary

        self.max_epochs = max_epochs
        self.min_epochs = min_epochs
        self.max_steps = max_steps
        self.min_steps = min_steps

        if num_sanity_val_steps == -1:
            self.num_sanity_val_steps = float('inf')
        else:
            self.num_sanity_val_steps = num_sanity_val_steps

        self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

        self.auto_lr_find = auto_lr_find
        self.auto_scale_batch_size = auto_scale_batch_size
        self._is_data_prepared = False
        self.replace_sampler_ddp = replace_sampler_ddp

        self.truncated_bptt_steps = truncated_bptt_steps
        self.resume_from_checkpoint = resume_from_checkpoint
        self.terminate_on_nan = terminate_on_nan
        self.shown_warnings = set()

        self.fast_dev_run = fast_dev_run
        if self.fast_dev_run:
            limit_train_batches = 1
            limit_val_batches = 1
            limit_test_batches = 1
            self.num_sanity_val_steps = 0
            self.max_epochs = 1
            rank_zero_info(
                'Running in fast_dev_run mode: will run a full train,'
                ' val and test loop using a single batch')

        # configure profiler
        if profiler is True:
            profiler = SimpleProfiler()
        self.profiler = profiler or PassThroughProfiler()

        # accumulated grads
        self.accumulate_grad_batches = accumulate_grad_batches
        self.configure_accumulated_gradients(accumulate_grad_batches)

        # override with environment flag
        gpus = os.environ.get('PL_TRAINER_GPUS', gpus)

        # for gpus allow int, string and gpu list
        if auto_select_gpus and isinstance(gpus, int):
            self.gpus = self.tuner.pick_multiple_gpus(gpus)
        else:
            self.gpus = gpus

        self.data_parallel_device_ids = device_parser.parse_gpu_ids(self.gpus)
        self.root_gpu = device_parser.determine_root_gpu_device(
            self.data_parallel_device_ids)
        self.root_device = torch.device("cpu")

        self.on_gpu = True if (self.data_parallel_device_ids
                               and torch.cuda.is_available()) else False

        # tpu state flags
        self.use_tpu = False
        self.tpu_local_core_rank = None
        self.tpu_global_core_rank = None

        # distributed backend choice
        self.distributed_backend = distributed_backend
        self.set_distributed_mode(distributed_backend)

        # override dist backend when using tpus
        if self.on_tpu:
            self.distributed_backend = 'tpu'
            self.init_tpu()

        # init flags for SLURM+DDP to work
        self.world_size = 1
        self.interactive_ddp_procs = []
        self.configure_slurm_ddp(self.num_nodes)
        self.node_rank = self.determine_ddp_node_rank()
        self.local_rank = self.determine_local_rank()
        self.global_rank = 0

        # NVIDIA setup
        self.set_nvidia_flags(self.is_slurm_managing_tasks,
                              self.data_parallel_device_ids)

        self._progress_bar_callback = self.configure_progress_bar(
            progress_bar_refresh_rate, process_position)

        # logging
        self.configure_logger(logger)
        self.log_save_interval = log_save_interval
        self.row_log_interval = row_log_interval

        # how much of the data to use
        # TODO: remove in 0.10.0
        if overfit_pct is not None:
            rank_zero_warn(
                "Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            overfit_batches = overfit_pct

        # TODO: remove in 0.10.0
        if val_percent_check is not None:
            rank_zero_warn(
                "Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            limit_val_batches = val_percent_check

        # TODO: remove in 0.10.0
        if test_percent_check is not None:
            rank_zero_warn(
                "Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            limit_test_batches = test_percent_check

        # TODO: remove in 0.10.0
        if train_percent_check is not None:
            rank_zero_warn(
                "Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
                " and this argument will be removed in v0.10.0",
                DeprecationWarning,
            )
            limit_train_batches = train_percent_check

        self.limit_train_batches = _determine_batch_limits(
            limit_train_batches, 'limit_train_batches')
        self.limit_val_batches = _determine_batch_limits(
            limit_val_batches, 'limit_val_batches')
        self.limit_test_batches = _determine_batch_limits(
            limit_test_batches, 'limit_test_batches')
        self.val_check_interval = _determine_batch_limits(
            val_check_interval, 'val_check_interval')
        self.overfit_batches = _determine_batch_limits(overfit_batches,
                                                       'overfit_batches')
        self.determine_data_use_amount(self.overfit_batches)

        # AMP init
        # These are the only lines needed after v0.8.0
        # we wrap the user's forward with autocast and give it back at the end of fit
        self.autocast_original_forward = None
        self.precision = precision
        self.scaler = None

        self.amp_level = amp_level
        self.initializer.init_amp(amp_backend)

        self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv(
            'KAGGLE_URL_BASE')

        # Callback system
        self.on_init_end()