Exemplo n.º 1
0
    def __init__(self, Fm, Gm=None, coupling='additive', keep_input=True, implementation_fwd=0, implementation_bwd=0):
        """The ReversibleBlock

        Parameters
        ----------
            Fm : torch.nn.Module
                A torch.nn.Module encapsulating an arbitrary function

            Gm : torch.nn.Module
                A torch.nn.Module encapsulating an arbitrary function
                (If not specified a deepcopy of Gm is used as a Module)

            coupling: str
                Type of coupling ['additive', 'affine']. Default = 'additive'

            keep_input : bool
                Retain the input information, by default it can be discarded since it will be
                reconstructed upon the backward pass.

            implementation_fwd : int
                Switch between different Operation implementations for forward training. Default = 1

            implementation_bwd : int
                Switch between different Operation implementations for backward training. Default = 1

        """
        super(ReversibleBlock, self).__init__()

        if coupling == 'additive':
            self.rev_block = AdditiveBlock(Fm, Gm, keep_input, implementation_fwd, implementation_bwd)
        elif coupling == 'affine':
            self.rev_block = AffineBlock(Fm, Gm, keep_input, implementation_fwd, implementation_bwd)
        else:
            raise NotImplementedError('Unknown coupling method: %s' % coupling)
Exemplo n.º 2
0
    def __init__(self,
                 Fm,
                 Gm=None,
                 coupling='additive',
                 keep_input=False,
                 implementation_fwd=1,
                 implementation_bwd=1):
        """The ReversibleBlock

        Parameters
        ----------
            Fm : torch.nn.Module
                A torch.nn.Module encapsulating an arbitrary function

            Gm : torch.nn.Module
                A torch.nn.Module encapsulating an arbitrary function
                (If not specified a deepcopy of Fm is used as a Module)

            coupling: str
                Type of coupling ['additive', 'affine']. Default = 'additive'
                'affine' is currently experimental

            keep_input : bool
                Retain the input information, by default it can be discarded since it will be
                reconstructed upon the backward pass.

            implementation_fwd : int
                Switch between different Operation implementations for forward training. Default = 1
                -1 : Naive implementation without reconstruction on the backward pass (keep_input should be True)
                 0 : Memory efficient implementation, compute gradients directly on y
                 1 : Memory efficient implementation, similar to approach in Gomez et al. 2017

            implementation_bwd : int
                Switch between different Operation implementations for backward training. Default = 1
                -1 : Naive implementation without reconstruction on the backward pass (keep_input should be True)
                 0 : Memory efficient implementation, compute gradients directly on y
                 1 : Memory efficient implementation, similar to approach in Gomez et al. 2017

        """
        super(ReversibleBlock, self).__init__()
        self.keep_input = keep_input
        if coupling == 'additive':
            self.rev_block = AdditiveBlock(Fm, Gm, implementation_fwd,
                                           implementation_bwd)
        elif coupling == 'affine':
            self.rev_block = AffineBlock(Fm, Gm, implementation_fwd,
                                         implementation_bwd)
        else:
            raise NotImplementedError('Unknown coupling method: %s' % coupling)
