def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):

        # track the outputs to reduce at the end of the epoch
        for opt_idx, opt_outputs in enumerate(batch_end_outputs):
            sample_output = opt_outputs[-1]

            # decide if we need to reduce at the end of the epoch automatically
            auto_reduce_tng_result = isinstance(
                sample_output,
                Result) and sample_output.should_reduce_on_epoch_end
            hook_overridden = (is_overridden("training_epoch_end",
                                             model=self.trainer.get_model()) or
                               is_overridden("on_train_epoch_end",
                                             model=self.trainer.get_model()))

            # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
            if not (hook_overridden or auto_reduce_tng_result):
                continue

            # with 1 step (no tbptt) don't use a sequence at epoch end
            if isinstance(opt_outputs,
                          list) and len(opt_outputs) == 1 and not isinstance(
                              opt_outputs[0], Result):
                opt_outputs = opt_outputs[0]

            epoch_output[opt_idx].append(opt_outputs)
    def __verify_train_loop_configuration(self, model):
        # -----------------------------------
        # verify model has a training step
        # -----------------------------------
        has_training_step = is_overridden('training_step', model)
        if not has_training_step:
            raise MisconfigurationException(
                'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        # -----------------------------------
        # verify model has a train dataloader
        # -----------------------------------
        has_train_dataloader = is_overridden('train_dataloader', model)
        if not has_train_dataloader:
            raise MisconfigurationException(
                'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        # -----------------------------------
        # verify model has optimizer
        # -----------------------------------
        has_optimizers = is_overridden('configure_optimizers', model)
        if not has_optimizers:
            raise MisconfigurationException(
                'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )
    def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
        model = self.trainer.get_model()

        # with a single dataloader don't pass an array
        outputs = self.outputs
        eval_results = outputs
        if num_dataloaders == 1:
            eval_results = outputs[0]

        user_reduced = False

        if self.testing:
            if is_overridden('test_epoch_end', model=model):
                if using_eval_result:
                    eval_results = self.__gather_epoch_end_eval_results(outputs)

                eval_results = model.test_epoch_end(eval_results)
                user_reduced = True

        else:
            if is_overridden('validation_epoch_end', model=model):
                if using_eval_result:
                    eval_results = self.__gather_epoch_end_eval_results(outputs)

                eval_results = model.validation_epoch_end(eval_results)
                user_reduced = True

        if using_eval_result and not user_reduced:
            eval_results = self.__auto_reduce_result_objs(outputs)

        if not isinstance(eval_results, list):
            eval_results = [eval_results]

        return eval_results
    def reset_val_dataloader(self, model: LightningModule) -> None:
        """Resets the validation dataloader and determines the number of batches.

        Args:
            model: The current `LightningModule`
        """
        has_loader = is_overridden('val_dataloader', model)
        has_step = is_overridden('validation_step', model)
        if has_loader and has_step:
            self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
Esempio n. 5
0
    def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
        model = self.trainer.get_model()

        # with a single dataloader don't pass an array
        outputs = self.outputs
        eval_results = outputs
        if num_dataloaders == 1:
            eval_results = outputs[0]

        user_reduced = False

        if self.testing:
            if is_overridden('test_epoch_end', model=model):
                if using_eval_result:
                    eval_results = self.__gather_epoch_end_eval_results(
                        outputs)
                model._current_fx_name = 'test_epoch_end'
                eval_results = model.test_epoch_end(eval_results)
                user_reduced = True

        else:
            if is_overridden('validation_epoch_end', model=model):
                if using_eval_result:
                    eval_results = self.__gather_epoch_end_eval_results(
                        outputs)
                model._current_fx_name = 'validation_epoch_end'
                eval_results = model.validation_epoch_end(eval_results)
                user_reduced = True

        # capture logging
        self.trainer.logger_connector.cache_logged_metrics()
        # depre warning
        if eval_results is not None and user_reduced:
            step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
            self.warning_cache.warn(
                f'The {step} should not return anything as of 9.1.'
                ' To log, use self.log(...) or self.write(...) directly in the LightningModule'
            )

        if using_eval_result and not user_reduced:
            eval_results = self.__auto_reduce_result_objs(outputs)

        result = model._results
        if len(result) > 0 and eval_results is None:
            eval_results = result.get_epoch_log_metrics()

        if not isinstance(eval_results, list):
            eval_results = [eval_results]

        # track depreceated metrics
        self.trainer.logger_connector.track_metrics_deprecated(
            eval_results, using_eval_result, self.testing)

        return eval_results
Esempio n. 6
0
    def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
        model = self.trainer.get_model()

        # reset results
        model._results = Result()

        # with a single dataloader don't pass an array
        outputs = self.outputs
        eval_results = outputs
        if num_dataloaders == 1:
            eval_results = outputs[0]

        user_reduced = False

        if self.testing:
            if is_overridden('test_epoch_end', model=model):
                model._current_fx_name = 'test_epoch_end'
                if using_eval_result:
                    eval_results = self.__gather_epoch_end_eval_results(
                        outputs)

                eval_results = model.test_epoch_end(eval_results)
                user_reduced = True

        else:
            if is_overridden('validation_epoch_end', model=model):
                model._current_fx_name = 'validation_epoch_end'
                if using_eval_result:
                    eval_results = self.__gather_epoch_end_eval_results(
                        outputs)

                eval_results = model.validation_epoch_end(eval_results)
                user_reduced = True

        # depre warning
        if eval_results is not None and user_reduced:
            step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
            self.warning_cache.warn(
                f'The {step} should not return anything as of 9.1.'
                ' To log, use self.log(...) or self.write(...) directly in the LightningModule'
            )

        if using_eval_result and not user_reduced:
            eval_results = self.__auto_reduce_result_objs(outputs)

        if not isinstance(eval_results, list):
            eval_results = [eval_results]

        return eval_results
Esempio n. 7
0
    def __verify_train_loop_configuration(self, model):
        # -----------------------------------
        # verify model has a training step
        # -----------------------------------
        has_training_step = is_overridden('training_step', model)
        if not has_training_step:
            raise MisconfigurationException(
                'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        # -----------------------------------
        # verify model has a train dataloader
        # -----------------------------------
        has_train_dataloader = is_overridden('train_dataloader', model)
        if not has_train_dataloader:
            raise MisconfigurationException(
                'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        # -----------------------------------
        # verify model has optimizer
        # -----------------------------------
        has_optimizers = is_overridden('configure_optimizers', model)
        if not has_optimizers:
            raise MisconfigurationException(
                'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        trainer = self.trainer

        trainer.overriden_optimizer_step = is_overridden(
            'optimizer_step', model)
        trainer.overriden_optimizer_zero_grad = is_overridden(
            'optimizer_zero_grad', model)
        automatic_optimization = trainer.train_loop.automatic_optimization
        going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches(
        )

        has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
        if (has_overriden_optimization_functions
            ) and going_to_accumulate_grad_batches and automatic_optimization:
            raise MisconfigurationException(
                'When overriding `LightningModule` optimizer_step or optimizer_zero_grad'
                ' , `accumulate_grad_batches` in `Trainer` should to be 1.'
                ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
            )
Esempio n. 8
0
    def training_epoch_end(self, model, epoch_output, num_optimizers):
        if not is_overridden('training_epoch_end', model=model):
            return Result()

        # run training_epoch_end
        # refresh the result for custom logging at the epoch level
        model._current_fx_name = 'training_epoch_end'
        model._results = Result()

        epoch_output = self.__prepare_epoch_end_inputs(epoch_output)

        if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization:
            epoch_output = epoch_output[0]

        # lightningmodule hook
        epoch_output = model.training_epoch_end(epoch_output)

        model._current_fx_name = ''

        if epoch_output is not None:
            raise MisconfigurationException(
                'training_epoch_end expects a return of None. '
                'HINT: remove the return statement in training_epoch_end')

        # user can ALSO log at the end of an epoch
        new_epoch_end_logs = model._results
        return new_epoch_end_logs
Esempio n. 9
0
    def call_hook(self, hook_name, *args, **kwargs):
        # temporary. Don't modify evaluation behaviour
        if self.logger_connector._current_stage == "train":
            # set hook_name to model + reset Result obj
            self._reset_result_and_set_hook_fx_name(hook_name)

        # always profile hooks
        with self.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self, hook_name):
                trainer_hook = getattr(self, hook_name)
                trainer_hook(*args, **kwargs)

            # next call hook in lightningModule
            output = None
            model_ref = self.get_model()
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                output = hook_fx(*args, **kwargs)

            # if the PL module doesn't have the hook then call the accelator
            # used to auto-reduce things for the user with Results obj
            elif hasattr(self.accelerator_backend, hook_name):
                accelerator_hook = getattr(self.accelerator_backend, hook_name)
                output = accelerator_hook(*args, **kwargs)

        # temporary. Don't modify evaluation behaviour
        if self.logger_connector._current_stage == "train":
            # capture logging
            self._cache_logged_metrics()
        return output
def test_dm_transfer_batch_to_device(tmpdir):
    class CustomBatch:

        def __init__(self, data):
            self.samples = data[0]
            self.targets = data[1]

    class CurrentTestDM(LightningDataModule):

        hook_called = False

        def transfer_batch_to_device(self, data, device):
            self.hook_called = True
            if isinstance(data, CustomBatch):
                data.samples = data.samples.to(device)
                data.targets = data.targets.to(device)
            else:
                data = super().transfer_batch_to_device(data, device)
            return data

    model = EvalModelTemplate()
    dm = CurrentTestDM()
    batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))

    trainer = Trainer(gpus=1)
    # running .fit() would require us to implement custom data loaders, we mock the model reference instead
    trainer.get_model = MagicMock(return_value=model)
    if is_overridden('transfer_batch_to_device', dm):
        model.transfer_batch_to_device = dm.transfer_batch_to_device

    trainer.accelerator_backend = GPUBackend(trainer)
    batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0'))
    expected = torch.device('cuda', 0)
    assert dm.hook_called
    assert batch_gpu.samples.device == batch_gpu.targets.device == expected
Esempio n. 11
0
    def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator):
        """
        Figure out what needs to be tracked/logged at the end of the epoch
        """

        # the training step outputs a list per optimizer. The list contains the outputs at each time step
        # when no TBPTT is used, then the list has 1 item per batch
        # when TBPTT IS used, then the list has n items (1 per time step)
        epoch_end_outputs = []
        for optimizer_idx_outputs in all_train_step_outputs:
            # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
            sample_output = optimizer_idx_outputs[-1]

            # pull out callback info if available (ie: Results object)
            if isinstance(sample_output, dict) and 'early_stop_on' in sample_output:
                early_stopping_accumulator.accumulate(sample_output['early_stop_on'])

            if isinstance(sample_output, dict) and 'checkpoint_on' in sample_output:
                checkpoint_accumulator.accumulate(sample_output['checkpoint_on'])

            # decide if we need to reduce at the end of the epoch automatically
            auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end

            # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
            if is_overridden('training_epoch_end', model=self.get_model()) or auto_reduce_tng_result:
                epoch_end_outputs.append(optimizer_idx_outputs)

        return epoch_end_outputs
Esempio n. 12
0
 def check_checkpoint_callback(self, should_check_val):
     # when no val loop is present or fast-dev-run still need to call checkpoints
     # TODO bake this logic into the checkpoint callback
     should_activate = not is_overridden('validation_step', self.get_model()) and not should_check_val
     if should_activate:
         checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
         [c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks]
Esempio n. 13
0
    def call_hook(self, hook_name, *args, capture=False, **kwargs):
        # set hook_name to model + reset Result obj
        if capture:
            self._reset_result_and_set_hook_fx_name(hook_name)

        # always profile hooks
        with self.profiler.profile(hook_name):

            # first call trainer hook
            if hasattr(self, hook_name):
                trainer_hook = getattr(self, hook_name)
                trainer_hook(*args, **kwargs)

            # next call hook in lightningModule
            output = None
            model_ref = self.get_model()
            if is_overridden(hook_name, model_ref):
                hook_fx = getattr(model_ref, hook_name)
                output = hook_fx(*args, **kwargs)

            # if the PL module doesn't have the hook then call the accelator
            # used to auto-reduce things for the user with Results obj
            elif hasattr(self.accelerator_backend, hook_name):
                accelerator_hook = getattr(self.accelerator_backend, hook_name)
                output = accelerator_hook(*args, **kwargs)

        if capture:
            self._cache_logged_metrics()
        return output
Esempio n. 14
0
    def run_sanity_check(self, ref_model):
        using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model)
        should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

        # run tiny validation (if validation defined)
        # to make sure program won't crash during val
        if should_sanity_check:
            self.reset_val_dataloader(ref_model)
            self.num_sanity_val_batches = [
                min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
            ]

            # hook and callback
            self.running_sanity_check = True
            self.on_sanity_check_start()

            # run eval step
            _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)

            # allow no returns from eval
            if eval_results is not None and len(eval_results) > 0:
                # when we get a list back, used only the last item
                if isinstance(eval_results, list):
                    eval_results = eval_results[-1]

                if isinstance(eval_results, EvalResult):
                    callback_metrics = eval_results.callback_metrics
                else:
                    _, _, _, callback_metrics, _ = self.process_dict_result(eval_results)
                self.logger_connector.callback_metrics = callback_metrics

            self.on_sanity_check_end()
            self.running_sanity_check = False
