Example #1
0
class LoggerConnector:

    def __init__(self, trainer, log_gpu_memory: Optional[str] = None):
        self.trainer = trainer
        self.log_gpu_memory = log_gpu_memory
        self._callback_metrics = MetricsHolder()
        self._evaluation_callback_metrics = MetricsHolder(to_float=True)
        self._logged_metrics = MetricsHolder()
        self._progress_bar_metrics = MetricsHolder(to_float=True)
        self.eval_loop_results = []
        self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage}
        self._cached_results[None] = EpochResultStore(trainer, None)
        self._callback_hook_validator = CallbackHookNameValidator()

    @property
    def callback_metrics(self) -> Dict:
        return self.get_metrics("callback_metrics")

    @callback_metrics.setter
    def callback_metrics(self, callback_metrics: Dict) -> None:
        self.set_metrics("callback_metrics", callback_metrics)

    @property
    def evaluation_callback_metrics(self) -> Dict:
        return self.get_metrics("evaluation_callback_metrics")

    @evaluation_callback_metrics.setter
    def evaluation_callback_metrics(self, evaluation_callback_metrics: Dict) -> None:
        self.set_metrics("evaluation_callback_metrics", evaluation_callback_metrics)

    @property
    def logged_metrics(self) -> Dict:
        return self.get_metrics("logged_metrics")

    @logged_metrics.setter
    def logged_metrics(self, logged_metrics: Dict) -> None:
        self.set_metrics("logged_metrics", logged_metrics)

    @property
    def progress_bar_metrics(self) -> Dict:
        return self.get_metrics("progress_bar_metrics")

    @progress_bar_metrics.setter
    def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:
        self.set_metrics("progress_bar_metrics", progress_bar_metrics)

    @property
    def cached_results(self) -> Union[EpochResultStore, None]:
        return self._cached_results.get(self.trainer._running_stage)  # type: ignore

    def get_metrics(self, key: str) -> Dict:
        metrics_holder = getattr(self, f"_{key}", None)
        model_ref = self.trainer.lightning_module
        metrics_holder.convert(
            self.trainer._device_type == DeviceType.TPU,
            model_ref.device if model_ref is not None else model_ref,
        )
        return metrics_holder.metrics

    def set_metrics(self, key: str, val: Dict) -> None:
        metrics_holder = getattr(self, f"_{key}", None)
        metrics_holder.reset(val)

    def reset(self) -> None:
        self.cached_results.reset()

    def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None:
        self._callback_hook_validator.check_logging_in_callbacks(
            current_hook_fx_name=hook_fx_name, on_step=on_step, on_epoch=on_epoch
        )

    def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders):
        # Todo: required argument `testing` is not used
        model = self.trainer.lightning_module
        # set dataloader_idx only if multiple ones
        model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
        # track batch_size
        self.cached_results._batch_size = Result.extract_batch_size(batch)

    def on_train_split_start(self, split_idx: int, opt_idx: int, split_batch) -> None:
        self.cached_results._split_idx = split_idx
        self.cached_results._opt_idx = opt_idx
        self.cached_results._batch_size = Result.extract_batch_size(split_batch)

    def on_train_batch_end(self) -> None:
        self.cached_results._split_idx = None
        self.cached_results._opt_idx = None
        self.cached_results._batch_size = None

    def cache_logged_metrics(self):
        self._cached_results[self.trainer._running_stage].cache_result()

    def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool):
        # logging
        self.configure_logger(logger)
        # todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders
        # and assign here the desired value
        self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
        self.trainer.log_every_n_steps = log_every_n_steps
        self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
        self.trainer.split_idx = None

    @property
    def should_flush_logs(self):
        should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
        return should_flush or self.trainer.should_stop

    @property
    def should_update_logs(self):
        should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
        return should_log_every_n_steps or self.trainer.should_stop

    def configure_logger(self, logger):
        if logger is True:
            version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)

            # default logger
            self.trainer.logger = TensorBoardLogger(
                save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs'
            )
        elif logger is False:
            self.trainer.logger = None
        else:
            if isinstance(logger, Iterable):
                self.trainer.logger = LoggerCollection(logger)
            else:
                self.trainer.logger = logger

    def cache_training_step_metrics(self, opt_closure_result):
        """
        This function is responsible to update
        logger_connector internals metrics holder based for depreceated logging
        """
        using_results_obj = isinstance(opt_closure_result.training_step_output, Result)

        # temporary dict to collect metrics
        logged_metrics_tmp = {}
        pbar_metrics_tmp = {}
        callback_metrics_tmp = {}

        if using_results_obj:
            batch_log_metrics = opt_closure_result.training_step_output.get_batch_log_metrics(
                include_forked_originals=False
            )
            logged_metrics_tmp.update(batch_log_metrics)

            batch_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics(
                include_forked_originals=False
            )
            pbar_metrics_tmp.update(batch_pbar_metrics)

            forked_metrics = opt_closure_result.training_step_output.get_forked_metrics()
            callback_metrics_tmp.update(forked_metrics)
            callback_metrics_tmp.update(logged_metrics_tmp)

        else:
            batch_log_metrics = opt_closure_result.training_step_output.log_metrics
            logged_metrics_tmp.update(batch_log_metrics)

            callback_metrics = opt_closure_result.training_step_output.callback_metrics
            callback_metrics_tmp.update(callback_metrics)

            batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
            pbar_metrics_tmp.update(batch_pbar_metrics)

        # track progress bar metrics
        if len(pbar_metrics_tmp) > 0:
            self.add_progress_bar_metrics(pbar_metrics_tmp)

        self._callback_metrics.update(callback_metrics_tmp)

        # save legacy log metrics
        self._logged_metrics.update(logged_metrics_tmp)
        self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp)

    def log_metrics(self, metrics, grad_norm_dic, step=None):
        """Logs the metric dict passed in.
        If `step` parameter is None and `step` key is presented is metrics,
        uses metrics["step"] as a step

        Args:
            metrics (dict): Metric values
            grad_norm_dic (dict): Gradient norms
            step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
            log_train_step_metrics (bool): Used to track if `log_metrics` function is being called in during training
                steps. In training steps, we will log metrics on step: `total_nb_idx` (for accumulated gradients)
                and global_step for the rest.
        """
        # add gpu memory
        if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
            mem_map = memory.get_memory_profile(self.log_gpu_memory)
            metrics.update(mem_map)

        # add norms
        metrics.update(grad_norm_dic)

        # turn all tensors to scalars
        scalar_metrics = self.trainer.metrics_to_scalars(metrics)

        if "step" in scalar_metrics and step is None:
            step = scalar_metrics.pop("step")

        elif step is None:
            # added metrics by Lightning for convenience
            scalar_metrics['epoch'] = self.trainer.current_epoch
            step = self.trainer.global_step

        # log actual metrics
        if self.trainer.logger is not None:
            if self.trainer.is_global_zero:
                self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
                self.trainer.logger.save()

            # track the logged metrics
            self.logged_metrics.update(scalar_metrics)
            self.trainer.dev_debugger.track_logged_metrics_history(scalar_metrics)

    def add_progress_bar_metrics(self, metrics):
        for k, v in metrics.items():
            if isinstance(v, torch.Tensor):
                v = v.item()

            self._progress_bar_metrics.metrics[k] = v

        self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

    def track_metrics_deprecated(self, deprecated_eval_results):
        self._track_callback_metrics(deprecated_eval_results)
        self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results)

    def evaluation_epoch_end(self, testing):
        # Todo: required argument `testing` is not used
        # reset dataloader idx
        model_ref = self.trainer.lightning_module
        model_ref._current_dataloader_idx = None

        # setting `has_batch_loop_finished` to True
        # will perform Results reduction accross entire epoch.
        self.cached_results.has_batch_loop_finished = True

    def add_to_eval_loop_results(self, dl_idx, has_been_initialized):
        callback_metrics = deepcopy(self.evaluation_callback_metrics)
        for key in list(callback_metrics.keys()):
            if "dataloader_idx" in key:
                if f"dataloader_idx_{dl_idx}" not in key:
                    # remove dl_idx from self.callback_metrics not belonging to this dataset.
                    del callback_metrics[key]
        if has_been_initialized:
            self.eval_loop_results[dl_idx].update(callback_metrics)
        else:
            self.eval_loop_results.append(callback_metrics)

    def prepare_eval_loop_results(self):
        num_dataloaders = self.trainer.evaluation_loop.num_dataloaders
        has_been_initialized = len(self.eval_loop_results) == num_dataloaders
        for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
            self.add_to_eval_loop_results(dl_idx, has_been_initialized)

    def get_evaluate_epoch_results(self):
        if not self.trainer.running_sanity_check:
            # log all the metrics as a single dict
            metrics_to_log = self.cached_results.get_epoch_log_metrics()
            if len(metrics_to_log) > 0:
                self.log_metrics(metrics_to_log, {})

        self.prepare_eval_loop_results()

        # log results of test
        if self.trainer.testing and self.trainer.is_global_zero and self.trainer.verbose_test:
            print('-' * 80)
            for result_idx, results in enumerate(self.eval_loop_results):
                print(f'DATALOADER:{result_idx} TEST RESULTS')
                pprint({
                    k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
                    for k, v in results.items()
                })
                print('-' * 80)

        results = self.eval_loop_results

        # clear mem
        self.eval_loop_results = []
        return results

    def _track_callback_metrics(self, eval_results):
        if len(eval_results) > 0 and (eval_results[0] is None or not isinstance(eval_results[0], Result)):
            return

        flat = {}
        if isinstance(eval_results, list):
            for eval_result in eval_results:
                # with a scalar return, auto set it to "val_loss" for callbacks
                if isinstance(eval_result, torch.Tensor):
                    flat = {'val_loss': eval_result}
                elif isinstance(eval_result, dict):
                    flat = flatten_dict(eval_result)

                # removing val_loss magic word to map to checkpoint + ES callback
                if 'val_loss' in flat:
                    flat['checkpoint_on'] = flat['val_loss']
                    flat['early_stop_on'] = flat['val_loss']
                self.trainer.logger_connector.callback_metrics.update(flat)
                if self.trainer.testing:
                    self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
        else:
            # with a scalar return, auto set it to "val_loss" for callbacks
            if isinstance(eval_results, torch.Tensor):
                flat = {'val_loss': eval_results}
            else:
                flat = flatten_dict(eval_results)

            # removing val_loss magic word to map to checkpoint + ES callback
            if 'val_loss' in flat:
                flat['checkpoint_on'] = flat['val_loss']
                flat['early_stop_on'] = flat['val_loss']

            self.trainer.logger_connector.callback_metrics.update(flat)
            if self.trainer.testing:
                self.trainer.logger_connector.evaluation_callback_metrics.update(flat)

    def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
        # eval loop returns all metrics
        dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}

        # add metrics to prog bar
        self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)

        # log metrics
        if len(log_metrics) > 0:
            self.trainer.logger_connector.log_metrics(log_metrics, {})

        # track metrics for callbacks (all prog bar, logged and callback metrics)
        callback_metrics.update(log_metrics)
        callback_metrics.update(prog_bar_metrics)
        self.trainer.logger_connector.callback_metrics.update(callback_metrics)
        if self.trainer.testing:
            self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)

        if len(dataloader_result_metrics) > 0:
            self.eval_loop_results.append(dataloader_result_metrics)

    def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
        if self.trainer.running_sanity_check:
            return

        if eval_results is not None and len(eval_results) > 0:

            # in eval, the user may return something at every validation step without final reduction
            if not isinstance(eval_results, list):
                eval_results = [eval_results]

            num_loaders: int = self.trainer.evaluation_loop.num_dataloaders
            prog_bar_metrics, log_metrics, callback_metrics = {}, {}, {}

            for result_idx, result in enumerate(eval_results):
                _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result)

                if num_loaders > 1:
                    self.__process_eval_epoch_end_results_and_log_legacy_update(
                        prog_bar_metrics, log_metrics, callback_metrics
                    )

            if num_loaders == 1:
                self.__process_eval_epoch_end_results_and_log_legacy_update(
                    prog_bar_metrics, log_metrics, callback_metrics
                )

    def on_train_epoch_end(self):
        # inform cached logger connector epoch finished
        self.cached_results.has_batch_loop_finished = True

    def log_train_epoch_end_metrics(
        self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers
    ):
        # epoch output is a list. Each item in that list has all the outputs per optimizer
        # epoch_output[optimizer_idx][training_step_idx][tbptt_index]
        # remember that not using truncated backprop is equivalent with truncated back prop of len(1)

        model = self.trainer.lightning_module

        epoch_callback_metrics = {}

        # -----------------------
        # Calculate epoch callback values if given
        # -----------------------
        if checkpoint_accumulator.num_values > 0:
            epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean()

        if early_stopping_accumulator.num_values > 0:
            epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean()

        # ------------------------
        # determine if using a result obj
        # ------------------------
        # [optimizer_idx][training_step_idx][tbptt_index]
        opt_idx_outputs = epoch_output[0]

        # TODO: deprecate 1.0
        try:
            sample_obj = opt_idx_outputs[0][0] if isinstance(opt_idx_outputs[0], list) else opt_idx_outputs[0]
            is_result_obj = len(epoch_output) > 0 and isinstance(sample_obj, Result)
            is_1_0_result = is_result_obj and 'extra' in sample_obj
        except IndexError:
            is_result_obj = False
            is_1_0_result = False

        # ------------------
        # NEW 1.0.0 PATH
        # ------------------
        if is_1_0_result:
            # lightning module hook
            self.training_epoch_end(model, epoch_output, num_optimizers)

            # log/aggregate metrics automatically
            epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)

        # TODO: deprecate 1.0
        else:
            out = self.__run_legacy_training_epoch_end(
                num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics
            )
            epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out

        # it will perform reduction over epoch and return log metrics
        cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics()
        cached_epoch_pbar_metrics = self.cached_results.get_epoch_pbar_metrics()

        # update
        epoch_log_metrics.update(cached_epoch_log_metrics)
        epoch_progress_bar_metrics.update(cached_epoch_pbar_metrics)

        # --------------------------
        # track results
        # --------------------------
        # add the metrics to the loggers and callbacks
        if epoch_log_metrics and len(epoch_log_metrics) > 0:
            self.log_metrics(epoch_log_metrics, {})
            self._callback_metrics.update(epoch_log_metrics)

        # add metrics to callbacks
        self._callback_metrics.update(epoch_callback_metrics)

        # add metrics to progress_bar and callbacks
        if len(epoch_progress_bar_metrics) > 0:
            self.add_progress_bar_metrics(epoch_progress_bar_metrics)
            self._callback_metrics.update(epoch_progress_bar_metrics)

        # reset epoch loop result for next epoch
        self.cached_results.reset()

    def training_epoch_end(self, model, epoch_output, num_optimizers):
        if not is_overridden('training_epoch_end', model=model):
            return

        # run training_epoch_end
        # refresh the result for custom logging at the epoch level
        model._current_fx_name = 'training_epoch_end'
        epoch_output = self.__prepare_epoch_end_inputs(epoch_output)

        if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization:
            epoch_output = epoch_output[0]

        # lightningmodule hook
        epoch_output = model.training_epoch_end(epoch_output)

        if epoch_output is not None:
            raise MisconfigurationException(
                'training_epoch_end expects a return of None. '
                'HINT: remove the return statement in training_epoch_end'
            )
        # capture logging
        self.trainer.logger_connector.cache_logged_metrics()

    def __run_legacy_training_epoch_end(
        self, num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics
    ):

        epoch_log_metrics = {}
        epoch_progress_bar_metrics = {}

        # --------------------------
        # EPOCH END STEP IF DEFINED
        # --------------------------
        if is_overridden('training_epoch_end', model=model):
            if is_result_obj:
                # with result object gather across time and training steps so each opt idx has a single result obj
                epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output)

            if num_optimizers == 1:
                epoch_output = epoch_output[0]

            # run training_epoch_end
            # a list with a result per optimizer index
            model._current_fx_name = 'training_epoch_end'
            epoch_output = model.training_epoch_end(epoch_output)

            # capture logging
            self.trainer.logger_connector.cache_logged_metrics()

            if isinstance(epoch_output, Result):
                epoch_log_metrics = epoch_output.epoch_log_metrics
                epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
            else:
                _processed_outputs = self.trainer.process_dict_result(epoch_output)
                epoch_progress_bar_metrics = _processed_outputs[1]
                epoch_log_metrics = _processed_outputs[2]
                epoch_callback_metrics = _processed_outputs[3]

        # --------------------------
        # Structured Result (auto epoch end)
        # --------------------------
        elif is_result_obj:
            epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)

        return epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics

    def __auto_reduce_results_on_epoch_end(self, epoch_output):
        epoch_log_metrics = {}
        epoch_progress_bar_metrics = {}
        for opt_outputs in epoch_output:
            # reduce across time first
            time_reduced_outputs = []
            for tbptt_outs in opt_outputs:
                tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
                if len(tbptt_outs) > 1:
                    time_reduced_outputs.append(tbptt_outs)

            if len(time_reduced_outputs) == 0:
                continue

            # reduce across training steps
            opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs)

            # with manual opt need 1 + metrics because meta is always there
            if opt_outputs.minimize is not None:
                opt_outputs.minimize = opt_outputs.minimize.mean()
            epoch_log_metrics.update(opt_outputs.epoch_log_metrics)
            epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics)

        return epoch_log_metrics, epoch_progress_bar_metrics

    def __prepare_epoch_end_inputs(self, epoch_output):
        """
        Pulls out only the "extra" information for epoch end

        Return:
            a single list, each element per optimizer then batch then time
        """
        gathered_epoch_outputs = []
        for opt_outputs in epoch_output:
            # gather across time first
            time_gathered_outputs = []
            for tbptt_outs in opt_outputs:
                result = []
                for x in tbptt_outs:
                    out = x.extra
                    out['loss'] = x.minimize
                    result.append(out)

                # when time = 0, pass in the literal dict instead of array
                if len(result) == 1:
                    result = result[0]
                time_gathered_outputs.append(result)

            gathered_epoch_outputs.append(time_gathered_outputs)

        return gathered_epoch_outputs

    def __gather_result_across_time_and_optimizers(self, epoch_output):
        """
        Gather results into a single padded tensor per metric where each tensor is gathered across
        time and across time steps.

        Returns:
            a list where each element is a Result with the tensors gathered
        """
        gathered_epoch_outputs = []
        for opt_outputs in epoch_output:
            # gather across time first
            time_gathered_outputs = []
            for tbptt_outs in opt_outputs:
                tbptt_outs = tbptt_outs[0].__class__.gather(tbptt_outs)
                time_gathered_outputs.append(tbptt_outs)

            # gather across training steps
            # each metric has dimensions (training_steps, seq_len) (seq_len=1 when no tbptt is used)
            gathered_opt_output = time_gathered_outputs[0].__class__.padded_gather(time_gathered_outputs)
            gathered_epoch_outputs.append(gathered_opt_output)

        return gathered_epoch_outputs

    def log_train_step_metrics(self, batch_output):
        if self.trainer.train_loop.should_accumulate() and self.trainer.train_loop.automatic_optimization:
            return
        _, batch_log_metrics = self.cached_results.update_logger_connector()
        # when metrics should be logged
        if self.should_update_logs or self.trainer.fast_dev_run is True:
            # logs user requested information to logger
            grad_norm_dic = batch_output.grad_norm_dic
            if grad_norm_dic is None:
                grad_norm_dic = {}
            if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0:
                self.log_metrics(batch_log_metrics, grad_norm_dic)
                self._callback_metrics.update(batch_log_metrics)