Exemplo n.º 3
0
    def __init__(self,
                 Fm,
                 Gm=None,
                 coupling='additive',
                 keep_input=False,
                 keep_input_inverse=False,
                 implementation_fwd=1,
                 implementation_bwd=1,
                 adapter=None):
        """The ReversibleBlock

        Note
        ----
        The `implementation_fwd` and `implementation_bwd` parameters can be set to one of the following implementations:

        * -1 Naive implementation without reconstruction on the backward pass.
        * 0  Memory efficient implementation, compute gradients directly.
        * 1  Memory efficient implementation, similar to approach in Gomez et al. 2017.


        Parameters
        ----------
            Fm : :obj:`torch.nn.Module`
                A torch.nn.Module encapsulating an arbitrary function

            Gm : :obj:`torch.nn.Module`, optional
                A torch.nn.Module encapsulating an arbitrary function
                (If not specified a deepcopy of Fm is used as a Module)

            coupling : :obj:`str`, optional
                Type of coupling ['additive', 'affine']. Default = 'additive'

            keep_input : :obj:`bool`, optional
                Set to retain the input information on forward, by default it can be discarded since it will be
                reconstructed upon the backward pass.

            keep_input_inverse : :obj:`bool`, optional
                Set to retain the input information on inverse, by default it can be discarded since it will be
                reconstructed upon the backward pass.

            implementation_fwd : :obj:`int`, optional
                Switch between different Operation implementations for forward training (Default = 1).
                If using the naive implementation (-1) then `keep_input` should be True.

            implementation_bwd : :obj:`int`, optional
                Switch between different Operation implementations for backward training (Default = 1).
                If using the naive implementation (-1) then `keep_input_inverse` should be True.

            adapter : :obj:`class`, optional
                Only relevant when using the 'affine' coupling.
                Should be a class of type :obj:`torch.nn.Module` that serves as an
                optional wrapper class A for Fm and Gm which must output
                s, t = A(x) with shape(s) = shape(t) = shape(x).
                s, t are respectively the scale and shift tensors for the affine coupling.

        Attributes
        ----------
            keep_input : :obj:`bool`, optional
                Set to retain the input information on forward, by default it can be discarded since it will be
                reconstructed upon the backward pass.

            keep_input_inverse : :obj:`bool`, optional
                Set to retain the input information on inverse, by default it can be discarded since it will be
                reconstructed upon the backward pass.

        Raises
        ------
        NotImplementedError
            If an unknown coupling or implementation is given.

        """
        super(ReversibleBlock, self).__init__()
        self.keep_input = keep_input
        self.keep_input_inverse = keep_input_inverse
        if coupling == 'additive':
            self.rev_block = AdditiveBlock(
                Fm,
                Gm,
                implementation_fwd=implementation_fwd,
                implementation_bwd=implementation_bwd)
        elif coupling == 'affine':
            self.rev_block = AffineBlock(Fm,
                                         Gm,
                                         adapter=adapter,
                                         implementation_fwd=implementation_fwd,
                                         implementation_bwd=implementation_bwd)
        else:
            raise NotImplementedError('Unknown coupling method: %s' % coupling)
Exemplo n.º 4
0
class ReversibleBlock(nn.Module):
    def __init__(self,
                 Fm,
                 Gm=None,
                 coupling='additive',
                 keep_input=False,
                 implementation_fwd=1,
                 implementation_bwd=1):
        """The ReversibleBlock

        Parameters
        ----------
            Fm : torch.nn.Module
                A torch.nn.Module encapsulating an arbitrary function

            Gm : torch.nn.Module
                A torch.nn.Module encapsulating an arbitrary function
                (If not specified a deepcopy of Fm is used as a Module)

            coupling: str
                Type of coupling ['additive', 'affine']. Default = 'additive'

            keep_input : bool
                Retain the input information, by default it can be discarded since it will be
                reconstructed upon the backward pass.

            implementation_fwd : int
                Switch between different Operation implementations for forward training. Default = 1
                -1 : Naive implementation without reconstruction on the backward pass (keep_input should be True)
                 0 : Memory efficient implementation, compute gradients directly on y
                 1 : Memory efficient implementation, similar to approach in Gomez et al. 2017

            implementation_bwd : int
                Switch between different Operation implementations for backward training. Default = 1
                -1 : Naive implementation without reconstruction on the backward pass (keep_input should be True)
                 0 : Memory efficient implementation, compute gradients directly on y
                 1 : Memory efficient implementation, similar to approach in Gomez et al. 2017

        """
        super(ReversibleBlock, self).__init__()
        self.keep_input = keep_input
        if coupling == 'additive':
            self.rev_block = AdditiveBlock(Fm, Gm, implementation_fwd,
                                           implementation_bwd)
        elif coupling == 'affine':
            self.rev_block = AffineBlock(Fm, Gm, implementation_fwd,
                                         implementation_bwd)
        else:
            raise NotImplementedError('Unknown coupling method: %s' % coupling)

    def forward(self, x):
        y = self.rev_block(x)
        # clears the referenced storage data linked to the input tensor as it can be reversed on the backward pass
        if not self.keep_input:
            if not pytorch_version_one_and_above:
                # PyTorch 0.4 way to clear storage
                x.data.set_()
            else:
                # PyTorch 1.0+ way to clear storage
                x.storage().resize_(0)

        return y

    def inverse(self, y):
        x = self.rev_block.inverse(y)
        # clears the referenced storage data linked to the input tensor as it can be reversed on the backward pass
        if not self.keep_input:
            if not pytorch_version_one_and_above:
                # PyTorch 0.4 way to clear storage
                y.data.set_()
            else:
                # PyTorch 1.0+ way to clear storage
                y.storage().resize_(0)

        return x
Exemplo n.º 5
0
def test_legacy_affine_coupling():
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=DeprecationWarning)
        AffineBlock(Fm=SubModule())