Esempio n. 1
0
    def process_dict_result(self, output, train=False):
        """Reduces output according to the training mode.

        Separates loss from logging and progress bar metrics
        """
        # --------------------
        # WARN DEPRECATED KEYS
        # --------------------
        # TODO: 1.0.0 remove
        if isinstance(output, dict):
            for k, v in output.items():
                if k in ['log', 'progress_bar']:
                    m = inspect.cleandoc(
                        f"The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0\n"
                        " Please use self.log(...) inside the lightningModule instead.\n"
                        " # log on a step or aggregate epoch metric to the logger and/or progress bar"
                        " (inside LightningModule)\n"
                        " self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)"
                    )
                    rank_zero_warn(m)

        # --------------------------
        # handle single scalar only
        # --------------------------
        # single scalar returned from a xx_step
        if isinstance(output, torch.Tensor):
            progress_bar_metrics = {}
            log_metrics = {}
            callback_metrics = {}
            hiddens = None
            return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens

        # ---------------
        # EXTRACT CALLBACK KEYS
        # ---------------
        # all keys not progress_bar or log are candidates for callbacks
        callback_metrics = {}
        if isinstance(output, Mapping):
            for k, v in output.items():
                if k not in ['progress_bar', 'log', 'hiddens']:
                    callback_metrics[k] = v

        if train and self._distrib_type in (DistributedType.DP,
                                            DistributedType.DDP2):
            num_gpus = self.num_gpus
            callback_metrics = self.reduce_distributed_output(
                callback_metrics, num_gpus)

        # ---------------
        # EXTRACT PROGRESS BAR KEYS
        # ---------------
        try:
            progress_output = output['progress_bar']

            # reduce progress metrics for progress bar when using dp
            if train and self._distrib_type in (DistributedType.DP,
                                                DistributedType.DDP2):
                num_gpus = self.num_gpus
                progress_output = self.reduce_distributed_output(
                    progress_output, num_gpus)

            progress_bar_metrics = progress_output
        # todo: specify the possible exception
        except Exception:
            progress_bar_metrics = {}

        # ---------------
        # EXTRACT LOGGING KEYS
        # ---------------
        # extract metrics to log to experiment
        try:
            log_output = output['log']

            # reduce progress metrics for progress bar when using dp
            if train and self._distrib_type in (DistributedType.DP,
                                                DistributedType.DDP2):
                num_gpus = self.num_gpus
                log_output = self.reduce_distributed_output(
                    log_output, num_gpus)

            log_metrics = log_output
        # todo: specify the possible exception
        except Exception:
            log_metrics = {}

        # ---------------
        # EXTRACT LOSS
        # ---------------
        # if output dict doesn't have the keyword loss
        # then assume the output=loss if scalar
        loss = None
        if train:
            try:
                loss = output['loss']
            # todo: specify the possible exception
            except Exception as exp:
                if isinstance(output, torch.Tensor):
                    loss = output
                else:
                    raise RuntimeError(
                        'No `loss` value in the dictionary returned from `model.training_step()`.'
                    ) from exp

            # when using dp need to reduce the loss
            if self._distrib_type in (DistributedType.DP,
                                      DistributedType.DDP2):
                loss = self.reduce_distributed_output(loss, self.num_gpus)

        # ---------------
        # EXTRACT HIDDEN
        # ---------------
        hiddens = output.get('hiddens', None) if isinstance(output,
                                                            Mapping) else None

        # use every metric passed in as a candidate for callback
        callback_metrics.update(progress_bar_metrics)
        callback_metrics.update(log_metrics)

        # detach all metrics for callbacks to prevent memory leaks
        # no .item() because it will slow things down
        callback_metrics = recursive_detach(callback_metrics)
        progress_bar_metrics = recursive_detach(progress_bar_metrics)
        log_metrics = recursive_detach(log_metrics)

        return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