Esempio n. 15
0
    def validate(self, val_iterator, info):
        self.model.zero_grad()
        self.model.eval()

        torch.set_grad_enabled(False)

        model = self.get_model()
        if self.is_function_implemented("on_validation_epoch_start", model):
            model.on_validation_epoch_start()

        val_outputs = []
        for batch_idx, batch in enumerate(val_iterator):
            batch_info = {"batch_idx": batch_idx}
            batch_info.update(info)
            batch_output = self.validate_batch(batch, batch_info)
            if batch_output is not None:
                val_outputs.append(batch_output)

        processed_outputs = None
        if is_overridden("validation_epoch_end", model):
            raw_outputs = [vo["raw_output"] for vo in val_outputs]
            processed_outputs = model.training_epoch_end(raw_outputs)

        if processed_outputs is not None:
            if isinstance(processed_outputs, torch.Tensor):
                return_output = {"val_loss": processed_outputs}
            elif isinstance(processed_outputs, Result):
                raise ValueError("Result objects are not supported. Please "
                                 "return a dictionary instead.")
            elif isinstance(processed_outputs, dict):
                return_output = processed_outputs
            else:
                raise TypeError("validation_epoch_end returned an invalid "
                                "type. It must return a Tensor, Result, "
                                "or dict.")
        else:
            # User did not override training_epoch_end
            assert isinstance(val_outputs, list)
            # Use AverageMeterCollection util to reduce results.
            meter_collection = AverageMeterCollection()
            for v in val_outputs:
                num_samples = v.pop(NUM_SAMPLES, 1)
                raw_output = v["raw_output"]
                if isinstance(raw_output, dict):
                    meter_collection.update(raw_output, num_samples)
                elif isinstance(raw_output, torch.Tensor):
                    meter_collection.update({"val_loss": raw_output.item()},
                                            num_samples)
                return_output = meter_collection.summary()

        if self.is_function_implemented("on_validation_epoch_end", model):
            model.on_validation_epoch_end()

        # Set back to True so training will work.
        torch.set_grad_enabled(True)

        return return_output
    def can_prepare_data(self):
        should_call_dm_prepare_data = True
        if self.trainer.datamodule is not None and is_overridden(
                'prepare_data', self.trainer.datamodule):
            should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data

        if self.trainer.prepare_data_per_node:
            return self.trainer.local_rank == 0 and should_call_dm_prepare_data
        else:
            return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data