def test_call_back_validator(tmpdir):

    funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')])

    callbacks_func = [
        'on_after_backward',
        'on_batch_end',
        'on_batch_start',
        'on_before_zero_grad',
        'on_epoch_end',
        'on_epoch_start',
        'on_fit_end',
        'on_fit_start',
        'on_init_end', 'on_init_start',
        'on_keyboard_interrupt',
        'on_load_checkpoint',
        'on_pretrain_routine_end',
        'on_pretrain_routine_start',
        'on_sanity_check_end',
        'on_sanity_check_start',
        'on_save_checkpoint',
        'on_test_batch_end',
        'on_test_batch_start',
        'on_test_end',
        'on_test_epoch_end',
        'on_test_epoch_start',
        'on_test_start',
        'on_train_batch_end',
        'on_train_batch_start',
        'on_train_end',
        'on_train_epoch_end',
        'on_train_epoch_start',
        'on_train_start',
        'on_validation_batch_end',
        'on_validation_batch_start',
        'on_validation_end',
        'on_validation_epoch_end',
        'on_validation_epoch_start',
        'on_validation_start',
        'setup',
        'teardown',
    ]

    not_supported = [
        "on_fit_end",
        "on_fit_start",
        "on_init_end",
        "on_init_start",
        "on_keyboard_interrupt",
        "on_load_checkpoint",
        "on_pretrain_routine_end",
        "on_pretrain_routine_start",
        "on_sanity_check_end",
        "on_sanity_check_start",
        "on_save_checkpoint",
        "on_test_end",
        "on_train_end",
        "on_validation_end",
        "setup",
        "teardown",
    ]

    assert funcs_name == callbacks_func, """Detected new callback function.
        Need to add its logging permission to CallbackHookNameValidator and update this test"""

    validator = CallbackHookNameValidator()

    for func_name in funcs_name:
        # This summurize where and what is currently possible to log using `self.log` function.
        is_stage = "train" in func_name or "test" in func_name or "validation" in func_name
        is_start = "start" in func_name or "batch" in func_name
        on_step = is_stage and is_start
        on_epoch = True
        # creating allowed condition
        allowed = (
            is_stage
            or "batch" in func_name
            or "epoch" in func_name
            or "grad" in func_name
            or "backward" in func_name
        )
        allowed = (
            allowed
            and "pretrain" not in func_name
            and func_name not in ["on_train_end", "on_test_end", "on_validation_end"]
        )
        if allowed:
            validator.check_logging_in_callbacks(current_hook_fx_name=func_name,
                                                 on_step=on_step,
                                                 on_epoch=on_epoch)
            if not is_start and is_stage:
                with pytest.raises(MisconfigurationException, match="function supports only"):
                    validator.check_logging_in_callbacks(current_hook_fx_name=func_name,
                                                         on_step=True,
                                                         on_epoch=on_epoch)
        else:
            assert func_name in not_supported
            with pytest.raises(MisconfigurationException, match="function doesn't support"):
                validator.check_logging_in_callbacks(current_hook_fx_name=func_name,
                                                     on_step=on_step,
                                                     on_epoch=on_epoch)

        result = validator.check_logging_in_callbacks(current_hook_fx_name=None,
                                                      on_step=None,
                                                      on_epoch=None)
        assert result is None
