예제 #1
0
    def fit(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Runs the full optimization routine.

        Args:
            datamodule: A instance of :class:`LightningDataModule`.

            model: Model to fit.

            train_dataloader: A Pytorch DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        """
        # bookkeeping
        self._state = TrainerState.RUNNING

        # ----------------------------
        # LINK DATA
        # ----------------------------
        # setup data, etc...
        self.train_loop.setup_fit(model, train_dataloader, val_dataloaders,
                                  datamodule)

        # hook
        self.data_connector.prepare_data(model)

        # bookkeeping
        # we reuse fit in .test() but change its behavior using this flag
        self.testing = os.environ.get('PL_TESTING_MODE', self.testing)

        # ----------------------------
        # SET UP TRAINING
        # ----------------------------
        self.accelerator_backend = self.accelerator_connector.select_accelerator(
        )
        self.accelerator_backend.setup(model)

        # ----------------------------
        # INSPECT THESE FOR MAIN LOOPS
        # ----------------------------
        # assign training and eval functions... inspect these to see the train and eval loops :)
        self.accelerator_backend.train_loop = self.train
        self.accelerator_backend.validation_loop = self.run_evaluation
        self.accelerator_backend.test_loop = self.run_evaluation

        # ----------------------------
        # TRAIN
        # ----------------------------
        # hook
        self.call_hook('on_fit_start')

        results = self.accelerator_backend.train()
        self.accelerator_backend.teardown()

        # ----------------------------
        # POST-Training CLEAN UP
        # ----------------------------
        # hook
        self.call_hook('on_fit_end')

        # hook
        self.teardown('fit')
        if self.is_function_implemented('teardown'):
            model.teardown('fit')

        # return 1 when finished
        # used for testing or when we need to know that training succeeded

        if self._state != TrainerState.INTERRUPTED:
            self._state = TrainerState.FINISHED
        return results or 1
예제 #2
0
    def fit(self,
            model: LightningModule,
            train_dataloader: Optional[DataLoader] = None,
            val_dataloaders: Optional[DataLoader] = None,
            test_dataloaders: Optional[DataLoader] = None):
        r"""
        Runs the full optimization routine.

        Args:
            model: Model to fit.

            train_dataloader: A Pytorch
                DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single
                Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

            test_dataloaders: Either a single
                Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined test_dataloaders method this will be skipped

        Example::

            # Option 1,
            # Define the train_dataloader(), test_dataloader() and val_dataloader() fxs
            # in the lightningModule
            # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
            trainer = Trainer()
            model = LightningModule()
            trainer.fit(model)

            # Option 2
            # in production cases we might want to pass different datasets to the same model
            # Recommended for PRODUCTION SYSTEMS
            train, val, test = DataLoader(...), DataLoader(...), DataLoader(...)
            trainer = Trainer()
            model = LightningModule()
            trainer.fit(model, train_dataloader=train,
                        val_dataloader=val, test_dataloader=test)

            # Option 1 & 2 can be mixed, for example the training set can be
            # defined as part of the model, and validation/test can then be
            # feed to .fit()

        """
        # bind logger and other properties
        model.logger = self.logger
        self.copy_trainer_model_properties(model)

        # set up the passed in dataloaders (if needed)
        self.__attach_dataloaders(model, train_dataloader, val_dataloaders,
                                  test_dataloaders)

        # download the data and do whatever transforms we need
        # do before any spawn calls so that the model can assign properties
        # only on proc 0 because no spawn has happened yet
        model.prepare_data()

        # route to appropriate start method
        # when using multi-node or DDP within a node start each module in a separate process
        if self.use_ddp2:
            task = int(os.environ['SLURM_LOCALID'])
            self.ddp_train(task, model)

        elif self.use_ddp:
            if self.is_slurm_managing_tasks:
                task = int(os.environ['SLURM_LOCALID'])
                self.ddp_train(task, model)
            else:
                self.__set_random_port()

                # track for predict
                self.model = model

                # train
                mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model, ))

                # load weights if not interrupted
                self.load_spawn_weights(model)
                self.model = model

        # 1 gpu or dp option triggers training using DP module
        # easier to avoid NCCL issues
        elif self.use_dp:
            self.dp_train(model)

        elif self.single_gpu:
            self.single_gpu_train(model)

        elif self.use_tpu:  # pragma: no-cover
            log.info(f'training on {self.num_tpu_cores} TPU cores')

            #  COLAB_GPU is an env var available by default in Colab environments.
            start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'

            # track for predict
            self.model = model

            # train
            xmp.spawn(self.tpu_train,
                      args=(model, ),
                      nprocs=self.num_tpu_cores,
                      start_method=start_method)

            # load weights if not interrupted
            self.load_spawn_weights(model)
            self.model = model

        # ON CPU
        else:
            # run through amp wrapper
            if self.use_amp:
                raise MisconfigurationException(
                    'amp + cpu is not supported.  Please use a GPU option')

            # CHOOSE OPTIMIZER
            # allow for lr schedulers as well
            self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
                self.init_optimizers(model.configure_optimizers())

            self.run_pretrain_routine(model)

        # return 1 when finished
        # used for testing or when we need to know that training succeeded
        return 1
예제 #3
0
    def init_optimizers(self,
                        model: LightningModule) -> Tuple[List, List, List]:
        optim_conf = model.configure_optimizers()

        if optim_conf is None:
            rank_zero_warn(
                '`LightningModule.configure_optimizers` returned `None`, '
                'this fit will run with no optimizer', UserWarning)
            optim_conf = _MockOptimizer()

        # single output, single optimizer
        if isinstance(optim_conf, Optimizer):
            return [optim_conf], [], []

        # two lists, optimizer + lr schedulers
        elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \
                and isinstance(optim_conf[0], list):
            optimizers, lr_schedulers = optim_conf
            lr_schedulers = self.configure_schedulers(lr_schedulers)
            return optimizers, lr_schedulers, []

        # single dictionary
        elif isinstance(optim_conf, dict):
            optimizer = optim_conf["optimizer"]
            monitor = optim_conf.get('monitor', None)
            lr_scheduler = optim_conf.get("lr_scheduler", [])
            if lr_scheduler:
                lr_schedulers = self.configure_schedulers([lr_scheduler],
                                                          monitor)
            else:
                lr_schedulers = []
            return [optimizer], lr_schedulers, []

        # multiple dictionaries
        elif isinstance(optim_conf,
                        (list, tuple)) and isinstance(optim_conf[0], dict):
            optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
            # take only lr wif exists and ot they are defined - not None
            lr_schedulers = [
                opt_dict["lr_scheduler"] for opt_dict in optim_conf
                if opt_dict.get("lr_scheduler")
            ]
            # take only freq wif exists and ot they are defined - not None
            optimizer_frequencies = [
                opt_dict["frequency"] for opt_dict in optim_conf
                if opt_dict.get("frequency") is not None
            ]

            # clean scheduler list
            if lr_schedulers:
                lr_schedulers = self.configure_schedulers(lr_schedulers)
            # assert that if frequencies are present, they are given for all optimizers
            if optimizer_frequencies and len(optimizer_frequencies) != len(
                    optimizers):
                raise ValueError(
                    "A frequency must be given to each optimizer.")
            return optimizers, lr_schedulers, optimizer_frequencies

        # single list or tuple, multiple optimizer
        elif isinstance(optim_conf, (list, tuple)):
            return list(optim_conf), [], []

        # unknown configuration
        else:
            raise ValueError(
                'Unknown configuration for model optimizers.'
                ' Output from `model.configure_optimizers()` should either be:'
                ' * single output, single `torch.optim.Optimizer`'
                ' * single output, list of `torch.optim.Optimizer`'
                ' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
                '    and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
                ' * two outputs, first being a list of `torch.optim.Optimizer` second being'
                '    a list of `torch.optim.lr_scheduler`'
                ' * multiple outputs, dictionaries as described with an optional `frequency` key (int)'
            )
예제 #4
0
def lr_find(
    trainer,
    model: LightningModule,
    train_dataloader: Optional[DataLoader] = None,
    val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
    min_lr: float = 1e-8,
    max_lr: float = 1,
    num_training: int = 100,
    mode: str = 'exponential',
    early_stop_threshold: float = 4.0,
    datamodule: Optional[LightningDataModule] = None,
    update_attr: bool = False,
):
    r"""
    ``lr_find`` enables the user to do a range test of good initial learning rates,
    to reduce the amount of guesswork in picking a good starting learning rate.

    Args:
        model: Model to do range testing for

        train_dataloader: A PyTorch
            ``DataLoader`` with training samples. If the model has
            a predefined train_dataloader method, this will be skipped.

        min_lr: minimum learning rate to investigate

        max_lr: maximum learning rate to investigate

        num_training: number of learning rates to test

        mode: Search strategy to update learning rate after each batch:

            - ``'exponential'`` (default): Will increase the learning rate exponentially.
            - ``'linear'``: Will increase the learning rate linearly.

        early_stop_threshold: threshold for stopping the search. If the
            loss at any point is larger than early_stop_threshold*best_loss
            then the search is stopped. To disable, set to None.

        datamodule: An optional ``LightningDataModule`` which holds the training
            and validation dataloader(s). Note that the ``train_dataloader`` and
            ``val_dataloaders`` parameters cannot be used at the same time as
            this parameter, or a ``MisconfigurationException`` will be raised.

        update_attr: Whether to update the learning rate attribute or not.


    Example::

        # Setup model and trainer
        model = MyModelClass(hparams)
        trainer = pl.Trainer()

        # Run lr finder
        lr_finder = trainer.tuner.lr_find(model, ...)

        # Inspect results
        fig = lr_finder.plot(); fig.show()
        suggested_lr = lr_finder.suggestion()

        # Overwrite lr and create new model
        hparams.lr = suggested_lr
        model = MyModelClass(hparams)

        # Ready to train with new learning rate
        trainer.fit(model)

    """
    if trainer.fast_dev_run:
        rank_zero_warn(
            'Skipping learning rate finder since fast_dev_run is enabled.',
            UserWarning)
        return

    # Determine lr attr
    if update_attr:
        lr_attr_name = _determine_lr_attr_name(trainer, model)

    save_path = os.path.join(trainer.default_root_dir,
                             'lr_find_temp_model.ckpt')

    __lr_finder_dump_params(trainer, model)

    # Prevent going into infinite loop
    trainer.auto_lr_find = False

    # Initialize lr finder object (stores results)
    lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

    # Use special lr logger callback
    trainer.callbacks = [
        _LRCallback(num_training,
                    early_stop_threshold,
                    progress_bar_refresh_rate=1)
    ]

    # No logging
    trainer.logger = DummyLogger()

    # Max step set to number of iterations
    trainer.max_steps = num_training

    # Disable standard progress bar for fit
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Required for saving the model
    trainer.optimizers, trainer.schedulers = [], [],
    trainer.model = model

    # Dump model checkpoint
    trainer.save_checkpoint(str(save_path))

    # Configure optimizer and scheduler
    model.configure_optimizers = lr_finder._exchange_scheduler(
        model.configure_optimizers)

    # Fit, lr & loss logged in callback
    trainer.fit(model,
                train_dataloader=train_dataloader,
                val_dataloaders=val_dataloaders,
                datamodule=datamodule)

    # Prompt if we stopped early
    if trainer.global_step != num_training:
        log.info('LR finder stopped early due to diverging loss.')

    # Transfer results from callback to lr finder object
    lr_finder.results.update({
        'lr': trainer.callbacks[0].lrs,
        'loss': trainer.callbacks[0].losses
    })
    lr_finder._total_batch_idx = trainer.total_batch_idx  # for debug purpose

    # Reset model state
    if trainer.is_global_zero:
        trainer.checkpoint_connector.restore(
            str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
        fs = get_filesystem(str(save_path))
        if fs.exists(save_path):
            fs.rm(save_path)

    # Finish by resetting variables so trainer is ready to fit model
    __lr_finder_restore_params(trainer, model)
    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    # Update lr attr if required
    if update_attr:
        lr = lr_finder.suggestion()

        # TODO: log lr.results to self.logger
        lightning_setattr(model, lr_attr_name, lr)
        log.info(f'Learning rate set to {lr}')

    return lr_finder
예제 #5
0
    def fit(
        self,
        model: LightningModule,
        train_dataloader: Any = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Runs the full optimization routine.

        Args:
            datamodule: A instance of :class:`LightningDataModule`.

            model: Model to fit.

            train_dataloader: Either a single PyTorch DataLoader or a collection of these
                (list, dict, nested lists and dicts). In the case of multiple dataloaders, please
                see this :ref:`page <multiple-training-dataloaders>`

            val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        """
        # bookkeeping
        self._state = TrainerState.RUNNING

        # bookkeeping
        # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified.
        if self._running_stage is None:
            self._running_stage = RunningStage.TRAINING

        # set local properties on the model
        self.model_connector.copy_trainer_model_properties(model)

        # ----------------------------
        # LINK DATA
        # ----------------------------
        # setup data, etc...
        self.train_loop.setup_fit(model, train_dataloader, val_dataloaders,
                                  datamodule)

        # hook
        self.data_connector.prepare_data(model)
        self.callback_connector._attach_model_callbacks(model, self)

        # ----------------------------
        # SET UP TRAINING
        # ----------------------------
        self.call_setup_hook(model)
        self.call_hook("on_before_accelerator_backend_setup", model)
        self.accelerator.setup(self, model)
        self.setup_trainer(model)

        # ----------------------------
        # INSPECT THE CORE LOOPS
        # ----------------------------
        #             Lightning internal flow looks like this.
        #
        #   trainer.fit(...) or trainer.test(...) or trainer.predict(...)   ||
        #                                |                                  ||
        #                        create accelerator                         ||
        #                                |                                  ||
        #                         trainer.dispatch                          ||  LIGHTNING
        #                                |                                  ||
        #    start_training or start_testing or start_predicting call       ||  FLOW
        #                        from `accelerator`                         ||
        #                                |                                  ||  DIRECTION
        #             run_train or run_test or run_predict call             ||
        #                           from `trainer`                          ||
        #                                |                                  ||
        #                             results                               \/
        # This is used to guide readers to the core loops: train, test, predict.
        # `run_predict` is the simplest to understand, use `Go to Definition` to read it :)
        # Search for `start_training` or `start_testing` or `start_predicting` in
        # `pytorch_lightning/plugins/training_type` folder to find accelerator dispatch functions.
        self.accelerator.train_loop = self.run_train
        self.accelerator.validation_loop = self.run_evaluation
        self.accelerator.test_loop = self.run_evaluation
        self.accelerator.predict_loop = self.run_predict

        # ----------------------------
        # TRAIN
        # ----------------------------
        # hook
        self.call_hook("on_fit_start")

        # plugin will setup fitting (e.g. ddp will launch child processes)
        self.pre_dispatch()

        # dispath `start_training` or `start_testing` or `start_predicting`
        self.dispatch()

        # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
        self.post_dispatch()

        # ----------------------------
        # POST-Training CLEAN UP
        # ----------------------------
        # hook
        self.call_hook('on_fit_end')

        # hook
        self.teardown('fit')
        if self.is_function_implemented('teardown'):
            model.teardown('fit')

        # return 1 when finished
        # used for testing or when we need to know that training succeeded
        if self._state != TrainerState.INTERRUPTED:
            self._state = TrainerState.FINISHED

        self._running_stage = None

        return self.accelerator.results or 1
 def backward(self, loss, optimizer, optimizer_idx):
     return LightningModule.backward(self, loss, optimizer,
                                     optimizer_idx)
예제 #7
0
    def _evaluate(
        self,
        model: LightningModule,
        dataloaders: List[DataLoader],
        max_batches: Union[int, List[int]],
        test_mode: bool = False
    ):
        """Run evaluation code.

        Args:
            model: The model to evaluate.
            dataloaders: A list of PyTorch dataloaders.
            max_batches: An integer or list of integers with length of the number of dataloaders. Each
                entry is the number of batches to process in the corresponding dataloader.
            test_mode:
        """
        # enable eval mode
        model.zero_grad()
        model.eval()

        # copy properties for forward overrides
        self.copy_trainer_model_properties(model)

        # disable gradients to save memory
        torch.set_grad_enabled(False)

        # bookkeeping
        outputs = []

        # convert max_batches to list
        if isinstance(max_batches, int):
            max_batches = [max_batches] * len(dataloaders)

        # run validation
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dl_outputs = []

            # on TPU we have to wrap it under the ParallelLoader
            if self.use_tpu:
                device = xm.xla_device(self.tpu_id)
                dataloader = xla_pl.ParallelLoader(dataloader, [device])
                dataloader = dataloader.per_device_loader(device)

            # each dataloader has a max num batches
            dl_max_batches = max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when on fast_dev_run (sets max_batch=1)
                if batch_idx >= dl_max_batches:
                    break

                # callbacks
                if test_mode:
                    self.on_test_batch_start()
                else:
                    self.on_validation_batch_start()

                # -----------------
                # RUN EVALUATION STEP
                # -----------------
                if self.use_amp and NATIVE_AMP_AVALAIBLE:
                    with torch.cuda.amp.autocast():
                        output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
                else:
                    output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)

                # on dp / ddp2 might still want to do something with the batch parts
                if test_mode:
                    if self.is_overridden('test_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('test_step_end'):
                            output = model_ref.test_step_end(output)
                    self.on_test_batch_end()
                else:
                    if self.is_overridden('validation_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('validation_step_end'):
                            output = model_ref.validation_step_end(output)
                    self.on_validation_batch_end()

                # track outputs for collation
                dl_outputs.append(output)

            outputs.append(dl_outputs)

        eval_results = {}

        # with a single dataloader don't pass an array
        if len(dataloaders) == 1:
            outputs = outputs[0]

        # give model a chance to do something with the outputs (and method defined)
        if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
            model = model.module

        if test_mode:
            if self.is_overridden('test_end', model=model):
                # TODO: remove in v1.0.0
                eval_results = model.test_end(outputs)
                rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed v1.0.'
                               ' Use `test_epoch_end` instead.', DeprecationWarning)

            elif self.is_overridden('test_epoch_end', model=model):
                eval_results = model.test_epoch_end(outputs)

        else:
            if self.is_overridden('validation_end', model=model):
                # TODO: remove in v1.0.0
                eval_results = model.validation_end(outputs)
                rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed v1.0.'
                               ' Use `validation_epoch_end` instead.', DeprecationWarning)

            elif self.is_overridden('validation_epoch_end', model=model):
                eval_results = model.validation_epoch_end(outputs)

        # aggregate ddp stats across
        has_content = eval_results is not None and len(eval_results) > 0
        if has_content and (self.use_ddp or self.use_ddp2):
            self.reduce_eval_ddp(eval_results)

        # enable train mode again
        model.train()

        # enable gradients to save memory
        torch.set_grad_enabled(True)

        return eval_results
    def lr_find(self,
                model: LightningModule,
                train_dataloader: Optional[DataLoader] = None,
                val_dataloaders: Optional[DataLoader] = None,
                min_lr: float = 1e-8,
                max_lr: float = 1,
                num_training: int = 100,
                mode: str = 'exponential',
                early_stop_threshold: float = 4.0,
                num_accumulation_steps=None):
        r"""
        lr_find enables the user to do a range test of good initial learning rates,
        to reduce the amount of guesswork in picking a good starting learning rate.

        Args:
            model: Model to do range testing for

            train_dataloader: A PyTorch
                DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            min_lr: minimum learning rate to investigate

            max_lr: maximum learning rate to investigate

            num_training: number of learning rates to test

            mode: search strategy, either 'linear' or 'exponential'. If set to
                'linear' the learning rate will be searched by linearly increasing
                after each batch. If set to 'exponential', will increase learning
                rate exponentially.

            early_stop_threshold: threshold for stopping the search. If the
                loss at any point is larger than early_stop_threshold*best_loss
                then the search is stopped. To disable, set to None.

            num_accumulation_steps: deprepecated, number of batches to calculate loss over.
                Set trainer argument ``accumulate_grad_batches`` instead.

        Example::

            # Setup model and trainer
            model = MyModelClass(hparams)
            trainer = pl.Trainer()

            # Run lr finder
            lr_finder = trainer.lr_find(model, ...)

            # Inspect results
            fig = lr_finder.plot(); fig.show()
            suggested_lr = lr_finder.suggestion()

            # Overwrite lr and create new model
            hparams.lr = suggested_lr
            model = MyModelClass(hparams)

            # Ready to train with new learning rate
            trainer.fit(model)

        """
        if num_accumulation_steps is not None:
            rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated"
                           " since v0.7.6 and will be removed in 0.9. Please"
                           " set trainer argument `accumulate_grad_batches` instead.",
                           DeprecationWarning)

        save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')

        self.__lr_finder_dump_params(model)

        # Prevent going into infinite loop
        self.auto_lr_find = False

        # Initialize lr finder object (stores results)
        lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

        # Use special lr logger callback
        self.callbacks = [_LRCallback(num_training,
                                      early_stop_threshold,
                                      progress_bar_refresh_rate=1)]

        # No logging
        self.logger = DummyLogger()

        # Max step set to number of iterations
        self.max_steps = num_training

        # Disable standard progress bar for fit
        if self.progress_bar_callback:
            self.progress_bar_callback.disable()

        # Disable standard checkpoint & early stopping
        self.checkpoint_callback = False
        self.early_stop_callback = None
        self.enable_early_stop = False

        # Required for saving the model
        self.optimizers, self.schedulers = [], [],
        self.model = model

        # Dump model checkpoint
        self.save_checkpoint(str(save_path))

        # Configure optimizer and scheduler
        optimizers, _, _ = self.init_optimizers(model)

        if len(optimizers) != 1:
            raise MisconfigurationException(
                f'`model.configure_optimizers()` returned {len(optimizers)}, but'
                ' learning rate finder only works with single optimizer')
        model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])

        # Fit, lr & loss logged in callback
        self.fit(model,
                 train_dataloader=train_dataloader,
                 val_dataloaders=val_dataloaders)

        # Prompt if we stopped early
        if self.global_step != num_training:
            log.info('LR finder stopped early due to diverging loss.')

        # Transfer results from callback to lr finder object
        lr_finder.results.update({'lr': self.callbacks[0].lrs,
                                  'loss': self.callbacks[0].losses})
        lr_finder._total_batch_idx = self.total_batch_idx  # for debug purpose

        # Reset model state
        self.restore(str(save_path), on_gpu=self.on_gpu)
        os.remove(save_path)

        # Finish by resetting variables so trainer is ready to fit model
        self.__lr_finder_restore_params(model)
        if self.progress_bar_callback:
            self.progress_bar_callback.enable()

        return lr_finder