Esempio n. 17
0
    def __verify_eval_loop_configuration(self, model, eval_loop_name):
        step_name = f'{eval_loop_name}_step'

        # map the dataloader name
        loader_name = f'{eval_loop_name}_dataloader'
        if eval_loop_name == 'validation':
            loader_name = 'val_dataloader'

        has_loader = is_overridden(loader_name, model)
        has_step = is_overridden(step_name, model)

        if has_loader and not has_step:
            rank_zero_warn(
                f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop'
            )
        if has_step and not has_loader:
            rank_zero_warn(
                f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop'
            )
Esempio n. 18
0
    def attach_datamodule(self, model, datamodule, stage):

        # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
        datamodule = datamodule or getattr(model, 'datamodule', None)

        # If we have a datamodule, attach necessary hooks + dataloaders
        if datamodule:

            # Override loader hooks
            if is_overridden('train_dataloader', datamodule):
                model.train_dataloader = datamodule.train_dataloader
            if is_overridden('val_dataloader', datamodule):
                model.val_dataloader = datamodule.val_dataloader
            if is_overridden('test_dataloader', datamodule):
                model.test_dataloader = datamodule.test_dataloader

            # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
            if is_overridden('transfer_batch_to_device', datamodule):
                model.transfer_batch_to_device = datamodule.transfer_batch_to_device

            self.trainer.datamodule = datamodule
