コード例 #1
0
ファイル: module.py プロジェクト: neuralmagic/sparseml
    def register_batch_loss_hook(
        self, hook: Callable[[int, int, int, Any, Any, Dict[str, Tensor]], None]
    ):
        """
        Called after loss calculation of the batch with the following info:
        (counter, step_count, batch_size, data, pred, losses)
        where counter is passed in to the run (ex: epoch),
        step_count is the number of items run so far,
        batch_size is the number of elements fed in the batch,
        data is the data output from the loader,
        pred is the result from the model after the forward,
        losses are the resulting loss dictionary

        :param hook: the hook to add that is called into when reached in the
            batch process
        :return: a removable handle to remove the hook when desired
        """
        handle = RemovableHandle(self._batch_loss_hooks)
        self._batch_loss_hooks[handle.id] = hook

        return handle
コード例 #2
0
    def register_propagate_forward_hook(self,
                                        hook: Callable) -> RemovableHandle:
        r"""Registers a forward hook on the module.
        The hook will be called every time after :meth:`propagate` has computed
        an output.
        It should have the following signature:

        .. code-block:: python

            hook(module, inputs, output) -> None or modified output

        The hook can modify the output.
        Input keyword arguments are passed to the hook as a dictionary in
        :obj:`inputs[-1]`.

        Returns a :class:`torch.utils.hooks.RemovableHandle` that can be used
        to remove the added hook by calling :obj:`handle.remove()`.
        """
        handle = RemovableHandle(self._propagate_forward_hooks)
        self._propagate_forward_hooks[handle.id] = hook
        return handle