def cache_result(self) -> None:
        """
        This function is called after every hook and stores the result object
        """
        with self.trainer.profiler.profile("cache_result"):
            model_ref = self.trainer.lightning_module

            # extract hook results
            hook_result = model_ref._results

            if len(hook_result) == 1:
                model_ref._current_fx_name = None
                return

            info = self.info
            fx_name = info["fx_name"]

            self._internals.setdefault(fx_name, HookResultStore(fx_name))

            # attach capture batch_size
            Result.attach_batch_size(self._batch_size, hook_result)

            hook_result = hook_result.detach()
            if self.trainer.move_metrics_to_cpu:
                hook_result = hook_result.cpu()
            elif self.trainer._distrib_type == DistributedType.DP:
                hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu))

            self._internals[fx_name].append(hook_result, info)

            # update logged_metrics, progress_bar_metrics, callback_metrics
            if "epoch_end" in fx_name:
                self.update_logger_connector()

            self.reset_model()
Beispiel #2
0
    def cache_result(self) -> None:
        """
        This function is called after every hook
        and store the result object
        """
        with self.trainer.profiler.profile("cache_result"):
            model_ref = self.trainer.get_model()

            # extract hook results
            hook_result = model_ref._results

            if len(hook_result) == 1:
                model_ref._current_hook_fx_name = None
                model_ref._current_fx_name = ''
                return

            # extract model information
            fx_name, dataloader_idx = self.current_model_info()

            # add only if anything as been logged
            # default len is 1 due to _internals

            if fx_name not in self._internals:
                self._internals[fx_name] = HookResultStore(fx_name)

            extra_info = {}
            if self.has_split_and_opt_idx:
                extra_info = self.extra_info

            # attach capture batch_size
            Result.attach_batch_size(self._batch_size, hook_result)

            hook_result.detach()
            if self.trainer.move_metrics_to_cpu:
                hook_result.cpu()
            elif self.trainer.use_dp:
                hook_result.to(torch.device("cuda", self.trainer.root_gpu))

            self._internals[fx_name].append(hook_result,
                                            dataloader_idx=dataloader_idx,
                                            extra_info=extra_info)

            # update logged_metrics, progress_bar_metrics, callback_metrics
            self.update_logger_connector(fx_name)

            self.reset_model()
    def cache_result(self) -> None:
        """
        This function is called after every hook
        and store the result object
        """
        model_ref = self.trainer.get_model()

        # extract hook results
        hook_result = model_ref._results

        # extract model information
        fx_name, dataloader_idx = self.current_model_info()

        # add only if anything as been logged
        # default len is 1 due to _internals
        if len(hook_result) > 1:

            if fx_name not in self._internals:
                self._internals[fx_name] = HookResultStore(fx_name)

            extra_info = {}
            if self.has_split_and_opt_idx:
                extra_info = self.extra_info

            # attach capture batch_size
            Result.attach_batch_size(self._batch_size, hook_result)

            self._internals[fx_name].append(deepcopy(hook_result),
                                            dataloader_idx=dataloader_idx,
                                            extra_info=extra_info)

            # update logged_metrics, progress_bar_metrics, callback_metrics
            self.update_logger_connector(fx_name)

        # reset _results, fx_name
        self.reset_model()