Esempio n. 19
0
    def validate_batch(self, batch, batch_info):
        model = self.get_model()
        batch_idx = batch_info["batch_idx"]
        if is_overridden("on_validation_batch_start", model):
            model.on_validation_batch_start(batch=batch,
                                            batch_idx=batch_idx,
                                            dataloader_idx=0)
        args = [batch, batch_idx]
        with self.timers.record("eval_fwd"):
            if self._is_distributed:
                # Use the DDP wrapped model (self.model).
                output = self.model(*args)
            elif self.use_gpu:
                # Using single GPU.
                device = self.device
                batch = model.transfer_batch_to_device(batch, device=device)
                args[0] = batch
                output = model.validation_step(*args)
            else:
                # Using CPU.
                output = model.validation_step(*args)

        if isinstance(output, Result):
            raise ValueError("EvalResult objects are not supported. Please "
                             "return a dictionary instead.")

        if is_overridden("on_validation_step_end", model):
            output = model.validation_step_end(output)

        if self.is_function_implemented("on_validation_batch_end", model):
            model.on_validation_batch_end(outputs=output,
                                          batch=batch,
                                          batch_idx=batch_idx,
                                          dataloader_idx=0)
        return {
            "raw_output": output,
            # NUM_SAMPLES: len(batch)
        }
Esempio n. 20
0
    def configure_checkpoint_callback(self, checkpoint_callback):
        if checkpoint_callback is True:
            # when no val step is defined, use 'loss' otherwise 'val_loss'
            train_step_only = not is_overridden('validation_step',
                                                self.get_model())
            monitor_key = 'loss' if train_step_only else 'val_loss'
            checkpoint_callback = ModelCheckpoint(filepath=None,
                                                  monitor=monitor_key)
        elif checkpoint_callback is False:
            checkpoint_callback = None

        if checkpoint_callback:
            checkpoint_callback.save_function = self.save_checkpoint

        return checkpoint_callback
Esempio n. 21
0
    def __run_legacy_training_epoch_end(
            self,
            num_optimizers,
            epoch_output,
            model,
            is_result_obj,
            epoch_callback_metrics
    ):

        epoch_log_metrics = {}
        epoch_progress_bar_metrics = {}

        # --------------------------
        # EPOCH END STEP IF DEFINED
        # --------------------------
        if is_overridden('training_epoch_end', model=model):
            if is_result_obj:
                # with result object gather across time and training steps so each opt idx has a single result obj
                epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output)

            if num_optimizers == 1:
                epoch_output = epoch_output[0]

            # run training_epoch_end
            # a list with a result per optimizer index
            model._current_fx_name = 'training_epoch_end'
            epoch_output = model.training_epoch_end(epoch_output)

            # capture logging
            self.trainer.logger_connector.cache_logged_metrics()

            if isinstance(epoch_output, Result):
                epoch_log_metrics = epoch_output.epoch_log_metrics
                epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
            else:
                _processed_outputs = self.trainer.process_dict_result(epoch_output)
                epoch_progress_bar_metrics = _processed_outputs[1]
                epoch_log_metrics = _processed_outputs[2]
                epoch_callback_metrics = _processed_outputs[3]

        # --------------------------
        # Structured Result (auto epoch end)
        # --------------------------
        elif is_result_obj:
            epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)

        return epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics
