Ejemplo n.º 1
0
    def __init__(self) -> None:
        super().__init__()
        self.accumulated_loss = TensorRunningAccum(window_length=20)
        self.running_loss = TensorRunningAccum(window_length=20)
        # the current split index when the batch gets split into chunks in truncated backprop through time
        self.split_idx: int = 0
        self.optimizer_loop = OptimizerLoop()
        self.manual_loop = ManualOptimization()

        self._outputs: _OUTPUTS_TYPE = []
        self._remaining_splits: List[Tuple[int, Any]] = []
    def __init__(self) -> None:
        super().__init__()
        self.accumulated_loss: Optional[Tensor] = None
        self.batch_outputs: Optional[List[List[STEP_OUTPUT]]] = None
        self.running_loss: TensorRunningAccum = TensorRunningAccum(
            window_length=20)
        # the current split index when the batch gets split into chunks in truncated backprop through time
        self.split_idx: Optional[int] = None
        self.optimizer_loop = OptimizerLoop()
        self.manual_loop = ManualOptimization()

        self._warning_cache: WarningCache = WarningCache()
        self._optimizer_freq_cumsum: Optional[int] = None
        self._remaining_splits: Optional[List[Any]] = None
Ejemplo n.º 3
0
class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]):
    """Runs over a single batch of data."""
    def __init__(self) -> None:
        super().__init__()
        self.accumulated_loss = TensorRunningAccum(window_length=20)
        self.running_loss = TensorRunningAccum(window_length=20)
        # the current split index when the batch gets split into chunks in truncated backprop through time
        self.split_idx: int = 0
        self.optimizer_loop = OptimizerLoop()
        self.manual_loop = ManualOptimization()

        self._outputs: _OUTPUTS_TYPE = []
        self._remaining_splits: List[Tuple[int, Any]] = []

    @property
    def done(self) -> bool:
        """Returns if all batch splits have been processed already."""
        return len(self._remaining_splits) == 0

    def connect(  # type: ignore[override]
            self,
            optimizer_loop: Optional[OptimizerLoop] = None,
            manual_loop: Optional[ManualOptimization] = None) -> None:
        if optimizer_loop is not None:
            self.optimizer_loop = optimizer_loop
        if manual_loop is not None:
            self.manual_loop = manual_loop

    def reset(self) -> None:
        """Resets the loop state."""
        self._outputs = []

    def on_run_start(self,
                     kwargs: OrderedDict) -> None:  # type: ignore[override]
        """Splits the data into tbptt splits.

        Args:
            kwargs: the kwargs passed down to the hooks.
        """
        batch = kwargs["batch"]
        self._remaining_splits = list(enumerate(
            self._tbptt_split_batch(batch)))

    def advance(self, kwargs: OrderedDict) -> None:  # type: ignore[override]
        """Runs the train step together with optimization (if necessary) on the current batch split.

        Args:
            kwargs: the kwargs passed down to the hooks.
        """
        # replace the batch with the split batch
        self.split_idx, kwargs["batch"] = self._remaining_splits.pop(0)

        self.trainer._logger_connector.on_train_split_start(self.split_idx)

        outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE,
                                _MANUAL_LOOP_OUTPUTS_TYPE]] = None  # for mypy
        # choose which loop will run the optimization
        if self.trainer.lightning_module.automatic_optimization:
            optimizers = _get_active_optimizers(
                self.trainer.optimizers, self.trainer.optimizer_frequencies,
                kwargs.get("batch_idx", 0))
            outputs = self.optimizer_loop.run(optimizers, kwargs)
        else:
            outputs = self.manual_loop.run(kwargs)
        if outputs:
            # automatic: can be empty if all optimizers skip their batches
            # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
            # then `advance` doesn't finish and an empty dict is returned
            self._outputs.append(outputs)

    def on_run_end(self) -> _OUTPUTS_TYPE:
        self.optimizer_loop._hiddens = None
        # this is not necessary as the manual loop runs for only 1 iteration, but just in case
        self.manual_loop._hiddens = None
        output, self._outputs = self._outputs, []  # free memory
        self._remaining_splits = []
        return output

    def teardown(self) -> None:
        self.optimizer_loop.teardown()
        self.manual_loop.teardown()
        # release memory
        if self.accumulated_loss.memory is not None:
            self.accumulated_loss.memory = self.accumulated_loss.memory.cpu()
        if self.running_loss.memory is not None:
            self.running_loss.memory = self.running_loss.memory.cpu()

    def _tbptt_split_batch(self, batch: Any) -> List[Any]:
        """Splits a single batch into a list of sequence steps for tbptt.

        Args:
            batch: the current batch to split
        """
        tbptt_steps = self.trainer.lightning_module.truncated_bptt_steps
        if tbptt_steps == 0:
            return [batch]

        splits = self.trainer._call_lightning_module_hook(
            "tbptt_split_batch", batch, tbptt_steps)
        return splits

    def _update_running_loss(self, current_loss: Tensor) -> None:
        """Updates the running loss value with the current value."""
        if self.trainer.lightning_module.automatic_optimization:
            # track total loss for logging (avoid mem leaks)
            self.accumulated_loss.append(current_loss)

        accumulated_loss = self.accumulated_loss.mean()

        if accumulated_loss is not None:
            # calculate running loss for display
            self.running_loss.append(self.accumulated_loss.mean() *
                                     self.trainer.accumulate_grad_batches)

        # reset for next set of accumulated grads
        self.accumulated_loss.reset()