예제 #9
0
 def call_teardown_hook(self, model: LightningModule) -> None:
     state = self._teardown_state
     self.profiler.teardown(stage=state)
     self.teardown(stage=state)
     model.teardown(stage=state)
예제 #10
0
    def fit(
            self,
            model: LightningModule,
            train_dataloader: Optional[DataLoader] = None,
            val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
    ):
        r"""
        Runs the full optimization routine.

        Args:
            model: Model to fit.

            train_dataloader: A Pytorch
                DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single
                Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        Example::

            # Option 1,
            # Define the train_dataloader() and val_dataloader() fxs
            # in the lightningModule
            # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
            trainer = Trainer()
            model = LightningModule()
            trainer.fit(model)

            # Option 2
            # in production cases we might want to pass different datasets to the same model
            # Recommended for PRODUCTION SYSTEMS
            train, val = DataLoader(...), DataLoader(...)
            trainer = Trainer()
            model = LightningModule()
            trainer.fit(model, train_dataloader=train, val_dataloaders=val)

            # Option 1 & 2 can be mixed, for example the training set can be
            # defined as part of the model, and validation can then be feed to .fit()

        """
        # bind logger and other properties
        model.logger = self.logger
        self.copy_trainer_model_properties(model)

        # clean hparams
        if hasattr(model, 'hparams'):
            parsing.clean_namespace(model.hparams)

        # set up the passed in dataloaders (if needed)
        self.__attach_dataloaders(model, train_dataloader, val_dataloaders)

        # check that model is configured correctly
        self.check_model_configuration(model)

        # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
        # or in the case where each node needs to do its own manipulation in which case just local_rank=0
        if self.can_prepare_data():
            model.prepare_data()
            self._is_data_prepared = True

        # Run auto batch size scaling
        if self.auto_scale_batch_size:
            if isinstance(self.auto_scale_batch_size, bool):
                self.auto_scale_batch_size = 'power'
            self.scale_batch_size(model, mode=self.auto_scale_batch_size)
            model.logger = self.logger  # reset logger binding

        # Run learning rate finder:
        if self.auto_lr_find:
            self._run_lr_finder_internally(model)
            model.logger = self.logger  # reset logger binding

        # route to appropriate start method
        # when using multi-node or DDP within a node start each module in a separate process
        if self.use_ddp2:
            if self.is_slurm_managing_tasks:
                task = int(os.environ['SLURM_LOCALID'])

            # torchelastic or general non_slurm ddp2
            elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ):
                task = int(os.environ['LOCAL_RANK'])

            self.ddp_train(task, model)
        elif self.use_ddp:
            if self.is_slurm_managing_tasks:
                task = int(os.environ['SLURM_LOCALID'])
                self.ddp_train(task, model)

            # torchelastic or general non_slurm ddp
            elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ):
                task = int(os.environ['LOCAL_RANK'])
                self.ddp_train(task, model)

            elif self.distributed_backend == 'cpu_ddp':
                self.__set_random_port()
                self.model = model
                mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))

            elif self.distributed_backend == 'ddp_spawn':
                model.share_memory()

                # spin up peers
                mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))

            elif self.distributed_backend == 'ddp':
                self.spawn_ddp_children(model)

        # 1 gpu or dp option triggers training using DP module
        # easier to avoid NCCL issues
        elif self.use_dp:
            self.dp_train(model)

        elif self.use_horovod:
            self.horovod_train(model)

        elif self.single_gpu:
            self.single_gpu_train(model)

        elif self.use_tpu:  # pragma: no-cover
            rank_zero_info(f'training on {self.tpu_cores} TPU cores')

            #  COLAB_GPU is an env var available by default in Colab environments.
            start_method = 'fork' if self.on_colab_kaggle else 'spawn'

            # track for predict
            self.model = model

            # train
            if self.tpu_id is not None:
                self.tpu_train(self.tpu_id, model)
            else:
                xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method)

            # load weights if not interrupted
            self.load_spawn_weights(model)
            self.model = model

        # ON CPU
        else:
            # run through amp wrapper
            if self.use_amp:
                raise MisconfigurationException('amp + cpu is not supported.  Please use a GPU option')

            # CHOOSE OPTIMIZER
            # allow for lr schedulers as well
            self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

            self.run_pretrain_routine(model)

        # return 1 when finished
        # used for testing or when we need to know that training succeeded
        return 1