Esempio n. 22
0
    def training_epoch_end(self, model, epoch_output, num_optimizers):
        if not is_overridden('training_epoch_end', model=model):
            return

        # run training_epoch_end
        # refresh the result for custom logging at the epoch level
        model._current_fx_name = 'training_epoch_end'
        epoch_output = self.__prepare_epoch_end_inputs(epoch_output)

        if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization:
            epoch_output = epoch_output[0]

        # lightningmodule hook
        epoch_output = model.training_epoch_end(epoch_output)

        if epoch_output is not None:
            raise MisconfigurationException('training_epoch_end expects a return of None. '
                                            'HINT: remove the return statement in training_epoch_end')
        # capture logging
        self.trainer.logger_connector.cache_logged_metrics()
Esempio n. 23
0
    def run_training_epoch(self):

        # get model
        model = self.trainer.get_model()

        # modify dataloader if needed (ddp, etc...)
        train_dataloader = self.trainer.accelerator_backend.process_dataloader(
            self.trainer.train_dataloader)

        # track epoch output
        epoch_output = [[] for _ in range(self.num_optimizers)]

        # enable profiling for the dataloader
        train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(
            train_dataloader)
        dataloader_idx = 0
        should_check_val = False
        for batch_idx, (batch, is_last_batch) in train_dataloader:

            self.trainer.batch_idx = batch_idx

            # ------------------------------------
            # TRAINING_STEP + TRAINING_STEP_END
            # ------------------------------------
            with self.trainer.profiler.profile("run_training_batch"):
                batch_output = self.run_training_batch(batch, batch_idx,
                                                       dataloader_idx)

            # when returning -1 from train_step, we end epoch early
            if batch_output.signal == -1:
                break

            # only track outputs when user implements training_epoch_end
            # otherwise we will build up unnecessary memory
            epoch_end_outputs = self.process_train_step_outputs(
                batch_output.training_step_output_for_epoch_end,
                self.early_stopping_accumulator,
                self.checkpoint_accumulator,
            )

            # hook
            # TODO: add outputs to batches
            self.on_train_batch_end(epoch_output, epoch_end_outputs, batch,
                                    batch_idx, dataloader_idx)

            # -----------------------------------------
            # SAVE METRICS TO LOGGERS
            # -----------------------------------------
            self.trainer.logger_connector.log_train_step_metrics(batch_output)

            # -----------------------------------------
            # VALIDATE IF NEEDED + CHECKPOINT CALLBACK
            # -----------------------------------------
            should_check_val = self.should_check_val_fx(
                batch_idx, is_last_batch)
            if should_check_val:
                self.trainer.run_evaluation(test_mode=False)
                # reset stage to train
                self.trainer.logger_connector.set_stage("train")

            # -----------------------------------------
            # SAVE LOGGERS (ie: Tensorboard, etc...)
            # -----------------------------------------
            self.save_loggers_on_train_batch_end()

            # update LR schedulers
            monitor_metrics = deepcopy(
                self.trainer.logger_connector.callback_metrics)
            self.update_train_loop_lr_schedulers(
                monitor_metrics=monitor_metrics)
            self.trainer.checkpoint_connector.has_trained = True

            # max steps reached, end training
            if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
                accumulation_done = self._accumulated_batches_reached()
                # Ensure accumulation across batches has completed before breaking loop
                if accumulation_done:
                    break

            # end epoch early
            # stop when the flag is changed or we've gone past the amount
            # requested in the batches
            if self.trainer.should_stop:
                break

            self.trainer.total_batch_idx += 1

            # stop epoch if we limited the number of training batches
            if (batch_idx + 1) >= self.trainer.num_training_batches:
                break

            # progress global step according to grads progress
            self.increment_accumulated_grad_global_step()

        # epoch end hook
        self.run_on_epoch_end_hook(epoch_output)

        # log epoch metrics
        self.trainer.logger_connector.log_train_epoch_end_metrics(
            epoch_output, self.checkpoint_accumulator,
            self.early_stopping_accumulator, self.num_optimizers)

        # when no val loop is present or fast-dev-run still need to call checkpoints
        self.check_checkpoint_callback(not (
            should_check_val or is_overridden('validation_step', model)))

        # increment the global step once
        # progress global step according to grads progress
        self.increment_accumulated_grad_global_step()