Esempio n. 2
0
    def log(
        self,
        fx: str,
        name: str,
        value: _METRIC_COLLECTION,
        prog_bar: bool = False,
        logger: bool = True,
        on_step: bool = False,
        on_epoch: bool = True,
        reduce_fx: Callable = torch.mean,
        enable_graph: bool = False,
        sync_dist: bool = False,
        sync_dist_fn: Callable = _Sync.no_op,
        sync_dist_group: Optional[Any] = None,
        dataloader_idx: Optional[int] = None,
        batch_size: Optional[int] = None,
        metric_attribute: Optional[str] = None,
        rank_zero_only: bool = False,
    ) -> None:
        """See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
        # no metrics should be logged with graphs
        if not enable_graph:
            value = recursive_detach(value)

        # move metrics to cpu on TPU.
        if isinstance(value, torch.Tensor) and value.device.type == "xla":
            value = value.cpu()

        # storage key
        key = f"{fx}.{name}"
        # add dataloader_suffix to both key and fx
        if dataloader_idx is not None:
            key += f".{dataloader_idx}"
            fx += f".{dataloader_idx}"

        meta = _Metadata(
            fx=fx,
            name=name,
            prog_bar=prog_bar,
            logger=logger,
            on_step=on_step,
            on_epoch=on_epoch,
            reduce_fx=reduce_fx,
            enable_graph=enable_graph,
            dataloader_idx=dataloader_idx,
            metric_attribute=metric_attribute,
        )
        meta.sync = _Sync(_should=sync_dist,
                          fn=sync_dist_fn,
                          _group=sync_dist_group,
                          rank_zero_only=rank_zero_only)

        # register logged value if it doesn't exist
        if key not in self:
            self.register_key(key, meta, value)

        # check the stored metadata and the current one match
        elif meta != self[key].meta:
            raise MisconfigurationException(
                f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
            )

        if batch_size is not None:
            self.batch_size = batch_size

        self.update_metrics(key, value)
Esempio n. 3
0
    def process_dict_result(self, output, train=False):
        """Reduces output according to the training mode.

        Separates loss from logging and progress bar metrics
        """
        # --------------------------
        # handle single scalar only
        # --------------------------
        # single scalar returned from a xx_step
        if isinstance(output, torch.Tensor):
            progress_bar_metrics = {}
            log_metrics = {}
            callback_metrics = {}
            hiddens = None
            return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens

        # ---------------
        # EXTRACT CALLBACK KEYS
        # ---------------
        # all keys not progress_bar or log are candidates for callbacks
        callback_metrics = {}
        if output:
            for k, v in output.items():
                if k not in ['progress_bar', 'log', 'hiddens']:
                    callback_metrics[k] = v

        if train and (self.use_dp or self.use_ddp2):
            num_gpus = self.num_gpus
            callback_metrics = self.reduce_distributed_output(
                callback_metrics, num_gpus)

        # ---------------
        # EXTRACT PROGRESS BAR KEYS
        # ---------------
        try:
            progress_output = output['progress_bar']

            # reduce progress metrics for progress bar when using dp
            if train and (self.use_dp or self.use_ddp2):
                num_gpus = self.num_gpus
                progress_output = self.reduce_distributed_output(
                    progress_output, num_gpus)

            progress_bar_metrics = progress_output
        except Exception:
            progress_bar_metrics = {}

        # ---------------
        # EXTRACT LOGGING KEYS
        # ---------------
        # extract metrics to log to experiment
        try:
            log_output = output['log']

            # reduce progress metrics for progress bar when using dp
            if train and (self.use_dp or self.use_ddp2):
                num_gpus = self.num_gpus
                log_output = self.reduce_distributed_output(
                    log_output, num_gpus)

            log_metrics = log_output
        except Exception:
            log_metrics = {}

        # ---------------
        # EXTRACT LOSS
        # ---------------
        # if output dict doesn't have the keyword loss
        # then assume the output=loss if scalar
        loss = None
        if train:
            try:
                loss = output['loss']
            except Exception:
                if isinstance(output, torch.Tensor):
                    loss = output
                else:
                    raise RuntimeError(
                        'No `loss` value in the dictionary returned from `model.training_step()`.'
                    )

            # when using dp need to reduce the loss
            if self.use_dp or self.use_ddp2:
                loss = self.reduce_distributed_output(loss, self.num_gpus)

        # ---------------
        # EXTRACT HIDDEN
        # ---------------
        hiddens = output.get('hiddens') if output else None

        # use every metric passed in as a candidate for callback
        callback_metrics.update(progress_bar_metrics)
        callback_metrics.update(log_metrics)

        # detach all metrics for callbacks to prevent memory leaks
        # no .item() because it will slow things down
        callback_metrics = recursive_detach(callback_metrics)

        # replace loss with checkpoint_on
        if 'loss' in callback_metrics:
            callback_metrics['checkpoint_on'] = callback_metrics['loss']
            callback_metrics['early_stop_on'] = callback_metrics['loss']
            del callback_metrics['loss']

        if 'val_loss' in callback_metrics:
            callback_metrics['checkpoint_on'] = callback_metrics['val_loss']
            callback_metrics['early_stop_on'] = callback_metrics['val_loss']
            del callback_metrics['val_loss']

        return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
Esempio n. 4
0
    def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer,
                          hiddens):
        """
        wrap the forward step in a closure so second order methods work
        """
        # ---------------------------
        # FORWARD
        # ---------------------------
        with self.profiler.profile('model_forward'):
            if self.use_amp and self.use_native_amp:
                with torch.cuda.amp.autocast():
                    training_step_output = self.training_forward(
                        split_batch, batch_idx, opt_idx, hiddens)
            else:
                training_step_output = self.training_forward(
                    split_batch, batch_idx, opt_idx, hiddens)

            # ----------------------------
            # PROCESS THE RESULT
            # ----------------------------
            # format and reduce outputs accordingly
            training_step_output = self.process_output(training_step_output,
                                                       train=True)

            # TODO: temporary part of structured results PR
            training_step_output = AttributeDict(
                batch_loss=training_step_output[0],
                pbar_on_batch_end=training_step_output[1],
                log_metrics=training_step_output[2],
                callback_metrics=training_step_output[3],
                hiddens=training_step_output[4],
            )

            # if the user decides to finally reduce things in epoch_end, save raw output without graphs
            training_step_output_for_epoch_end = recursive_detach(
                training_step_output)

        # accumulate loss
        # (if accumulate_grad_batches = 1 no effect)
        closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches

        # backward pass
        model_ref = self.get_model()
        with self.profiler.profile('model_backward'):
            # scale loss for 16 bit
            if self.precision == 16 and not self.on_tpu:
                closure_loss = model_ref.amp_scale_loss(
                    closure_loss, optimizer, opt_idx)

            # do backward pass
            model_ref.backward(self, closure_loss, optimizer, opt_idx)

            # once backward has been applied, release graph
            closure_loss = closure_loss.detach()
            training_step_output.batch_loss = training_step_output.batch_loss.detach(
            )

        if self.use_horovod:
            # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
            optimizer.synchronize()

        # insert after step hook
        if self.is_function_implemented('on_after_backward'):
            model_ref = self.get_model()
            with self.profiler.profile('on_after_backward'):
                model_ref.on_after_backward()

        result = AttributeDict(
            loss=closure_loss,
            training_step_output=training_step_output,
            training_step_output_for_epoch_end=
            training_step_output_for_epoch_end,
            hiddens=training_step_output.hiddens,
        )
        return result
Esempio n. 5
0
    def optimizer_closure(self,
                          split_batch,
                          batch_idx,
                          opt_idx,
                          optimizer,
                          hiddens,
                          gsm_p,
                          gsm_loss,
                          K,
                          beta=0.01):
        """
        wrap the forward step in a closure so second order methods work
        """
        # ---------------------------
        # FORWARD
        # ---------------------------
        with self.profiler.profile('model_forward'):
            if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
                with torch.cuda.amp.autocast():
                    training_step_output = self.training_forward(
                        split_batch, batch_idx, opt_idx, hiddens, gsm_p,
                        gsm_loss, K)
            else:
                training_step_output = self.training_forward(
                    split_batch, batch_idx, opt_idx, hiddens, gsm_p, gsm_loss,
                    K)

            # ----------------------------
            # PROCESS THE RESULT
            # ----------------------------
            # format and reduce outputs accordingly
            training_step_output_for_epoch_end = training_step_output
            training_step_output = self.process_output(training_step_output,
                                                       train=True)

            training_step_output = AttributeDict(
                batch_loss=training_step_output[0],
                pbar_on_batch_end=training_step_output[1],
                log_metrics=training_step_output[2],
                callback_metrics=training_step_output[3],
                hiddens=training_step_output[4],
            )

            # if the user decides to finally reduce things in epoch_end, save raw output without graphs
            training_step_output_for_epoch_end = recursive_detach(
                training_step_output_for_epoch_end)

        # accumulate loss
        # (if accumulate_grad_batches = 1 no effect)
        ## todo: check self.accumulate_grad_batches
        closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches

        # ----------------------------
        # Calculate total loss
        # ----------------------------
        # closure_loss = (1 - beta) * closure_loss + beta * gsm_loss

        # the loss will get scaled for amp. avoid any modifications to it
        untouched_loss = closure_loss.detach().clone()

        # backward pass
        model_ref = self.get_model()
        with self.profiler.profile('model_backward'):
            # scale loss for 16 bit
            if self.precision == 16 and not self.on_tpu:
                closure_loss = model_ref.amp_scale_loss(
                    closure_loss, optimizer, opt_idx)

                # enter amp context
                if not NATIVE_AMP_AVALAIBLE:
                    context = closure_loss
                    closure_loss = closure_loss.__enter__()

            # do backward pass
            model_ref.backward(self, closure_loss, optimizer, opt_idx)

            # exit amp context
            if self.precision == 16 and not NATIVE_AMP_AVALAIBLE and not self.on_tpu:
                a, b, c = None, None, None
                error = context.__exit__(a, b, c)
                if error:
                    rank_zero_warn(a, b, c)
                    raise Exception('apex unscale error')

            # once backward has been applied, release graph
            closure_loss = closure_loss.detach()
            training_step_output.batch_loss = training_step_output.batch_loss.detach(
            )

        if self.use_horovod:
            # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
            optimizer.synchronize()

        # insert after step hook
        if self.is_function_implemented('on_after_backward'):
            model_ref = self.get_model()
            with self.profiler.profile('on_after_backward'):
                model_ref.on_after_backward()

        result = AttributeDict(
            loss=untouched_loss,
            training_step_output=training_step_output,
            training_step_output_for_epoch_end=
            training_step_output_for_epoch_end,
            hiddens=training_step_output.hiddens,
        )
        return result
Esempio n. 6
0
    def train_batch(self, batch, batch_info):
        # Get the original PTL module.
        model = self.get_model()
        optimizer = self.optimizers[0]
        batch_idx = batch_info["batch_idx"]
        epoch_idx = batch_info["epoch_idx"]

        if self.is_function_implemented("on_train_batch_start", model):
            response = model.on_train_batch_start(batch=batch,
                                                  batch_idx=batch_idx,
                                                  dataloader_idx=0)
            # Skip remainder of epoch if response is -1.
            if response == -1:
                return {"signal": -1}

        args = [batch, batch_idx]
        if len(self.optimizers) > 1:
            if self.has_arg("training_step", "optimizer_idx"):
                args.append(0)

        with self.timers.record("fwd"):
            if self._is_distributed:
                # Use the DDP wrapped model (self.model).
                output = self.model(*args)
            elif self.use_gpu:
                # Using single GPU.
                # Don't copy the batch since there is a single gpu that
                # the batch could be referenced from and if there are
                # multiple optimizers the batch will wind up copying it to
                # the same device repeatedly.
                device = self.device
                batch = model.transfer_batch_to_device(batch, device=device)
                args[0] = batch
                output = model.training_step(*args)
            else:
                # Using CPU.
                output = model.training_step(*args)

        if isinstance(output, Result):
            raise ValueError("TrainResult objects are not supported. Please "
                             "return a dictionary instead.")

        # allow any mode to define training_step_end
        # do something will all the dp outputs (like softmax)
        if is_overridden("training_step_end", model):
            output = model.training_step_end(output)

        # Extract loss from output if dictionary.
        try:
            loss = output["loss"]
        except Exception:
            if isinstance(output, torch.Tensor):
                loss = output
            else:
                raise RuntimeError(
                    "No `loss` value in the dictionary returned from "
                    "`model.training_step()`.")

        # If output contains tensors, detach them all.
        if isinstance(output, torch.Tensor):
            output = output.detach()
        elif isinstance(output, dict):
            output = recursive_detach(output)
        else:
            raise TypeError("training_step returned invalid type. It must "
                            "return either a Tensor, Result, or dict.")

        untouched_loss = loss.detach().clone()

        with self.timers.record("grad"):
            if self.use_fp16:
                with self._amp.scale_loss(loss, optimizer) as scaled_loss:
                    model.backward(scaled_loss, optimizer, optimizer_idx=0)
            else:
                model.backward(loss, optimizer, optimizer_idx=0)

        if self.is_function_implemented("on_after_backward", model):
            model.on_after_backward()

        with self.timers.record("apply"):
            model.optimizer_step(epoch=epoch_idx,
                                 batch_idx=batch_idx,
                                 optimizer=optimizer,
                                 optimizer_idx=0)

        model.on_before_zero_grad(optimizer)

        model.optimizer_zero_grad(epoch=epoch_idx,
                                  batch_idx=batch_idx,
                                  optimizer=optimizer,
                                  optimizer_idx=0)

        if self.is_function_implemented("on_train_batch_end", model):
            model.on_train_batch_end(outputs=output,
                                     batch=batch,
                                     batch_idx=batch_idx,
                                     dataloader_idx=0)

        return {
            "signal": 0,
            "training_loss": untouched_loss.item(),
            "raw_output": output,
            # NUM_SAMPLES: len(batch)
        }
Esempio n. 7
0
    def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer,
                          hiddens):
        """
        wrap the forward step in a closure so second order methods work
        """
        # ---------------------------
        # FORWARD (TRAINING STEP + TRAIN STEP END)
        # ---------------------------
        with self.profiler.profile('model_forward'):
            if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
                with torch.cuda.amp.autocast():
                    training_step_output = self.training_forward(
                        split_batch, batch_idx, opt_idx, hiddens)
            else:
                training_step_output = self.training_forward(
                    split_batch, batch_idx, opt_idx, hiddens)

            # ----------------------------
            # PROCESS THE RESULT
            # ----------------------------
            # format and reduce outputs accordingly
            training_step_output_for_epoch_end = training_step_output
            is_result_obj = isinstance(training_step_output, Result)

            # don't allow EvalResult in the training_step
            if isinstance(training_step_output, EvalResult):
                raise MisconfigurationException(
                    'training_step cannot return EvalResult, '
                    'use a dict or TrainResult instead')

            # handle regular dicts
            if not is_result_obj:
                training_step_output = self.process_output(
                    training_step_output, train=True)

                training_step_output = AttributeDict(
                    batch_loss=training_step_output[0],
                    pbar_on_batch_end=training_step_output[1],
                    log_metrics=training_step_output[2],
                    callback_metrics=training_step_output[3],
                    hiddens=training_step_output[4],
                )

            # if the user decides to finally reduce things in epoch_end, save raw output without graphs
            if isinstance(training_step_output_for_epoch_end, torch.Tensor):
                training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach(
                )
            elif is_result_obj:
                training_step_output_for_epoch_end = copy(training_step_output)
                training_step_output_for_epoch_end.detach()
            else:
                training_step_output_for_epoch_end = recursive_detach(
                    training_step_output_for_epoch_end)

        # accumulate loss
        # (if accumulate_grad_batches = 1 no effect)
        closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss
        closure_loss = closure_loss / self.accumulate_grad_batches

        # the loss will get scaled for amp. avoid any modifications to it
        untouched_loss = closure_loss.detach().clone()

        # backward pass
        model_ref = self.get_model()
        with self.profiler.profile('model_backward'):
            # scale loss for 16 bit
            if self.precision == 16 and not self.on_tpu:
                closure_loss = model_ref.amp_scale_loss(
                    closure_loss, optimizer, opt_idx)

                # enter amp context
                if not NATIVE_AMP_AVALAIBLE:
                    context = closure_loss
                    closure_loss = closure_loss.__enter__()

            # do backward pass
            model_ref.backward(self, closure_loss, optimizer, opt_idx)

            # exit amp context
            if self.precision == 16 and not NATIVE_AMP_AVALAIBLE and not self.on_tpu:
                a, b, c = None, None, None
                error = context.__exit__(a, b, c)
                if error:
                    rank_zero_warn(a, b, c)
                    raise Exception('apex unscale error')

            # once backward has been applied, release graph
            closure_loss = closure_loss.detach()

            if is_result_obj:
                training_step_output.detach()
            else:
                training_step_output.batch_loss = training_step_output.batch_loss.detach(
                )

        if self.use_horovod:
            # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
            optimizer.synchronize()

        # insert after step hook
        if self.is_function_implemented('on_after_backward'):
            model_ref = self.get_model()
            with self.profiler.profile('on_after_backward'):
                model_ref.on_after_backward()

        # when in dev debugging track the losses
        self.dev_debugger.track_train_loss_history(batch_idx,
                                                   untouched_loss.detach())

        result = AttributeDict(
            loss=untouched_loss,
            training_step_output=training_step_output,
            training_step_output_for_epoch_end=
            training_step_output_for_epoch_end,
            hiddens=training_step_output.hiddens,
        )
        return result
Esempio n. 8
0
    def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
        with self.trainer.profiler.profile('model_forward'):
            args = self.build_train_args(split_batch, batch_idx, opt_idx,
                                         hiddens)
            training_step_output = self.trainer.accelerator_backend.training_step(
                args)
            training_step_output = self.trainer.call_hook(
                'training_step_end', training_step_output)

            # ----------------------------
            # PROCESS THE RESULT
            # ----------------------------
            # format and reduce outputs accordingly
            training_step_output_for_epoch_end = training_step_output
            is_result_obj = isinstance(training_step_output, Result)

            # track batch size for weighted average
            if is_result_obj:
                training_step_output.track_batch_size(len(split_batch))

            # don't allow EvalResult in the training_step
            if isinstance(training_step_output, EvalResult):
                raise MisconfigurationException(
                    'training_step cannot return EvalResult, '
                    'use a dict or TrainResult instead')

            # handle regular dicts
            if not is_result_obj:
                training_step_output = self.trainer.process_dict_result(
                    training_step_output, train=True)

                training_step_output = AttributeDict(
                    batch_loss=training_step_output[0],
                    pbar_on_batch_end=training_step_output[1],
                    log_metrics=training_step_output[2],
                    callback_metrics=training_step_output[3],
                    hiddens=training_step_output[4],
                )

            # if the user decides to finally reduce things in epoch_end, save raw output without graphs
            if isinstance(training_step_output_for_epoch_end, torch.Tensor):
                training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach(
                )
            elif is_result_obj:
                training_step_output_for_epoch_end = copy(training_step_output)
                training_step_output_for_epoch_end.detach()
            else:
                training_step_output_for_epoch_end = recursive_detach(
                    training_step_output_for_epoch_end)

        # accumulate loss
        # (if accumulate_grad_batches = 1 no effect)
        closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss
        closure_loss = closure_loss / self.trainer.accumulate_grad_batches

        # the loss will get scaled for amp. avoid any modifications to it
        untouched_loss = closure_loss.detach().clone()

        # result
        result = AttributeDict(
            closure_loss=closure_loss,
            loss=untouched_loss,
            training_step_output=training_step_output,
            training_step_output_for_epoch_end=
            training_step_output_for_epoch_end,
            hiddens=training_step_output.hiddens,
        )
        return result