class LoggerConnector:
    def __init__(self, trainer, log_gpu_memory: Optional[str] = None):
        self.trainer = trainer
        self.log_gpu_memory = log_gpu_memory
        self._callback_metrics = MetricsHolder()
        self._evaluation_callback_metrics = MetricsHolder(to_float=True)
        self._logged_metrics = MetricsHolder()
        self._progress_bar_metrics = MetricsHolder(to_float=True)
        self.eval_loop_results = []
        self._cached_results = {
            stage: EpochResultStore(trainer)
            for stage in RunningStage
        }
        self._cached_results[None] = EpochResultStore(trainer)
        self._callback_hook_validator = CallbackHookNameValidator()

    @property
    def callback_metrics(self) -> Dict:
        return self.get_metrics("callback_metrics")

    @callback_metrics.setter
    def callback_metrics(self, callback_metrics: Dict) -> None:
        self.set_metrics("callback_metrics", callback_metrics)

    @property
    def evaluation_callback_metrics(self) -> Dict:
        return self.get_metrics("evaluation_callback_metrics")

    @evaluation_callback_metrics.setter
    def evaluation_callback_metrics(self,
                                    evaluation_callback_metrics: Dict) -> None:
        self.set_metrics("evaluation_callback_metrics",
                         evaluation_callback_metrics)

    @property
    def logged_metrics(self) -> Dict:
        return self.get_metrics("logged_metrics")

    @logged_metrics.setter
    def logged_metrics(self, logged_metrics: Dict) -> None:
        self.set_metrics("logged_metrics", logged_metrics)

    @property
    def progress_bar_metrics(self) -> Dict:
        return self.get_metrics("progress_bar_metrics")

    @progress_bar_metrics.setter
    def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:
        self.set_metrics("progress_bar_metrics", progress_bar_metrics)

    @property
    def cached_results(self) -> Union[EpochResultStore, None]:
        return self._cached_results.get(self.trainer._running_stage)

    def get_metrics(self, key: str) -> Dict:
        metrics_holder: MetricsHolder = getattr(self, f"_{key}")
        model = self.trainer.lightning_module
        metrics_holder.convert(model.device if model is not None else None)
        return metrics_holder.metrics

    def set_metrics(self, key: str, val: Dict) -> None:
        metrics_holder: MetricsHolder = getattr(self, f"_{key}")
        metrics_holder.reset(val)

    def reset(self) -> None:
        self.cached_results.reset()

    def check_logging_in_callbacks(self,
                                   hook_fx_name,
                                   on_step: bool = None,
                                   on_epoch: bool = None) -> None:
        self._callback_hook_validator.check_logging_in_callbacks(
            current_hook_fx_name=hook_fx_name,
            on_step=on_step,
            on_epoch=on_epoch)

    def on_evaluation_batch_start(self, batch, dataloader_idx,
                                  num_dataloaders):
        model = self.trainer.lightning_module
        # set dataloader_idx only if multiple ones
        model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
        # track batch_size
        self.cached_results._batch_size = Result.extract_batch_size(batch)

    def on_train_split_start(self, split_idx: int, opt_idx: int,
                             split_batch) -> None:
        self.cached_results._split_idx = split_idx
        self.cached_results._opt_idx = opt_idx
        self.cached_results._batch_size = Result.extract_batch_size(
            split_batch)

    def on_train_batch_end(self) -> None:
        self.cached_results._split_idx = None
        self.cached_results._opt_idx = None
        self.cached_results._batch_size = None

    def cache_logged_metrics(self):
        self._cached_results[self.trainer._running_stage].cache_result()

    def on_trainer_init(self, logger, flush_logs_every_n_steps: int,
                        log_every_n_steps: int, move_metrics_to_cpu: bool):
        # logging
        self.configure_logger(logger)
        self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
        self.trainer.log_every_n_steps = log_every_n_steps
        self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
        self.trainer.split_idx = None

    @property
    def should_flush_logs(self):
        should_flush = (self.trainer.global_step +
                        1) % self.trainer.flush_logs_every_n_steps == 0
        return should_flush or self.trainer.should_stop

    @property
    def should_update_logs(self):
        should_log_every_n_steps = (self.trainer.global_step +
                                    1) % self.trainer.log_every_n_steps == 0
        return should_log_every_n_steps or self.trainer.should_stop

    def configure_logger(self, logger):
        if logger is True:
            version = os.environ.get('PL_EXP_VERSION',
                                     self.trainer.slurm_job_id)

            # default logger
            self.trainer.logger = TensorBoardLogger(
                save_dir=self.trainer.default_root_dir,
                version=version,
                name='lightning_logs')
        elif logger is False:
            self.trainer.logger = None
        else:
            if isinstance(logger, Iterable):
                self.trainer.logger = LoggerCollection(logger)
            else:
                self.trainer.logger = logger

    def cache_training_step_metrics(self, opt_closure_result):
        """
        This function is responsible to update
        logger_connector internals metrics holder based for depreceated logging
        """
        using_results_obj = isinstance(opt_closure_result.training_step_output,
                                       Result)

        # temporary dict to collect metrics
        logged_metrics_tmp = {}
        pbar_metrics_tmp = {}
        callback_metrics_tmp = {}

        if using_results_obj:
            batch_log_metrics = opt_closure_result.training_step_output.get_batch_log_metrics(
                include_forked_originals=False)
            logged_metrics_tmp.update(batch_log_metrics)

            batch_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics(
                include_forked_originals=False)
            pbar_metrics_tmp.update(batch_pbar_metrics)

            forked_metrics = opt_closure_result.training_step_output.get_forked_metrics(
            )
            callback_metrics_tmp.update(forked_metrics)
            callback_metrics_tmp.update(logged_metrics_tmp)

        else:
            batch_log_metrics = opt_closure_result.training_step_output.log_metrics
            logged_metrics_tmp.update(batch_log_metrics)

            batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
            pbar_metrics_tmp.update(batch_pbar_metrics)

        # track progress bar metrics
        if len(pbar_metrics_tmp) > 0:
            self.add_progress_bar_metrics(pbar_metrics_tmp)

        self._callback_metrics.update(callback_metrics_tmp)
        self._logged_metrics.update(logged_metrics_tmp)

    def log_metrics(self, metrics, grad_norm_dic, step=None):
        """Logs the metric dict passed in.
        If `step` parameter is None and `step` key is presented is metrics,
        uses metrics["step"] as a step

        Args:
            metrics (dict): Metric values
            grad_norm_dic (dict): Gradient norms
            step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
        """
        # add gpu memory
        if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
            mem_map = memory.get_memory_profile(self.log_gpu_memory)
            metrics.update(mem_map)

        # add norms
        metrics.update(grad_norm_dic)

        # turn all tensors to scalars
        scalar_metrics = self.trainer.metrics_to_scalars(metrics)

        if "step" in scalar_metrics and step is None:
            step = scalar_metrics.pop("step")

        elif step is None:
            # added metrics by Lightning for convenience
            scalar_metrics['epoch'] = self.trainer.current_epoch
            step = self.trainer.global_step

        # log actual metrics
        if self.trainer.logger is not None:
            if self.trainer.is_global_zero:
                self.trainer.logger.agg_and_log_metrics(scalar_metrics,
                                                        step=step)
                self.trainer.logger.save()

            # track the logged metrics
            self.logged_metrics.update(scalar_metrics)
            self.trainer.dev_debugger.track_logged_metrics_history(
                scalar_metrics)

    def add_progress_bar_metrics(self, metrics):
        for k, v in metrics.items():
            if isinstance(v, torch.Tensor):
                v = v.item()

            self._progress_bar_metrics.metrics[k] = v

        self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

    def evaluation_epoch_end(self):
        # reset dataloader idx
        model_ref = self.trainer.lightning_module
        model_ref._current_dataloader_idx = None

        # setting `has_batch_loop_finished` to True
        # will perform Results reduction accross entire epoch.
        self.cached_results.has_batch_loop_finished = True

    def add_to_eval_loop_results(self, dl_idx, has_been_initialized):
        callback_metrics = deepcopy(self.evaluation_callback_metrics)
        for key in list(callback_metrics.keys()):
            if "dataloader_idx" in key:
                if f"dataloader_idx_{dl_idx}" not in key:
                    # remove dl_idx from self.callback_metrics not belonging to this dataset.
                    del callback_metrics[key]
        if has_been_initialized:
            self.eval_loop_results[dl_idx].update(callback_metrics)
        else:
            self.eval_loop_results.append(callback_metrics)

    def prepare_eval_loop_results(self):
        num_dataloaders = self.trainer.evaluation_loop.num_dataloaders
        has_been_initialized = len(self.eval_loop_results) == num_dataloaders
        for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
            self.add_to_eval_loop_results(dl_idx, has_been_initialized)

    def get_evaluate_epoch_results(self):
        if not self.trainer.sanity_checking:
            # log all the metrics as a single dict
            metrics_to_log = self.cached_results.get_epoch_log_metrics()
            if len(metrics_to_log) > 0:
                self.log_metrics(metrics_to_log, {})

        self.prepare_eval_loop_results()

        # log results of evaluation
        if (self.trainer.state != TrainerState.FITTING
                and self.trainer.evaluating and self.trainer.is_global_zero
                and self.trainer.verbose_evaluate):
            print('-' * 80)
            for result_idx, results in enumerate(self.eval_loop_results):
                print(
                    f'DATALOADER:{result_idx} {self.trainer._running_stage.upper()} RESULTS'
                )
                pprint({
                    k: (v.item() if v.numel() == 1 else v.tolist())
                    if isinstance(v, torch.Tensor) else v
                    for k, v in results.items()
                })
                print('-' * 80)

        results = self.eval_loop_results

        # clear mem
        self.eval_loop_results = []
        return results

    def on_train_epoch_end(self):
        # inform cached logger connector epoch finished
        self.cached_results.has_batch_loop_finished = True

    def log_train_epoch_end_metrics(self, epoch_output, num_optimizers):
        # epoch output is a list. Each item in that list has all the outputs per optimizer
        # epoch_output[optimizer_idx][training_step_idx][tbptt_index]
        # remember that not using truncated backprop is equivalent with truncated back prop of len(1)

        model = self.trainer.lightning_module

        # lightning module hook
        self.training_epoch_end(model, epoch_output, num_optimizers)

        # log/aggregate metrics automatically
        epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(
            epoch_output)

        # it will perform reduction over epoch and return log metrics
        cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics()
        cached_epoch_pbar_metrics = self.cached_results.get_epoch_pbar_metrics(
        )

        # update
        epoch_log_metrics.update(cached_epoch_log_metrics)
        epoch_progress_bar_metrics.update(cached_epoch_pbar_metrics)

        # --------------------------
        # track results
        # --------------------------
        # add the metrics to the loggers and callbacks
        if epoch_log_metrics and len(epoch_log_metrics) > 0:
            self.log_metrics(epoch_log_metrics, {})
            self._callback_metrics.update(epoch_log_metrics)

        # add metrics to progress_bar and callbacks
        if len(epoch_progress_bar_metrics) > 0:
            self.add_progress_bar_metrics(epoch_progress_bar_metrics)
            self._callback_metrics.update(epoch_progress_bar_metrics)

        # reset epoch loop result for next epoch
        self.cached_results.reset()

    def training_epoch_end(self, model, epoch_output, num_optimizers):
        if not is_overridden('training_epoch_end', model=model):
            return

        # run training_epoch_end
        # refresh the result for custom logging at the epoch level
        model._current_fx_name = 'training_epoch_end'
        epoch_output = self.__prepare_epoch_end_inputs(epoch_output)

        if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization:
            epoch_output = epoch_output[0]

        # lightningmodule hook
        epoch_output = model.training_epoch_end(epoch_output)

        if epoch_output is not None:
            raise MisconfigurationException(
                'training_epoch_end expects a return of None. '
                'HINT: remove the return statement in training_epoch_end')
        # capture logging
        self.trainer.logger_connector.cache_logged_metrics()

    def __auto_reduce_results_on_epoch_end(self, epoch_output):
        epoch_log_metrics = {}
        epoch_progress_bar_metrics = {}
        for opt_outputs in epoch_output:
            # reduce across time first
            time_reduced_outputs = []
            for tbptt_outs in opt_outputs:
                tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(
                    tbptt_outs)
                if len(tbptt_outs) > 1:
                    time_reduced_outputs.append(tbptt_outs)

            if len(time_reduced_outputs) == 0:
                continue

            # reduce across training steps
            opt_outputs = time_reduced_outputs[
                0].__class__.reduce_on_epoch_end(time_reduced_outputs)

            # with manual opt need 1 + metrics because meta is always there
            if opt_outputs.minimize is not None:
                opt_outputs.minimize = opt_outputs.minimize.mean()
            epoch_log_metrics.update(opt_outputs.epoch_log_metrics)
            epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics)

        return epoch_log_metrics, epoch_progress_bar_metrics

    def __prepare_epoch_end_inputs(self, epoch_output):
        """
        Pulls out only the "extra" information for epoch end

        Return:
            a single list, each element per optimizer then batch then time
        """
        gathered_epoch_outputs = []
        for opt_outputs in epoch_output:
            # gather across time first
            time_gathered_outputs = []
            for tbptt_outs in opt_outputs:
                result = []
                for x in tbptt_outs:
                    out = x.extra
                    out['loss'] = x.minimize
                    result.append(out)

                # when time = 0, pass in the literal dict instead of array
                if len(result) == 1:
                    result = result[0]
                time_gathered_outputs.append(result)

            gathered_epoch_outputs.append(time_gathered_outputs)

        return gathered_epoch_outputs

    def log_train_step_metrics(self, batch_output):
        if self.trainer.train_loop.should_accumulate(
        ) and self.trainer.train_loop.automatic_optimization:
            return
        _, batch_log_metrics = self.cached_results.update_logger_connector()
        # when metrics should be logged
        if self.should_update_logs or self.trainer.fast_dev_run is True:
            # logs user requested information to logger
            grad_norm_dic = batch_output.grad_norm_dic
            if grad_norm_dic is None:
                grad_norm_dic = {}
            if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0:
                self.log_metrics(batch_log_metrics, grad_norm_dic)
                self._callback_metrics.update(batch_log_metrics)