Esempio n. 24
0
    def train_batch(self, batch, batch_info):
        # Get the original PTL module.
        model = self.get_model()
        optimizer = self.optimizers[0]
        batch_idx = batch_info["batch_idx"]
        epoch_idx = batch_info["epoch_idx"]

        if self.is_function_implemented("on_train_batch_start", model):
            response = model.on_train_batch_start(batch=batch,
                                                  batch_idx=batch_idx,
                                                  dataloader_idx=0)
            # Skip remainder of epoch if response is -1.
            if response == -1:
                return {"signal": -1}

        args = [batch, batch_idx]
        if len(self.optimizers) > 1:
            if self.has_arg("training_step", "optimizer_idx"):
                args.append(0)

        with self.timers.record("fwd"):
            if self._is_distributed:
                # Use the DDP wrapped model (self.model).
                output = self.model(*args)
            elif self.use_gpu:
                # Using single GPU.
                # Don't copy the batch since there is a single gpu that
                # the batch could be referenced from and if there are
                # multiple optimizers the batch will wind up copying it to
                # the same device repeatedly.
                device = self.device
                batch = model.transfer_batch_to_device(batch, device=device)
                args[0] = batch
                output = model.training_step(*args)
            else:
                # Using CPU.
                output = model.training_step(*args)

        if isinstance(output, Result):
            raise ValueError("TrainResult objects are not supported. Please "
                             "return a dictionary instead.")

        # allow any mode to define training_step_end
        # do something will all the dp outputs (like softmax)
        if is_overridden("training_step_end", model):
            output = model.training_step_end(output)

        # Extract loss from output if dictionary.
        try:
            loss = output["loss"]
        except Exception:
            if isinstance(output, torch.Tensor):
                loss = output
            else:
                raise RuntimeError(
                    "No `loss` value in the dictionary returned from "
                    "`model.training_step()`.")

        # If output contains tensors, detach them all.
        if isinstance(output, torch.Tensor):
            output = output.detach()
        elif isinstance(output, dict):
            output = recursive_detach(output)
        else:
            raise TypeError("training_step returned invalid type. It must "
                            "return either a Tensor, Result, or dict.")

        untouched_loss = loss.detach().clone()

        with self.timers.record("grad"):
            if self.use_fp16:
                with self._amp.scale_loss(loss, optimizer) as scaled_loss:
                    model.backward(scaled_loss, optimizer, optimizer_idx=0)
            else:
                model.backward(loss, optimizer, optimizer_idx=0)

        if self.is_function_implemented("on_after_backward", model):
            model.on_after_backward()

        with self.timers.record("apply"):
            optimizer.step()

        model.on_before_zero_grad(optimizer)

        model.optimizer_zero_grad(epoch=epoch_idx,
                                  batch_idx=batch_idx,
                                  optimizer=optimizer,
                                  optimizer_idx=0)

        if self.is_function_implemented("on_train_batch_end", model):
            model.on_train_batch_end(outputs=output,
                                     batch=batch,
                                     batch_idx=batch_idx,
                                     dataloader_idx=0)

        return {
            "signal": 0,
            "training_loss": untouched_loss.item(),
            "raw_output": output,
            # NUM_SAMPLES: len(batch)
        }
Esempio n. 25
0
    def train_epoch(self, iterator, info):
        model = self.get_model()

        # Enable train mode.
        self.model.train()

        # Enable gradients.
        torch.set_grad_enabled(True)

        if self.is_function_implemented("on_train_epoch_start", model):
            model.on_train_epoch_start()

        if self.use_tqdm and self.world_rank == 0:
            desc = ""
            if info is not None and "epoch_idx" in info:
                if "num_epochs" in info:
                    desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e"
                else:
                    desc = f"{info['epoch_idx'] + 1}e"

            # TODO: Implement len for Dataset?
            total = info[NUM_STEPS]
            if total is None:
                if hasattr(iterator, "__len__"):
                    total = len(iterator)

            _progress_bar = tqdm(total=total,
                                 desc=desc,
                                 unit="batch",
                                 leave=False)

        # Output for each batch.
        epoch_outputs = []

        for batch_idx, batch in enumerate(iterator):
            batch_info = {
                "batch_idx": batch_idx,
                "global_step": self.global_step
            }
            batch_info.update(info)
            batch_output = self.train_batch(batch, batch_info=batch_info)
            # batch output for each optimizer.
            epoch_outputs.append(batch_output)

            should_stop = batch_output["signal"] == -1

            if self.use_tqdm and self.world_rank == 0:
                _progress_bar.n = batch_idx + 1
                postfix = {}
                if "training_loss" in batch_output:
                    postfix.update(loss=batch_output["training_loss"])
                _progress_bar.set_postfix(postfix)

            for s_dict, scheduler in zip(self.scheduler_dicts,
                                         self.schedulers):
                if s_dict["interval"] == SCHEDULER_STEP_BATCH:
                    scheduler.step()

            self.global_step += 1

            if should_stop:
                break

        processed_outputs = None
        if is_overridden("training_epoch_end", model):
            raw_outputs = [eo["raw_output"] for eo in epoch_outputs]
            processed_outputs = model.training_epoch_end(raw_outputs)

        if processed_outputs is not None:
            if isinstance(processed_outputs, torch.Tensor):
                return_output = {"train_loss": processed_outputs}
            elif isinstance(processed_outputs, Result):
                raise ValueError("Result objects are not supported. Please "
                                 "return a dictionary instead.")
            elif isinstance(processed_outputs, dict):
                return_output = processed_outputs
            else:
                raise TypeError("training_epoch_end returned an invalid "
                                "type. It must return a Tensor, Result, "
                                "or dict.")
        else:
            # User did not override training_epoch_end
            assert isinstance(epoch_outputs, list)
            # Use AverageMeterCollection util to reduce results.
            meter_collection = AverageMeterCollection()
            for o in epoch_outputs:
                num_samples = o.pop(NUM_SAMPLES, 1)
                raw_output = o["raw_output"]
                if isinstance(raw_output, dict):
                    meter_collection.update(raw_output, num_samples)
                elif isinstance(raw_output, torch.Tensor):
                    meter_collection.update({"train_loss": o["training_loss"]},
                                            num_samples)
                return_output = meter_collection.summary()

        if self.is_function_implemented("on_train_epoch_end", model):
            model.on_train_epoch_end(
                [eo.get("raw_output") for eo in epoch_outputs])

        for s_dict, scheduler in zip(self.scheduler_dicts, self.schedulers):
            if s_dict["interval"] == SCHEDULER_STEP_EPOCH:
                scheduler.step()

        return return_output