class TrainingBatchLoop(Loop):
    """Runs over a single batch of data."""
    def __init__(self) -> None:
        super().__init__()
        self.accumulated_loss: Optional[Tensor] = None
        self.batch_outputs: Optional[List[List[STEP_OUTPUT]]] = None
        self.running_loss: TensorRunningAccum = TensorRunningAccum(
            window_length=20)
        # the current split index when the batch gets split into chunks in truncated backprop through time
        self.split_idx: Optional[int] = None
        self.optimizer_loop = OptimizerLoop()
        self.manual_loop = ManualOptimization()

        self._warning_cache: WarningCache = WarningCache()
        self._optimizer_freq_cumsum: Optional[int] = None
        self._remaining_splits: Optional[List[Any]] = None

    @property
    def done(self) -> bool:
        """Returns if all batch splits have been processed already."""
        return len(self._remaining_splits) == 0

    @property
    def optimizer_freq_cumsum(self) -> int:
        """Returns the cumulated sum of optimizer frequencies."""
        if self._optimizer_freq_cumsum is None:
            self._optimizer_freq_cumsum = np.cumsum(
                self.trainer.optimizer_frequencies)
        return self._optimizer_freq_cumsum

    def connect(self,
                optimizer_loop: Optional["Loop"] = None,
                manual_loop: Optional[ManualOptimization] = None) -> None:
        if optimizer_loop is not None:
            self.optimizer_loop = optimizer_loop
        if manual_loop is not None:
            self.manual_loop = manual_loop

    def run(self, batch: Any, batch_idx: int) -> AttributeDict:
        """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks.

        Args:
            batch: the current batch to run the train step on
            batch_idx: the index of the current batch
        """
        if batch is None:
            self._warning_cache.warn(
                "train_dataloader yielded None. If this was on purpose, ignore this warning..."
            )
            return AttributeDict(signal=0, training_step_output=[[]])

        # hook
        self.trainer.logger_connector.on_batch_start()
        response = self.trainer.call_hook("on_batch_start")
        if response == -1:
            return AttributeDict(signal=-1)

        # hook
        response = self.trainer.call_hook("on_train_batch_start", batch,
                                          batch_idx, 0)
        if response == -1:
            return AttributeDict(signal=-1)

        self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()

        super().run(batch, batch_idx)
        output = AttributeDict(signal=0,
                               training_step_output=self.batch_outputs)
        self.batch_outputs = None  # free memory
        return output

    def reset(self) -> None:
        """Resets the loop state."""
        self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))]

    def on_run_start(self, batch: Any, batch_idx: int):
        """Splits the data into tbptt splits.

        Args:
            batch: the current batch to run the trainstep on
            batch_idx: the index of the current batch
        """
        void(batch_idx)
        self._remaining_splits = list(enumerate(
            self._tbptt_split_batch(batch)))

    def advance(self, batch, batch_idx):
        """Runs the train step together with optimization (if necessary) on the current batch split.

        Args:
            batch: the current batch to run the training on (this is not the split!)
            batch_idx: the index of the current batch
        """
        void(batch)
        split_idx, split_batch = self._remaining_splits.pop(0)
        self.split_idx = split_idx

        # let logger connector extract current batch size
        self.trainer.logger_connector.on_train_split_start(
            batch_idx, split_idx, split_batch)

        if self.trainer.lightning_module.automatic_optimization:
            # in automatic optimization, hand over execution to the OptimizerLoop
            optimizers = [
                optimizer
                for _, optimizer in self.get_active_optimizers(batch_idx)
            ]
            batch_outputs = self.optimizer_loop.run(split_batch, optimizers,
                                                    batch_idx)
            # combine outputs from each optimizer
            for k in range(len(batch_outputs)):
                self.batch_outputs[k].extend(batch_outputs[k])
        else:
            # in manual optimization, hand over execution to the ManualOptimization loop
            result = self.manual_loop.run(split_batch, batch_idx)
            if result is not None and result.loss is not None:
                self.batch_outputs[0].append(result.drop_closure_loss())

    def on_run_end(self) -> None:
        self.optimizer_loop._hiddens = None
        # this is not necessary as the manual loop runs for only 1 iteration, but just in case
        self.manual_loop._hiddens = None

    def teardown(self) -> None:
        # release memory
        self._remaining_splits = None

    def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
        """Gets the number of active optimizers based on their frequency."""
        return len(self.get_active_optimizers(batch_idx))

    def _tbptt_split_batch(self, batch: Any) -> List[Any]:
        """Splits a single batch into a list of sequence steps for tbptt.

        Args:
            batch: the current batch to split
        """
        tbptt_steps = self.trainer.lightning_module.truncated_bptt_steps
        if tbptt_steps == 0:
            return [batch]

        model_ref = self.trainer.lightning_module
        with self.trainer.profiler.profile("tbptt_split_batch"):
            splits = model_ref.tbptt_split_batch(batch, tbptt_steps)
        return splits

    def _update_running_loss(self, current_loss: Tensor) -> None:
        """Updates the running loss value with the current value."""
        if self.trainer.lightning_module.automatic_optimization:
            # track total loss for logging (avoid mem leaks)
            self.accumulated_loss.append(current_loss)

        accumulated_loss = self.accumulated_loss.mean()

        if accumulated_loss is not None:
            # calculate running loss for display
            self.running_loss.append(self.accumulated_loss.mean() *
                                     self.trainer.accumulate_grad_batches)

        # reset for next set of accumulated grads
        self.accumulated_loss.reset()

    def get_active_optimizers(
            self,
            batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]:
        """Returns the currently active optimizers. When multiple optimizers are used with different frequencies,
        only one of the optimizers is active at a time.

        Returns:
            A list of tuples (opt_idx, optimizer) of currently active optimizers.
        """
        if not self.trainer.optimizer_frequencies:
            # call training_step once per optimizer
            return list(enumerate(self.trainer.optimizers))

        optimizers_loop_length = self.optimizer_freq_cumsum[-1]
        current_place_in_loop = batch_idx % optimizers_loop_length

        # find optimzier index by looking for the first {item > current_place} in the cumsum list
        opt_idx = int(
            np.argmax(self.optimizer_freq_cumsum > current_place_in_loop))
        return [(opt_idx, self.trainer.optimizers[opt_idx])]