예제 #11
0
    def _evaluate(self,
                  model: LightningModule,
                  dataloaders,
                  max_batches: int,
                  test_mode: bool = False):
        """Run evaluation code.

        Args:
            model: PT model
            dataloaders: list of PT dataloaders
            max_batches: Scalar
            test_mode:
        """
        # enable eval mode
        model.zero_grad()
        model.eval()

        # copy properties for forward overrides
        self.copy_trainer_model_properties(model)

        # disable gradients to save memory
        torch.set_grad_enabled(False)

        # bookkeeping
        outputs = []

        # run validation
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dl_outputs = []

            # on TPU we have to wrap it under the ParallelLoader
            if self.use_tpu:
                device = xm.xla_device()
                dataloader = xla_pl.ParallelLoader(dataloader, [device])
                dataloader = dataloader.per_device_loader(device)

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when on fast_dev_run (sets max_batch=1)
                if batch_idx >= max_batches:
                    break

                # -----------------
                # RUN EVALUATION STEP
                # -----------------
                output = self.evaluation_forward(model, batch, batch_idx,
                                                 dataloader_idx, test_mode)

                # on dp / ddp2 might still want to do something with the batch parts
                if test_mode:
                    if self.is_overriden('test_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('test_step_end'):
                            output = model_ref.test_step_end(output)
                else:
                    if self.is_overriden('validation_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('validation_step_end'):
                            output = model_ref.validation_step_end(output)

                # track outputs for collation
                dl_outputs.append(output)

                # batch done
                if self.progress_bar_refresh_rate >= 1 and batch_idx % self.progress_bar_refresh_rate == 0:
                    if test_mode:
                        self.test_progress_bar.update(
                            self.progress_bar_refresh_rate)
                    else:
                        self.val_progress_bar.update(
                            self.progress_bar_refresh_rate)
                        self.main_progress_bar.update(
                            self.progress_bar_refresh_rate)
            outputs.append(dl_outputs)

        eval_results = {}

        # with a single dataloader don't pass an array
        if len(dataloaders) == 1:
            outputs = outputs[0]

        # give model a chance to do something with the outputs (and method defined)
        if isinstance(
                model,
            (LightningDistributedDataParallel, LightningDataParallel)):
            model = model.module

        if test_mode:
            if self.is_overriden('test_end', model=model):
                # TODO: remove in v1.0.0
                eval_results = model.test_end(outputs)
                rank_zero_warn(
                    'Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
                    ' Use `test_epoch_end` instead.', DeprecationWarning)

            elif self.is_overriden('test_epoch_end', model=model):
                eval_results = model.test_epoch_end(outputs)

        else:
            if self.is_overriden('validation_end', model=model):
                # TODO: remove in v1.0.0
                eval_results = model.validation_end(outputs)
                rank_zero_warn(
                    'Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
                    ' Use `validation_epoch_end` instead.', DeprecationWarning)

            elif self.is_overriden('validation_epoch_end', model=model):
                eval_results = model.validation_epoch_end(outputs)

        # enable train mode again
        model.train()

        # enable gradients to save memory
        torch.set_grad_enabled(True)

        return eval_results
예제 #12
0
    def fit(
        self,
        model: LightningModule,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
        datamodule: Optional[LightningDataModule] = None,
    ):
        r"""
        Runs the full optimization routine.

        Args:
            datamodule: A instance of :class:`LightningDataModule`.

            model: Model to fit.

            train_dataloader: A Pytorch DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

        """
        self._state = TrainerState.RUNNING

        # setup data, etc...
        self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

        # hook
        self.data_connector.prepare_data(model)

        # set testing if set in environ
        self.testing = os.environ.get('PL_TESTING_MODE', self.testing)

        # -------------------------
        # TRAIN
        # -------------------------
        self.accelerator_backend = self.accelerator_connector.select_accelerator()
        self.accelerator_backend.setup(model)

        # hook
        self.call_hook('on_fit_start')

        results = self.accelerator_backend.train()
        self.accelerator_backend.teardown()

        # -------------------------
        # POST-Training
        # -------------------------
        # hook
        self.call_hook('on_fit_end')

        # hook
        self.teardown('fit')
        if self.is_function_implemented('teardown'):
            model.teardown('fit')

        # return 1 when finished
        # used for testing or when we need to know that training succeeded

        if self._state != TrainerState.INTERRUPTED:
            self._state = TrainerState.FINISHED
        return results or 1
예제 #13
0
 def on_epoch_end(self, trainer: Trainer, model: LightningModule):
     metrics = self.get_metrics(trainer, model)
     assert metrics["foo"] == self.trainer.current_epoch
     assert metrics["foo_2"] == self.trainer.current_epoch
     model.on_epoch_end_called = True
예제 #14
0
    def _evaluate(
        self,
        model: LightningModule,
        dataloaders: List[DataLoader],
        max_batches: Union[int, List[int]],
        test_mode: bool = False
    ):
        """Run evaluation code.

        Args:
            model: The model to evaluate.
            dataloaders: A list of PyTorch dataloaders.
            max_batches: An integer or list of integers with length of the number of dataloaders. Each
                entry is the number of batches to process in the corresponding dataloader.
            test_mode:
        """
        # enable eval mode
        model.zero_grad()
        model.eval()

        # copy properties for forward overrides
        self.copy_trainer_model_properties(model)

        # disable gradients to save memory
        torch.set_grad_enabled(False)

        # bookkeeping
        outputs = []

        # convert max_batches to list
        if isinstance(max_batches, int):
            max_batches = [max_batches] * len(dataloaders)

        # --------------------------
        # ON_EVAL_EPOCH_START hook
        # --------------------------
        self.__call_eval_loop_hook_start(test_mode)

        # run validation
        for dataloader_idx, dataloader in enumerate(dataloaders):
            dl_outputs = []

            # on TPU we have to wrap it under the ParallelLoader
            if self.use_tpu:
                device = xm.xla_device(self.tpu_id)
                dataloader = xla_pl.ParallelLoader(dataloader, [device])
                dataloader = dataloader.per_device_loader(device)

            # each dataloader has a max num batches
            dl_max_batches = max_batches[dataloader_idx]

            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue

                # stop short when on fast_dev_run (sets max_batch=1)
                if batch_idx >= dl_max_batches:
                    break

                # callbacks
                if test_mode:
                    self.on_test_batch_start()
                else:
                    self.on_validation_batch_start()

                # -----------------
                # RUN EVALUATION STEP
                # -----------------
                if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
                    with torch.cuda.amp.autocast():
                        output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
                else:
                    output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)

                # allow only EvalResult when using structured results (from val_step)
                if isinstance(output, Result) and not isinstance(output, EvalResult):
                    m = 'only EvalResults or dicts are allowed from validation_step'
                    raise MisconfigurationException(m)

                # on dp / ddp2 might still want to do something with the batch parts
                if test_mode:
                    if self.is_overridden('test_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('test_step_end'):
                            output = model_ref.test_step_end(output)
                    self.on_test_batch_end()
                else:
                    if self.is_overridden('validation_step_end'):
                        model_ref = self.get_model()
                        with self.profiler.profile('validation_step_end'):
                            output = model_ref.validation_step_end(output)
                    self.on_validation_batch_end()

                # track outputs for collation
                if output is not None:
                    dl_outputs.append(output)

                self.__eval_add_step_metrics(output)

            outputs.append(dl_outputs)

        # ---------------------
        # EVAL_EPOCH_END
        # ---------------------
        using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
        eval_results = self.__run_eval_epoch_end(test_mode, outputs, dataloaders, using_eval_result)

        # log callback metrics
        self.__update_callback_metrics(eval_results, using_eval_result)

        # enable train mode again
        model.train()

        # enable gradients to save memory
        torch.set_grad_enabled(True)

        # --------------------------
        # ON_EVAL_EPOCH_END hook
        # --------------------------
        self.__call_eval_loop_hook_end(test_mode)

        return eval_results
예제 #15
0
    def init_optimizers(self,
                        model: LightningModule) -> Tuple[List, List, List]:
        optim_conf = model.configure_optimizers()
        if optim_conf is None:
            rank_zero_warn(
                '`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer',
                UserWarning,
            )
            optim_conf = _MockOptimizer()

        optimizers, lr_schedulers, optimizer_frequencies = [], [], []
        monitor = None

        # single output, single optimizer
        if isinstance(optim_conf, Optimizer):
            optimizers = [optim_conf]
        # two lists, optimizer + lr schedulers
        elif isinstance(optim_conf,
                        (list, tuple)) and len(optim_conf) == 2 and isinstance(
                            optim_conf[0], list):
            opt, sch = optim_conf
            optimizers = opt
            lr_schedulers = sch if isinstance(sch, list) else [sch]
        # single dictionary
        elif isinstance(optim_conf, dict):
            optimizers = [optim_conf["optimizer"]]
            monitor = optim_conf.get('monitor', None)
            lr_schedulers = [optim_conf["lr_scheduler"]
                             ] if "lr_scheduler" in optim_conf else []
        # multiple dictionaries
        elif isinstance(optim_conf, (list, tuple)) and all(
                isinstance(d, dict) for d in optim_conf):
            optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
            lr_schedulers = [
                opt_dict["lr_scheduler"] for opt_dict in optim_conf
                if "lr_scheduler" in opt_dict
            ]
            optimizer_frequencies = [
                opt_dict["frequency"] for opt_dict in optim_conf
                if opt_dict.get("frequency", None) is not None
            ]
            # assert that if frequencies are present, they are given for all optimizers
            if optimizer_frequencies and len(optimizer_frequencies) != len(
                    optimizers):
                raise ValueError(
                    "A frequency must be given to each optimizer.")
        # single list or tuple, multiple optimizer
        elif isinstance(optim_conf, (list, tuple)):
            optimizers = list(optim_conf)
        # unknown configuration
        else:
            raise MisconfigurationException(
                'Unknown configuration for model optimizers.'
                ' Output from `model.configure_optimizers()` should either be:\n'
                ' * `torch.optim.Optimizer`\n'
                ' * [`torch.optim.Optimizer`]\n'
                ' * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n'
                ' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
                ' * A list of the previously described dict format, with an optional "frequency" key (int)'
            )

        lr_schedulers = self.configure_schedulers(lr_schedulers,
                                                  monitor=monitor)
        _validate_scheduler_optimizer(optimizers, lr_schedulers)

        return optimizers, lr_schedulers, optimizer_frequencies