Example #1
0
    def run_pretrain_routine(self, model):
        """
        Sanity check a few things before starting actual training
        :param model:
        :return:
        """
        ref_model = model
        if self.data_parallel:
            ref_model = model.module

        # give model convenience properties
        ref_model.trainer = self

        # set local properties on the model
        self.copy_trainer_model_properties(ref_model)

        # link up experiment object
        if self.logger is not None:
            ref_model.logger = self.logger

            # save exp to get started
            if hasattr(ref_model, "hparams"):
                self.logger.log_hyperparams(ref_model.hparams)

            self.logger.save()

        if self.use_ddp or self.use_ddp2:
            dist.barrier()

        # set up checkpoint callback
        self.configure_checkpoint_callback()

        # register auto-resubmit when on SLURM
        self.register_slurm_signal_handlers()

        # transfer data loaders from model
        self.get_dataloaders(ref_model)

        # print model summary
        if self.proc_rank == 0 and self.weights_summary is not None:
            if self.weights_summary in ['full', 'top']:
                ref_model.summarize(mode=self.weights_summary)
            else:
                m = "weights_summary can be None, 'full' or 'top'"
                raise MisconfigurationException(m)

        # track model now.
        # if cluster resets state, the model will update with the saved weights
        self.model = model

        # restore training and model before hpc call
        self.restore_weights(model)

        # when testing requested only run test and return
        if self.testing:
            self.run_evaluation(test=True)
            return

        # run tiny validation (if validation defined)
        # to make sure program won't crash during val
        ref_model.on_sanity_check_start()
        if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0:
            # init progress bars for validation sanity check
            pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps,
                             leave=False, position=2 * self.process_position,
                             disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
            self.main_progress_bar = pbar
            # dummy validation progress bar
            self.val_progress_bar = tqdm.tqdm(disable=True)

            self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing)

            # close progress bars
            self.main_progress_bar.close()
            self.val_progress_bar.close()

        # init progress bar
        pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
                         disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
                         file=sys.stdout)
        self.main_progress_bar = pbar

        # clear cache before training
        if self.on_gpu:
            torch.cuda.empty_cache()

        # CORE TRAINING LOOP
        self.train()
Example #2
0
    def run_pretrain_routine(self, model: LightningModule):
        """Sanity check a few things before starting actual training.

        Args:
            model: The model to run sanity test on.
        """
        ref_model = model
        if self.data_parallel:
            ref_model = model.module

        # give model convenience properties
        ref_model.trainer = self

        # set local properties on the model
        self.copy_trainer_model_properties(ref_model)

        # link up experiment object
        if self.logger is not None:
            ref_model.logger = self.logger

            # save exp to get started
            if hasattr(ref_model, "hparams"):
                self.logger.log_hyperparams(ref_model.hparams)

            self.logger.save()

        if self.use_ddp or self.use_ddp2:
            dist.barrier()

        # wait for all models to restore weights
        if self.on_tpu and XLA_AVAILABLE:
            # wait for all processes to catch up
            torch_xla.core.xla_model.rendezvous(
                "pl.Trainer.run_pretrain_routine")

        # set up checkpoint callback
        self.configure_checkpoint_callback()

        # register auto-resubmit when on SLURM
        self.register_slurm_signal_handlers()

        # print model summary
        if self.proc_rank == 0 and self.weights_summary is not None:
            if self.weights_summary in ['full', 'top']:
                ref_model.summarize(mode=self.weights_summary)
            else:
                m = "weights_summary can be None, 'full' or 'top'"
                raise MisconfigurationException(m)

        # track model now.
        # if cluster resets state, the model will update with the saved weights
        self.model = model

        # restore training and model before hpc call
        self.restore_weights(model)

        # download the data and do whatever transforms we need
        self.call_prepare_data(ref_model)

        # when testing requested only run test and return
        if self.testing:
            # only load test dataloader for testing
            self.reset_test_dataloader(ref_model)
            self.run_evaluation(test=True)
            return

        # load the dataloaders
        self.reset_train_dataloader(ref_model)
        self.reset_val_dataloader(ref_model)

        # check if we should run validation during training
        self.disable_validation = ((self.num_val_batches == 0 or
                                    not self.is_overriden('validation_step'))
                                   and not self.fast_dev_run)

        # run tiny validation (if validation defined)
        # to make sure program won't crash during val
        ref_model.on_sanity_check_start()
        ref_model.on_train_start()
        if not self.disable_validation and self.num_sanity_val_steps > 0:
            # init progress bars for validation sanity check
            pbar = tqdm(desc='Validation sanity check',
                        total=self.num_sanity_val_steps *
                        len(self.val_dataloaders),
                        leave=False,
                        position=2 * self.process_position,
                        disable=not self.show_progress_bar,
                        dynamic_ncols=True)
            self.main_progress_bar = pbar
            # dummy validation progress bar
            self.val_progress_bar = tqdm(disable=True)

            eval_results = self.evaluate(model, self.val_dataloaders,
                                         self.num_sanity_val_steps, False)
            _, _, _, callback_metrics, _ = self.process_output(eval_results)

            # close progress bars
            self.main_progress_bar.close()
            self.val_progress_bar.close()

            if self.enable_early_stop:
                self.early_stop_callback.check_metrics(callback_metrics)

        # init progress bar
        pbar = tqdm(leave=True,
                    position=2 * self.process_position,
                    disable=not self.show_progress_bar,
                    dynamic_ncols=True,
                    file=sys.stdout)
        self.main_progress_bar = pbar

        # clear cache before training
        if self.on_gpu:
            torch.cuda.empty_cache()

        # CORE TRAINING LOOP
        self.train()