class LoggerConnector:
    def __init__(self, trainer):
        self.trainer = trainer
        self.callback_metrics = {}
        self.logged_metrics = {}
        self.progress_bar_metrics = {}
        self.eval_loop_results = []
        self._stages = sorted([s.value for s in LoggerStages])
        self._cached_results = {
            stage: EpochResultStore(trainer, stage)
            for stage in self._stages
        }
        self._callback_hook_validator = CallbackHookNameValidator()
        self._current_stage = None

    def cached_results(
            self,
            stage_or_testing: Union[str,
                                    bool]) -> Union[EpochResultStore, None]:
        """ Function to access cached_results using str or bool. Bool is used only for testing"""
        stage_or_testing = str(stage_or_testing)
        stages = self._stages
        if stage_or_testing in self._stages:
            return self._cached_results[stage_or_testing]
        if stage_or_testing in LOOKUP_TABLE:
            # Acces using trainer.testing
            stage = LOOKUP_TABLE[stage_or_testing]
            return self._cached_results[stage]
        raise MisconfigurationException(
            f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {self._stages}"
            f" or {LOOKUP_TABLE.keys()}")

    def set_stage(self, stage_or_testing: str, reset: bool = False) -> None:
        self._current_stage = self._determine_stage(stage_or_testing)
        if reset:
            self.cached_results(stage_or_testing).reset()

    def check_logging_in_callbacks(self,
                                   hook_fx_name,
                                   on_step: bool = None,
                                   on_epoch: bool = None) -> None:
        self._callback_hook_validator.check_logging_in_callbacks(
            current_hook_fx_name=hook_fx_name,
            on_step=on_step,
            on_epoch=on_epoch)

    def on_evaluation_batch_start(self, testing, batch, dataloader_idx,
                                  num_dataloaders):
        # reset the result of the PL module
        model = self.trainer.get_model()
        model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

        # track batch_size
        self.cached_results(testing)._batch_size = Result.extract_batch_size(
            batch)

    def on_batch_start(self, split_idx: int, opt_idx: int,
                       split_batch) -> None:
        self._cached_results["train"]._split_idx = split_idx
        self._cached_results["train"]._opt_idx = opt_idx
        self._cached_results["train"]._batch_size = Result.extract_batch_size(
            split_batch)

    def on_train_batch_end(self) -> None:
        self._cached_results["train"]._split_idx = None
        self._cached_results["train"]._opt_idx = None
        self._cached_results["train"]._batch_size = None

    def _determine_stage(self, stage_or_testing: Union[str, bool]) -> str:
        stage_or_testing = str(stage_or_testing)
        stages = self._stages
        if stage_or_testing in stages:
            return stage_or_testing
        if stage_or_testing in LOOKUP_TABLE:
            # Acces using trainer.testing
            return LOOKUP_TABLE[stage_or_testing]
        raise MisconfigurationException(
            f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {stages}"
            f" or {LOOKUP_TABLE.keys()}")

    def cache_logged_metrics(self) -> Union[EpochResultStore, None]:
        if self._current_stage is not None:
            self._cached_results[self._current_stage].cache_result()

    def on_trainer_init(self, logger, flush_logs_every_n_steps,
                        log_every_n_steps):
        # logging
        self.configure_logger(logger)
        # todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders
        #  and assign here the desired value
        self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
        self.trainer.log_every_n_steps = log_every_n_steps

    def configure_logger(self, logger):
        if logger is True:
            version = os.environ.get('PL_EXP_VERSION',
                                     self.trainer.slurm_job_id)

            # default logger
            self.trainer.logger = TensorBoardLogger(
                save_dir=self.trainer.default_root_dir,
                version=version,
                name='lightning_logs')
        elif logger is False:
            self.trainer.logger = None
        else:
            if isinstance(logger, Iterable):
                self.trainer.logger = LoggerCollection(logger)
            else:
                self.trainer.logger = logger

    def log_metrics(self, metrics, grad_norm_dic, step=None):
        """Logs the metric dict passed in.
        If `step` parameter is None and `step` key is presented is metrics,
        uses metrics["step"] as a step

        Args:
            metrics (dict): Metric values
            grad_norm_dic (dict): Gradient norms
            step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
        """
        # add gpu memory
        if self.trainer.on_gpu and self.trainer.log_gpu_memory:
            mem_map = memory.get_memory_profile(self.trainer.log_gpu_memory)
            metrics.update(mem_map)

        # add norms
        metrics.update(grad_norm_dic)

        # turn all tensors to scalars
        scalar_metrics = self.trainer.metrics_to_scalars(metrics)

        if "step" in scalar_metrics and step is None:
            step = scalar_metrics.pop("step")

        elif step is None:
            # added metrics by Lightning for convenience
            scalar_metrics['epoch'] = self.trainer.current_epoch
            step = self.trainer.global_step

        # log actual metrics
        if self.trainer.logger is not None:
            if self.trainer.is_global_zero:
                self.trainer.logger.agg_and_log_metrics(scalar_metrics,
                                                        step=step)
                self.trainer.logger.save()

            # track the logged metrics
            self.logged_metrics.update(scalar_metrics)
            self.trainer.dev_debugger.track_logged_metrics_history(
                scalar_metrics)

    def add_progress_bar_metrics(self, metrics):
        for k, v in metrics.items():
            if isinstance(v, torch.Tensor):
                v = v.item()

            self.progress_bar_metrics[k] = v

        self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

    def on_evaluation_epoch_end(self, deprecated_eval_results, epoch_logs,
                                using_eval_result, test_mode):
        self._track_callback_metrics(deprecated_eval_results,
                                     using_eval_result)

        # TODO: deprecate parts of this for 1.0 (when removing results)
        self.__process_eval_epoch_end_results_and_log_legacy(
            deprecated_eval_results, test_mode)

        self._log_on_evaluation_epoch_end_metrics(epoch_logs)

        # get the final loop results
        eval_loop_results = self._get_evaluate_epoch_results(test_mode)
        return eval_loop_results

    def _get_evaluate_epoch_results(self, test_mode):
        # log results of test
        if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
            print('-' * 80)
            for result_idx, results in enumerate(self.eval_loop_results):
                print(f'DATALOADER:{result_idx} TEST RESULTS')
                pprint(results)
                print('-' * 80)

        results = self.eval_loop_results

        # clear mem
        self.eval_loop_results = []
        return results

    def _log_on_evaluation_epoch_end_metrics(self, epoch_logs):
        step_metrics = self.trainer.evaluation_loop.step_metrics

        num_loaders = len(step_metrics)

        # clear mem
        self.trainer.evaluation_loop.step_metrics = []

        if self.trainer.running_sanity_check:
            return

        # track all metrics we want to log
        metrics_to_log = []

        # ---------------------------
        # UPDATE EPOCH LOGGED METRICS
        # ---------------------------
        # (ie: in methods at the val_epoch_end level)
        # union the epoch logs with whatever was returned from loaders and reduced
        epoch_logger_metrics = epoch_logs.get_epoch_log_metrics()
        epoch_pbar_metrics = epoch_logs.get_epoch_pbar_metrics()

        self.logged_metrics.update(epoch_logger_metrics)
        self.add_progress_bar_metrics(epoch_pbar_metrics)

        # enable the metrics to be monitored
        self.callback_metrics.update(epoch_logger_metrics)
        self.callback_metrics.update(epoch_pbar_metrics)

        if len(epoch_logger_metrics) > 0:
            metrics_to_log.append(epoch_logger_metrics)

        # --------------------------------
        # UPDATE  METRICS PER DATALOADER
        # --------------------------------
        # each dataloader aggregated metrics
        # now we log all of them
        for dl_idx, dl_metrics in enumerate(step_metrics):
            if len(dl_metrics) == 0:
                # Ensure custom logged metrics are included if not included with step metrics
                if len(epoch_logger_metrics) > 0:
                    self.eval_loop_results.append(epoch_logger_metrics)
                continue

            reduced_epoch_metrics = dl_metrics[
                0].__class__.reduce_on_epoch_end(dl_metrics)
            # track the metrics
            logger_metrics = reduced_epoch_metrics.get_epoch_log_metrics()
            pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics()
            forked_metrics = reduced_epoch_metrics.get_forked_metrics()

            # make the keys 'k/dl'
            logger_metrics = self.__rename_keys_by_dataloader_idx(
                logger_metrics, dl_idx, num_loaders)
            pbar_metrics = self.__rename_keys_by_dataloader_idx(
                pbar_metrics, dl_idx, num_loaders)
            forked_metrics = self.__rename_keys_by_dataloader_idx(
                forked_metrics, dl_idx, num_loaders)

            self.logged_metrics.update(logger_metrics)
            self.add_progress_bar_metrics(pbar_metrics)

            # enable the metrics to be monitored
            self.callback_metrics.update(logger_metrics)
            self.callback_metrics.update(pbar_metrics)

            # forked metrics were dropped, enable them for callbacks
            self.callback_metrics.update(forked_metrics)

            # track the final results for the dataloader
            self.add_to_eval_loop_results(dl_idx, num_loaders)

            # actually log
            if len(logger_metrics) > 0:
                metrics_to_log.append(logger_metrics)

        # log all the metrics as a s single dict
        metrics_to_log = dict(ChainMap(*metrics_to_log))
        if len(metrics_to_log) > 0:
            self.log_metrics(metrics_to_log, {})

    def add_to_eval_loop_results(self, dl_idx, num_loaders):
        callback_metrics = deepcopy(self.callback_metrics)
        if num_loaders == 1:
            if len(self.eval_loop_results) > 0:
                self.eval_loop_results[0].update(callback_metrics)
            else:
                self.eval_loop_results.append(callback_metrics)
            return

        for key in list(callback_metrics.keys()):
            if "dataloader_idx" in key:
                if f"dataloader_idx_{dl_idx}" not in key:
                    # remove dl_idx from self.callback_metrics not belonging to this dataset.
                    del callback_metrics[key]
        self.eval_loop_results.append(callback_metrics)

    def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx,
                                        num_loaders):
        if num_loaders == 1:
            return metrics

        result = {
            f'{k}/dataloader_idx_{dataloader_idx}': v
            for k, v in metrics.items()
        }
        return result

    def _track_callback_metrics(self, eval_results, using_eval_result):
        if (len(eval_results) > 0
                and (eval_results[0] is None
                     or not isinstance(eval_results[0], Result))):
            return

        if using_eval_result:
            if isinstance(eval_results, list):
                for eval_result in eval_results:
                    self.trainer.logger_connector.callback_metrics.update(
                        eval_result.callback_metrics)
            else:
                self.trainer.logger_connector.callback_metrics.update(
                    eval_results.callback_metrics)
        else:
            flat = {}
            if isinstance(eval_results, list):
                for eval_result in eval_results:
                    # with a scalar return, auto set it to "val_loss" for callbacks
                    if isinstance(eval_result, torch.Tensor):
                        flat = {'val_loss': eval_result}
                    elif isinstance(eval_result, dict):
                        flat = flatten_dict(eval_result)

                    # removing val_loss magic word to map to checkpoint + ES callback
                    if 'val_loss' in flat:
                        flat['checkpoint_on'] = flat['val_loss']
                        flat['early_stop_on'] = flat['val_loss']
                    self.trainer.logger_connector.callback_metrics.update(flat)
            else:
                # with a scalar return, auto set it to "val_loss" for callbacks
                if isinstance(eval_results, torch.Tensor):
                    flat = {'val_loss': eval_results}
                else:
                    flat = flatten_dict(eval_results)

                # removing val_loss magic word to map to checkpoint + ES callback
                if 'val_loss' in flat:
                    flat['checkpoint_on'] = flat['val_loss']
                    flat['early_stop_on'] = flat['val_loss']
                self.trainer.logger_connector.callback_metrics.update(flat)

    def __process_eval_epoch_end_results_and_log_legacy_update(
            self, prog_bar_metrics, log_metrics, callback_metrics):
        # eval loop returns all metrics
        dataloader_result_metrics = {
            **prog_bar_metrics,
            **log_metrics,
            **callback_metrics
        }

        # add metrics to prog bar
        self.trainer.logger_connector.add_progress_bar_metrics(
            prog_bar_metrics)

        # log metrics
        if len(log_metrics) > 0:
            self.trainer.logger_connector.log_metrics(log_metrics, {})

        # track metrics for callbacks (all prog bar, logged and callback metrics)
        self.trainer.logger_connector.callback_metrics.update(callback_metrics)
        self.trainer.logger_connector.callback_metrics.update(log_metrics)
        self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics)

        if len(dataloader_result_metrics) > 0:
            self.eval_loop_results.append(dataloader_result_metrics)

    def __process_eval_epoch_end_results_and_log_legacy(
            self, eval_results, test_mode):
        if self.trainer.running_sanity_check:
            return

        if eval_results is not None and len(eval_results) > 0:

            # in eval, the user may return something at every validation step without final reduction
            if not isinstance(eval_results, list):
                eval_results = [eval_results]

            num_loaders: int = self.trainer.evaluation_loop.num_dataloaders
            prog_bar_metrics, log_metrics, callback_metrics = {}, {}, {}

            for result_idx, result in enumerate(eval_results):
                if isinstance(result, EvalResult):
                    prog_bar_metrics = result.epoch_pbar_metrics
                    log_metrics = result.epoch_log_metrics
                    callback_metrics = result.callback_metrics

                    # in testing we don't need the callback metrics
                    if test_mode:
                        callback_metrics = {}
                else:
                    _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(
                        result)

                if num_loaders > 1:
                    self.__process_eval_epoch_end_results_and_log_legacy_update(
                        prog_bar_metrics, log_metrics, callback_metrics)

            if num_loaders == 1:
                self.__process_eval_epoch_end_results_and_log_legacy_update(
                    prog_bar_metrics, log_metrics, callback_metrics)

    def on_train_epoch_end(self, epoch_output):
        pass

    def log_train_epoch_end_metrics(self, epoch_output, checkpoint_accumulator,
                                    early_stopping_accumulator,
                                    num_optimizers):
        # epoch output is a list. Each item in that list has all the outputs per optimizer
        # epoch_output[optimizer_idx][training_step_idx][tbptt_index]
        # remember that not using truncated backprop is equivalent with truncated back prop of len(1)

        model = self.trainer.get_model()

        epoch_callback_metrics = {}

        # -----------------------
        # Calculate epoch callback values if given
        # -----------------------
        if checkpoint_accumulator.num_values > 0:
            epoch_callback_metrics[
                'checkpoint_on'] = checkpoint_accumulator.mean()

        if early_stopping_accumulator.num_values > 0:
            epoch_callback_metrics[
                'early_stop_on'] = early_stopping_accumulator.mean()

        # ------------------------
        # determine if using a result obj
        # ------------------------
        # [optimizer_idx][training_step_idx][tbptt_index]
        opt_idx_outputs = epoch_output[0]

        # TODO: deprecate 1.0
        try:
            sample_obj = opt_idx_outputs[0][0] if isinstance(
                opt_idx_outputs[0], list) else opt_idx_outputs[0]
            is_result_obj = len(epoch_output) > 0 and isinstance(
                sample_obj, Result)
            is_1_0_result = is_result_obj and 'extra' in sample_obj
        except IndexError as e:
            is_result_obj = False
            is_1_0_result = False

        # ------------------
        # NEW 1.0.0 PATH
        # ------------------
        if is_1_0_result:
            # lightning module hook
            epoch_end_log_result = self.training_epoch_end(
                model, epoch_output, num_optimizers)

            # log/aggregate metrics automatically
            epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(
                epoch_output)
            epoch_log_metrics.update(
                epoch_end_log_result.get_epoch_log_metrics())
            epoch_progress_bar_metrics.update(
                epoch_end_log_result.get_epoch_pbar_metrics())

        # TODO: deprecate 1.0
        else:
            out = self.__run_legacy_training_epoch_end(num_optimizers,
                                                       epoch_output, model,
                                                       is_result_obj,
                                                       epoch_callback_metrics)
            epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out

        # --------------------------
        # track results
        # --------------------------
        # add the metrics to the loggers and callbacks
        if epoch_log_metrics and len(epoch_log_metrics) > 0:
            self.log_metrics(epoch_log_metrics, {})
            self.callback_metrics.update(epoch_log_metrics)

        # add metrics to callbacks
        self.callback_metrics.update(epoch_callback_metrics)

        # add metrics to progress_bar and callbacks
        if len(epoch_progress_bar_metrics) > 0:
            self.add_progress_bar_metrics(epoch_progress_bar_metrics)
            self.callback_metrics.update(epoch_progress_bar_metrics)

    def training_epoch_end(self, model, epoch_output, num_optimizers):
        if not is_overridden('training_epoch_end', model=model):
            return Result()

        # run training_epoch_end
        # refresh the result for custom logging at the epoch level
        model._current_fx_name = 'training_epoch_end'
        model._results = Result()

        epoch_output = self.__prepare_epoch_end_inputs(epoch_output)

        if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization:
            epoch_output = epoch_output[0]

        # lightningmodule hook
        epoch_output = model.training_epoch_end(epoch_output)

        model._current_fx_name = ''

        if epoch_output is not None:
            raise MisconfigurationException(
                'training_epoch_end expects a return of None. '
                'HINT: remove the return statement in training_epoch_end')

        # user can ALSO log at the end of an epoch
        new_epoch_end_logs = model._results
        return new_epoch_end_logs

    def __run_legacy_training_epoch_end(self, num_optimizers, epoch_output,
                                        model, is_result_obj,
                                        epoch_callback_metrics):

        epoch_log_metrics = {}
        epoch_progress_bar_metrics = {}

        # --------------------------
        # EPOCH END STEP IF DEFINED
        # --------------------------
        if is_overridden('training_epoch_end', model=model):
            if is_result_obj:
                # with result object gather across time and training steps so each opt idx has a single result obj
                epoch_output = self.__gather_result_across_time_and_optimizers(
                    epoch_output)

            if num_optimizers == 1:
                epoch_output = epoch_output[0]

            # run training_epoch_end
            # a list with a result per optimizer index
            epoch_output = model.training_epoch_end(epoch_output)

            if isinstance(epoch_output, Result):
                epoch_log_metrics = epoch_output.epoch_log_metrics
                epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
            else:
                _processed_outputs = self.trainer.process_dict_result(
                    epoch_output)
                epoch_progress_bar_metrics = _processed_outputs[1]
                epoch_log_metrics = _processed_outputs[2]
                epoch_callback_metrics = _processed_outputs[3]

        # --------------------------
        # Structured Result (auto epoch end)
        # --------------------------
        elif is_result_obj:
            epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(
                epoch_output)

        return epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics

    def __auto_reduce_results_on_epoch_end(self, epoch_output):
        epoch_log_metrics = {}
        epoch_progress_bar_metrics = {}
        for opt_outputs in epoch_output:
            # reduce across time first
            time_reduced_outputs = []
            for tbptt_outs in opt_outputs:
                tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(
                    tbptt_outs)
                if len(tbptt_outs) > 1:
                    time_reduced_outputs.append(tbptt_outs)

            if len(time_reduced_outputs) == 0:
                continue

            # reduce across training steps
            opt_outputs = time_reduced_outputs[
                0].__class__.reduce_on_epoch_end(time_reduced_outputs)

            # with manual opt need 1+ metrics because meta is always there
            if opt_outputs.minimize is not None:
                opt_outputs.minimize = opt_outputs.minimize.mean()
            epoch_log_metrics.update(opt_outputs.epoch_log_metrics)
            epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics)

        return epoch_log_metrics, epoch_progress_bar_metrics

    def __prepare_epoch_end_inputs(self, epoch_output):
        """
        Pulls out only the "extra" information for epoch end

        Return:
            a single list, each element per optimizer then batch then time
        """
        gathered_epoch_outputs = []
        for opt_outputs in epoch_output:
            # gather across time first
            time_gathered_outputs = []
            for tbptt_outs in opt_outputs:
                result = []
                for x in tbptt_outs:
                    out = x.extra
                    out['loss'] = x.minimize
                    result.append(out)

                # when time = 0, pass in the literal dict instead of array
                if len(result) == 1:
                    result = result[0]
                time_gathered_outputs.append(result)

            gathered_epoch_outputs.append(time_gathered_outputs)

        return gathered_epoch_outputs

    def __gather_result_across_time_and_optimizers(self, epoch_output):
        """
        Gather results into a single padded tensor per metric where each tensor is gathered across
        time and across time steps.

        Returns:
            a list where each element is a Result with the tensors gathered
        """
        gathered_epoch_outputs = []
        for opt_outputs in epoch_output:
            # gather across time first
            time_gathered_outputs = []
            for tbptt_outs in opt_outputs:
                tbptt_outs = tbptt_outs[0].__class__.gather(tbptt_outs)
                time_gathered_outputs.append(tbptt_outs)

            # gather across training steps
            # each metric has dimensions (training_steps, seq_len) (seq_len=1 when no tbptt is used)
            gathered_opt_output = time_gathered_outputs[
                0].__class__.padded_gather(time_gathered_outputs)
            gathered_epoch_outputs.append(gathered_opt_output)

        return gathered_epoch_outputs

    def log_train_step_metrics(self, batch_output):
        # when metrics should be logged
        should_log_metrics = ((self.trainer.global_step + 1) %
                              self.trainer.log_every_n_steps == 0
                              or self.trainer.should_stop)
        if should_log_metrics or self.trainer.fast_dev_run:
            # logs user requested information to logger
            metrics = batch_output.batch_log_metrics
            grad_norm_dic = batch_output.grad_norm_dic
            if len(metrics) > 0 or len(grad_norm_dic) > 0:
                self.log_metrics(metrics, grad_norm_dic)
                self.callback_metrics.update(metrics)