Esempio n. 26
0
    def setup(self, config):
        # Pass in config if ptl_module accepts it.
        ptl_class = self.__class__._lightning_module_cls
        if not issubclass(ptl_class, ptl.LightningModule):
            raise TypeError("Argument must be subclass of "
                            "pytorch_lightning.LightningModule. Got class {} "
                            "instead.".format(ptl_class))
        if "config" in inspect.signature(ptl_class.__init__).parameters:
            ptl_module = ptl_class(config=config)
        else:
            ptl_module = ptl_class()

        # This is needed for LightningDistributedDataParallel.
        ptl_module.testing = False

        # Call on_fit_start on instantiation.
        if self.is_function_implemented("on_fit_start", ptl_module):
            ptl_module.on_fit_start()

        # Only run data preparation once per node.
        if self.local_rank == 0 and self.is_function_implemented(
                "prepare_data", ptl_module):
            ptl_module.prepare_data()

        # Call model.setup.
        ptl_module.setup("fit")

        if not is_overridden("configure_optimizers", ptl_module):
            raise MisconfigurationException(
                "No `configure_optimizers()` method defined.")

        optimizers, self._scheduler_dicts, optimizer_frequencies = \
            self.init_optimizers(model=ptl_module)

        if len(optimizer_frequencies) > 0:
            logger.warning("Optimizer frequencies will be ignored. When "
                           "passing in multiple optimizers, you should "
                           "implement your own custom training loop.")

        lr_schedulers = []
        for scheduler in self.scheduler_dicts:
            if isinstance(scheduler, dict):
                # A scheduler dictionary is passed in.
                if "reduce_on_plateau" in scheduler and "monitor" in \
                        scheduler and scheduler["reduce_on_plateau"] is True:
                    logger.info(
                        "reduce_on_plateau and monitor will be "
                        "ignored "
                        "from the scheduler dict {}. To update a "
                        "ReduceLROnPlateau scheduler, you should use "
                        "TorchTrainer.update_schedulers.".format(scheduler))
                if "frequency" in scheduler and scheduler["frequency"] > 1:
                    logger.info("frequency will be ignored from the "
                                "scheduler dict {}.".format(scheduler))
                lr_schedulers.append(scheduler["scheduler"])
            else:
                lr_schedulers.append(scheduler)

        # Set this so register doesn't complain.
        self._scheduler_step_freq = "ptl"
        ddp_model, self._optimizers, self._schedulers = self.register(
            models=[ptl_module],
            optimizers=optimizers,
            schedulers=lr_schedulers)

        assert len(ddp_model) == 1
        self._model = ddp_model[0]

        model = self.get_model()
        if self.is_function_implemented("on_pretrain_routine_start", model):
            model.on_pretrain_routine_start()

        train_data_loader = None
        if self.__class__._train_dataloader:
            train_data_loader = self.__class__._train_dataloader
        elif self.is_function_implemented("train_dataloader", model):
            train_data_loader = model.train_dataloader()

        val_data_loader = None
        if self.__class__._val_dataloader:
            val_data_loader = self.__class__._val_dataloader
        elif self.is_function_implemented("val_dataloader", model):
            val_data_loader = model.val_dataloader()

        self.register_data(train_loader=train_data_loader,
                           validation_loader=val_data_loader)