Example #3
0
    def fit(self,
            model: LightningModule,
            train_dataloader: Optional[DataLoader] = None,
            val_dataloaders: Optional[DataLoader] = None,
            test_dataloaders: Optional[DataLoader] = None):
        r"""
        Runs the full optimization routine.

        Args:
            model: Model to fit.

            train_dataloader: A Pytorch
                DataLoader with training samples. If the model has
                a predefined train_dataloader method this will be skipped.

            val_dataloaders: Either a single
                Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined val_dataloaders method this will be skipped

            test_dataloaders: Either a single
                Pytorch Dataloader or a list of them, specifying validation samples.
                If the model has a predefined test_dataloaders method this will be skipped

        Example::

            # Option 1,
            # Define the train_dataloader(), test_dataloader() and val_dataloader() fxs
            # in the lightningModule
            # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
            trainer = Trainer()
            model = LightningModule()
            trainer.fit(model)

            # Option 2
            # in production cases we might want to pass different datasets to the same model
            # Recommended for PRODUCTION SYSTEMS
            train, val, test = DataLoader(...), DataLoader(...), DataLoader(...)
            trainer = Trainer()
            model = LightningModule()
            trainer.fit(model, train_dataloader=train,
                        val_dataloader=val, test_dataloader=test)

            # Option 1 & 2 can be mixed, for example the training set can be
            # defined as part of the model, and validation/test can then be
            # feed to .fit()

        """
        # set up the passed in dataloaders (if needed)
        self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders,
                                   test_dataloaders)

        # route to appropriate start method
        # when using multi-node or DDP within a node start each module in a separate process
        if self.use_ddp2:
            task = int(os.environ['SLURM_LOCALID'])
            self.ddp_train(task, model)

        elif self.use_ddp:
            if self.is_slurm_managing_tasks:
                task = int(os.environ['SLURM_LOCALID'])
                self.ddp_train(task, model)
            else:
                mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model, ))

        # 1 gpu or dp option triggers training using DP module
        # easier to avoid NCCL issues
        elif self.use_dp:
            self.dp_train(model)

        elif self.single_gpu:
            self.single_gpu_train(model)

        elif self.use_tpu:
            log.info(f'training on {self.num_tpu_cores} TPU cores')

            #  COLAB_GPU is an env var available by default in Colab environments.
            start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'
            xmp.spawn(self.tpu_train,
                      args=(model, ),
                      nprocs=self.num_tpu_cores,
                      start_method=start_method)

        # ON CPU
        else:
            # run through amp wrapper
            if self.use_amp:
                raise MisconfigurationException(
                    'amp + cpu is not supported.  Please use a GPU option')

            # CHOOSE OPTIMIZER
            # allow for lr schedulers as well
            self.optimizers, self.lr_schedulers = self.init_optimizers(
                model.configure_optimizers())

            self.run_pretrain_routine(model)

        # return 1 when finished
        # used for testing or when we need to know that training succeeded
        return 1
Example #4
0
    def run_evaluation(self, test=False):
        # when testing make sure user defined a test step
        can_run_test_step = False
        if test:
            can_run_test_step = self.is_overriden('test_step') and self.is_overriden('test_end')
            if not can_run_test_step:
                m = '''You called .test() without defining a test step or test_end.
                Please define and try again'''
                raise MisconfigurationException(m)

        # validate only if model has validation_step defined
        # test only if test_step or validation_step are defined
        run_val_step = self.is_overriden('validation_step')

        if run_val_step or can_run_test_step:

            # hook
            model = self.get_model()
            model.on_pre_performance_check()

            # select dataloaders
            if test:
                dataloaders = self.get_test_dataloaders()
                max_batches = self.nb_test_batches
            else:
                # val
                dataloaders = self.get_val_dataloaders()
                max_batches = self.nb_val_batches

            # cap max batches to 1 when using fast_dev_run
            if self.fast_dev_run:
                max_batches = 1

            # init validation or test progress bar
            # main progress bar will already be closed when testing so initial position is free
            position = 2 * self.process_position + (not test)
            desc = 'Testing' if test else 'Validating'
            pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
                             disable=not self.show_progress_bar, dynamic_ncols=True,
                             unit='batch')
            setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)

            # run evaluation
            eval_results = self.evaluate(self.model,
                                         dataloaders,
                                         max_batches,
                                         test)
            _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
                eval_results)

            # add metrics to prog bar
            self.add_tqdm_metrics(prog_bar_metrics)

            # log metrics
            self.log_metrics(log_metrics, {})

            # track metrics for callbacks
            self.callback_metrics = callback_metrics

            # hook
            model.on_post_performance_check()

            # add model specific metrics
            tqdm_metrics = self.training_tqdm_dict
            if not test:
                self.main_progress_bar.set_postfix(**tqdm_metrics)

            # close progress bar
            if test:
                self.test_progress_bar.close()
            else:
                self.val_progress_bar.close()

        # model checkpointing
        if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
            self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
                                                  logs=self.callback_metrics)