Esempio n. 27
0
    def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers):
        # epoch output is a list. Each item in that list has all the outputs per optimizer
        # epoch_output[optimizer_idx][training_step_idx][tbptt_index]
        # remember that not using truncated backprop is equivalent with truncated back prop of len(1)

        model = self.get_model()

        epoch_log_metrics = {}
        epoch_callback_metrics = {}
        epoch_progress_bar_metrics = {}

        # -----------------------
        # Calculate epoch callback values if given
        # -----------------------
        if checkpoint_accumulator.num_values > 0:
            epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean()

        if early_stopping_accumulator.num_values > 0:
            epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean()

        # ------------------------
        # determine if using a result obj
        # ------------------------
        # [optimizer_idx][training_step_idx][tbptt_index]
        opt_idx_outputs = epoch_output[0]

        try:
            sample_obj = opt_idx_outputs[0][0] if isinstance(opt_idx_outputs[0], list) else opt_idx_outputs[0]
            is_result_obj = len(epoch_output) > 0 and isinstance(sample_obj, Result)
        except IndexError as e:
            is_result_obj = False

        # --------------------------
        # EPOCH END STEP IF DEFINED
        # --------------------------
        if is_overridden('training_epoch_end', model=model):
            self.global_step += 1

            if is_result_obj:
                # with result object gather across time and training steps so each opt idx has a single result obj
                epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output)

            if num_optimizers == 1:
                epoch_output = epoch_output[0]

            # run training_epoch_end
            # a list with a result per optimizer index
            epoch_output = model.training_epoch_end(epoch_output)

            if isinstance(epoch_output, Result):
                epoch_log_metrics = epoch_output.epoch_log_metrics
                epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
            else:
                _processed_outputs = self.process_output(epoch_output)
                epoch_progress_bar_metrics = _processed_outputs[1]
                epoch_log_metrics = _processed_outputs[2]
                epoch_callback_metrics = _processed_outputs[3]

        # --------------------------
        # Structured Result (auto epoch end)
        # --------------------------
        elif is_result_obj:
            epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)

        # --------------------------
        # track results
        # --------------------------
        # add the metrics to the loggers
        if epoch_log_metrics and len(epoch_log_metrics) > 0:
            self.log_metrics(epoch_log_metrics, {})

        # add metrics to callbacks
        self.callback_metrics.update(epoch_callback_metrics)

        # add metrics to progress_bar
        if len(epoch_progress_bar_metrics) > 0:
            self.add_progress_bar_metrics(epoch_progress_bar_metrics)
Esempio n. 28
0
    def __verify_train_loop_configuration(self, model):
        # -----------------------------------
        # verify model has a training step
        # -----------------------------------
        has_training_step = is_overridden('training_step', model)
        if not has_training_step:
            raise MisconfigurationException(
                'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        # -----------------------------------
        # verify model has a train dataloader
        # -----------------------------------
        has_train_dataloader = is_overridden('train_dataloader', model)
        if not has_train_dataloader:
            raise MisconfigurationException(
                'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        # -----------------------------------
        # verify model has optimizer
        # -----------------------------------
        has_optimizers = is_overridden('configure_optimizers', model)
        if not has_optimizers:
            raise MisconfigurationException(
                'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
                ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.'
            )

        trainer = self.trainer

        trainer.overriden_optimizer_step = is_overridden('optimizer_step', model)
        trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model)

        enable_pl_optimizer = trainer._enable_pl_optimizer
        automatic_optimization = trainer.train_loop.automatic_optimization
        if trainer.overriden_optimizer_step and not enable_pl_optimizer and automatic_optimization:
            rank_zero_warn(
                "When overriding `LightningModule` optimizer_step with"
                " `Trainer(..., enable_pl_optimizer=False, automatic_optimization=True, ...)`,"
                " we won't be calling `.zero_grad` we can't assume when you call your `optimizer.step()`."
                " For Lightning to take care of it, please use `Trainer(enable_pl_optimizer=True)`."
            )

        going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches()

        has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad
        if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization:
            raise MisconfigurationException(
                'When overriding `LightningModule` optimizer_step or optimizer_zero_grad with '
                '`Trainer(automatic_optimization=True, ...)`, `accumulate_grad_batches` should to be 1.'
                ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
            )

        if (enable_pl_optimizer) and trainer.overriden_optimizer_zero_grad and not automatic_optimization:
            raise MisconfigurationException(
                'When overriding `LightningModule` optimizer_zero_grad with  '
                '`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported'
            )
Esempio n. 29
0
 def enable_validation(self) -> bool:
     """ Check if we should run validation during training. """
     model_ref = self.model_connector.get_model()
     val_loop_enabled = is_overridden(
         'validation_step', model_ref) and self.limit_val_batches > 0
     return val_loop_enabled or self.fast_dev_run
Esempio n. 30
0
def attach_step_and_epoch_functions(model, datamodule):
    datamodule.forward = model.forward
    for attr in dir(datamodule):
        if sum([token in attr for token in ["_step", "_epoch_end"]]) > 0:
            if not is_overridden(attr, model):
                setattr(model, attr, getattr(datamodule, attr))