Пример #1
0
    def test_dihedral(self):
        N = 8
        g = FlipRot2dOnR2(N, axis=np.pi / 3)

        r1 = FieldType(g, list(g.representations.values()))
        r2 = FieldType(g, list(g.representations.values()))
        # r1 = FieldType(g, [g.trivial_repr])
        # r2 = FieldType(g, [g.fibergroup.irrep(1, 0)])
        # r2 = FieldType(g, [irr for irr in g.fibergroup.irreps.values() if irr.size == 1])
        # r2 = FieldType(g, [g.regular_repr])

        s = 7
        # sigma = 0.6
        # fco = lambda r: 1. * r * np.pi
        # fco = lambda r: 2 * r
        sigma = None
        fco = None
        cl = R2Conv(r1,
                    r2,
                    s,
                    basisexpansion='blocks',
                    sigma=sigma,
                    frequencies_cutoff=fco,
                    bias=True)

        for _ in range(8):
            # cl.basisexpansion._init_weights()
            init.generalized_he_init(cl.weights.data, cl.basisexpansion)
            cl.eval()
            cl.check_equivariance()
Пример #2
0
    def test_flip(self):
        # g = Flip2dOnR2(axis=np.pi/3)
        g = Flip2dOnR2(axis=np.pi / 2)

        r1 = FieldType(g, list(g.representations.values()))
        r2 = FieldType(g, list(g.representations.values()) * 3).sorted()

        s = 9
        # sigma = 0.6
        # fco = lambda r: 1. * r * np.pi
        # fco = lambda r: 2 * r
        sigma = None
        fco = None
        cl = R2Conv(r1,
                    r2,
                    s,
                    basisexpansion='blocks',
                    sigma=sigma,
                    frequencies_cutoff=fco,
                    bias=True)

        for _ in range(32):
            # cl.basisexpansion._init_weights()
            init.generalized_he_init(cl.weights.data, cl.basisexpansion)
            cl.eval()
            cl.check_equivariance()
Пример #3
0
    def test_so2(self):
        N = 7
        g = Rot2dOnR2(-1, N)

        r1 = FieldType(g, list(g.representations.values()))
        r2 = FieldType(g, list(g.representations.values()))

        s = 7
        # sigma = 0.6
        # fco = lambda r: 1. * r * np.pi
        # fco = lambda r: 2 * r
        sigma = None
        fco = None
        cl = R2Conv(r1,
                    r2,
                    s,
                    basisexpansion='blocks',
                    sigma=sigma,
                    frequencies_cutoff=fco,
                    bias=True)

        for _ in range(8):
            # cl.basisexpansion._init_weights()
            init.generalized_he_init(cl.weights.data, cl.basisexpansion)
            cl.eval()
            cl.check_equivariance()
Пример #4
0
    def test_padding_mode_reflect(self):
        g = Flip2dOnR2(axis=np.pi / 2)

        r1 = FieldType(g, [g.trivial_repr])
        r2 = FieldType(g, [g.regular_repr])

        s = 3
        cl = R2Conv(r1,
                    r2,
                    s,
                    bias=True,
                    padding=1,
                    padding_mode='reflect',
                    initialize=False)

        for _ in range(32):
            init.generalized_he_init(cl.weights.data, cl.basisexpansion)
            cl.eval()
            cl.check_equivariance()
Пример #5
0
    def check(self, r1: FieldType, r2: FieldType):

        np.set_printoptions(precision=7, threshold=60000, suppress=True)

        assert r1.gspace == r2.gspace

        s = 7

        cl = R2Conv(r1,
                    r2,
                    s,
                    basisexpansion='blocks',
                    sigma=[0.01] + [0.6] * int(s // 2),
                    frequencies_cutoff=3.)

        ys = []
        for _ in range(1000):
            init.generalized_he_init(cl.weights.data, cl.basisexpansion)
            cl.eval()

            x = torch.randn(10, r1.size, s, s)

            xg = GeometricTensor(x, r1)
            y = cl(xg).tensor

            del xg
            del x

            ys.append(y.transpose(0, 1).reshape(r2.size, -1))

        ys = torch.cat(ys, dim=1)

        mean = ys.mean(1)
        std = ys.std(1)
        print(mean)
        print(std)

        self.assertTrue(
            torch.allclose(torch.zeros_like(mean), mean, rtol=2e-2, atol=5e-2))
        self.assertTrue(
            torch.allclose(torch.ones_like(std), std, rtol=1e-1, atol=6e-2))
Пример #6
0
    def test_cyclic(self):
        N = 8
        g = Rot2dOnR2(N)

        r1 = FieldType(g, list(g.representations.values()))
        r2 = FieldType(g, list(g.representations.values()) * 2)
        # r1 = FieldType(g, [g.trivial_repr])
        # r2 = FieldType(g, [g.regular_repr])

        s = 7
        sigma = None
        # fco = lambda r: 1. * r * np.pi
        fco = None

        cl = R2Conv(r1,
                    r2,
                    s,
                    basisexpansion='blocks',
                    sigma=sigma,
                    frequencies_cutoff=fco,
                    bias=True)
        cl.bias.data = 20 * torch.randn_like(cl.bias.data)

        for _ in range(1):
            init.generalized_he_init(cl.weights.data, cl.basisexpansion)
            cl.eval()
            cl.check_equivariance()

        cl.train()
        for _ in range(1):
            cl.check_equivariance()

        cl.eval()

        for _ in range(5):
            init.generalized_he_init(cl.weights.data, cl.basisexpansion)
            cl.eval()
            filter = cl.filter.clone()
            cl.check_equivariance()
            self.assertTrue(torch.allclose(filter, cl.filter))
Пример #7
0
    def test_padding_mode_circular(self):
        g = FlipRot2dOnR2(4, axis=np.pi / 2)

        r1 = FieldType(g, [g.trivial_repr])
        r2 = FieldType(g, [g.regular_repr])

        for mode in ['circular', 'reflect', 'replicate']:
            for s in [3, 5, 7]:
                padding = s // 2
                cl = R2Conv(r1,
                            r2,
                            s,
                            bias=True,
                            padding=padding,
                            padding_mode=mode,
                            initialize=False)

                for _ in range(10):
                    init.generalized_he_init(cl.weights.data,
                                             cl.basisexpansion)
                    cl.eval()
                    cl.check_equivariance()
Пример #8
0
    def __init__(
        self,
        in_type: FieldType,
        out_type: FieldType,
        kernel_size: int,
        padding: int = 0,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        basisexpansion: str = 'blocks',
        sigma: Union[List[float], float] = None,
        frequencies_cutoff: Union[float, Callable[[float], int]] = None,
        rings: List[float] = None,
        maximum_offset: int = None,
        recompute: bool = False,
        basis_filter: Callable[[dict], bool] = None,
    ):
        r"""
        
        
        G-steerable planar convolution mapping between the input and output :class:`~e2cnn.nn.FieldType` s specified by
        the parameters ``in_type`` and ``out_type``.
        This operation is equivariant under the action of :math:`\R^2\rtimes G` where :math:`G` is the
        :attr:`e2cnn.nn.FieldType.fibergroup` of ``in_type`` and ``out_type``.
        
        Specifically, let :math:`\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}` and
        :math:`\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}` be the representations specified by the input and output
        field types.
        Then :class:`~e2cnn.nn.R2Conv` guarantees an equivariant mapping
        
        .. math::
            \kappa \star [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [\kappa \star f] \qquad\qquad \forall g \in G, u \in \R^2
            
        where the transformation of the input and output fields are given by
 
        .. math::
            [\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\
            [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\

        The equivariance of G-steerable convolutions is guaranteed by restricting the space of convolution kernels to an
        equivariant subspace.
        As proven in `3D Steerable CNNs <https://arxiv.org/abs/1807.02547>`_, this parametrizes the *most general
        equivariant convolutional map* between the input and output fields.
        For feature fields on :math:`\R^2` (e.g. images), the complete G-steerable kernel spaces for :math:`G \leq \O2`
        is derived in `General E(2)-Equivariant Steerable CNNs <https://arxiv.org/abs/1911.08251>`_.

        During training, in each forward pass the module expand the basis of G-steerable kernels with learned weights
        before calling :func:`torch.nn.functional.conv2d`.
        When :meth:`~troch.nn.Module.eval()` is called, the filter is built with the current trained weights and stored
        for future reuse such that no overhead of expanding the kernel remains.
 
        
        The parameters ``basisexpansion``, ``sigma``, ``frequencies_cutoff``, ``rings`` and ``maximum_offset`` are
        optional parameters used to control how the basis for the filters is built, how it is sampled on the filter
        grid and how it is expanded to build the filter. We suggest to keep these default values.
        
        
        Args:
            in_type (FieldType): the type of the input field, specifying its transformation law
            out_type (FieldType): the type of the output field, specifying its transformation law
            kernel_size (int): the size of the (square) filter
            padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0``
            stride(int, optional): the stride of the kernel. Default: ``1``
            dilation(int, optional): the spacing between kernel elements. Default: ``1``
            groups (int, optional): number of blocked connections from input channels to output channels.
                                    It allows depthwise convolution. When used, the input and output types need to be
                                    divisible in ``groups`` groups, all equal to each other.
                                    Default: ``1``.
            bias (bool, optional): Whether to add a bias to the output (only to fields which contain a
                    trivial irrep) or not. Default ``True``
            basisexpansion (str, optional): the basis expansion algorithm to use
            sigma (list or float, optional): width of each ring where the bases are sampled. If only one scalar
                    is passed, it is used for all rings.
            frequencies_cutoff (callable or float, optional): function mapping the radii of the basis elements to the
                    maximum frequency accepted. If a float values is passed, the maximum frequency is equal to the
                    radius times this factor. By default (``None``), a more complex policy is used.
            rings (list, optional): radii of the rings where to sample the bases
            maximum_offset (int, optional): number of additional (aliased) frequencies in the intertwiners for finite
                    groups. By default (``None``), all additional frequencies allowed by the frequencies cut-off
                    are used.
            recompute (bool, optional): if ``True``, recomputes a new basis for the equivariant kernels.
                    By Default (``False``), it  caches the basis built or reuse a cached one, if it is found.
            basis_filter (callable, optional): function which takes as input a descriptor of a basis element
                    (as a dictionary) and returns a boolean value: whether to preserve (``True``) or discard (``False``)
                    the basis element. By default (``None``), no filtering is applied.
        
        Attributes:
            
            ~.weights (torch.Tensor): the learnable parameters which are used to expand the kernel
            ~.filter (torch.Tensor): the convolutional kernel obtained by expanding the parameters
                                    in :attr:`~e2cnn.nn.R2Conv.weights`
            ~.bias (torch.Tensor): the learnable parameters which are used to expand the bias, if ``bias=True``
            ~.expanded_bias (torch.Tensor): the equivariant bias which is summed to the output, obtained by expanding
                                    the parameters in :attr:`~e2cnn.nn.R2Conv.bias`
        
        """

        assert in_type.gspace == out_type.gspace
        assert isinstance(in_type.gspace, GeneralOnR2)

        super(R2Conv, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type
        self.out_type = out_type

        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.groups = groups

        if groups > 1:
            # Check the input and output classes can be split in `groups` groups, all equal to each other
            # first, check that the number of fields is divisible by `groups`
            assert len(in_type) % groups == 0
            assert len(out_type) % groups == 0
            in_size = len(in_type) // groups
            out_size = len(out_type) // groups

            # then, check that all groups are equal to each other, i.e. have the same types in the same order
            assert all(
                in_type.representations[i] == in_type.representations[i %
                                                                      in_size]
                for i in range(len(in_type)))
            assert all(out_type.representations[i] == out_type.representations[
                i % out_size] for i in range(len(out_type)))

            # finally, retrieve the type associated to a single group in input.
            # this type will be used to build a smaller kernel basis and a smaller filter
            # as in PyTorch, to build a filter for grouped convolution, we build a filter which maps from one input
            # group to all output groups. Then, PyTorch's standard convolution routine interpret this filter as `groups`
            # different filters, each mapping an input group to an output group.
            in_type = in_type.index_select(list(range(in_size)))

        if bias:
            # bias can be applied only to trivial irreps inside the representation
            # to apply bias to a field we learn a bias for each trivial irreps it contains
            # and, then, we transform it with the change of basis matrix to be able to apply it to the whole field
            # this is equivalent to transform the field to its irreps through the inverse change of basis,
            # sum the bias only to the trivial irrep and then map it back with the change of basis

            # count the number of trivial irreps
            trivials = 0
            for r in self.out_type:
                for irr in r.irreps:
                    if self.out_type.fibergroup.irreps[irr].is_trivial():
                        trivials += 1

            # if there is at least 1 trivial irrep
            if trivials > 0:

                # matrix containing the columns of the change of basis which map from the trivial irreps to the
                # field representations. This matrix allows us to map the bias defined only over the trivial irreps
                # to a bias for the whole field more efficiently
                bias_expansion = torch.zeros(self.out_type.size, trivials)

                p, c = 0, 0
                for r in self.out_type:
                    pi = 0
                    for irr in r.irreps:
                        irr = self.out_type.fibergroup.irreps[irr]
                        if irr.is_trivial():
                            bias_expansion[p:p + r.size, c] = torch.tensor(
                                r.change_of_basis[:, pi])
                            c += 1
                        pi += irr.size
                    p += r.size

                self.register_buffer("bias_expansion", bias_expansion)
                self.bias = Parameter(torch.zeros(trivials),
                                      requires_grad=True)
                self.register_buffer("expanded_bias",
                                     torch.zeros(out_type.size))
            else:
                self.bias = None
                self.expanded_bias = None
        else:
            self.bias = None
            self.expanded_bias = None

        grid, basis_filter, rings, sigma, maximum_frequency = compute_basis_params(
            kernel_size, frequencies_cutoff, rings, sigma, dilation,
            basis_filter)

        # BasisExpansion: submodule which takes care of building the filter
        self._basisexpansion = None

        # notice that `in_type` is used instead of `self.in_type` such that it works also when `groups > 1`
        if basisexpansion == 'blocks':
            self._basisexpansion = BlocksBasisExpansion(
                in_type,
                out_type,
                grid,
                sigma=sigma,
                rings=rings,
                maximum_offset=maximum_offset,
                maximum_frequency=maximum_frequency,
                basis_filter=basis_filter,
                recompute=recompute)

        else:
            raise ValueError('Basis Expansion algorithm "%s" not recognized' %
                             basisexpansion)

        self.weights = Parameter(torch.zeros(self.basisexpansion.dimension()),
                                 requires_grad=True)
        self.register_buffer(
            "filter",
            torch.zeros(out_type.size, in_type.size, kernel_size, kernel_size))

        # by default, the weights are initialized with a generalized form of He's weight initialization
        init.generalized_he_init(self.weights.data, self.basisexpansion)
Пример #9
0
    def __init__(
        self,
        in_type: FieldType,
        out_type: FieldType,
        kernel_size: int,
        padding: int = 0,
        output_padding: int = 0,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        basisexpansion: str = 'blocks',
        sigma: Union[List[float], float] = None,
        frequencies_cutoff: Union[float, Callable[[float], int]] = None,
        rings: List[float] = None,
        maximum_offset: int = None,
        recompute: bool = False,
        basis_filter: Callable[[dict], bool] = None,
        initialize: bool = True,
    ):
        r"""
        
        Transposed G-steerable planar convolution layer.
        
        .. warning ::
            
            Transposed convolution can produce artifacts which can harm the overall equivariance of the model.
            We suggest using :class:`~e2cnn.nn.R2Upsampling` combined with :class:`~e2cnn.nn.R2Conv` to perform
            upsampling.
        
        .. seealso ::
            For additional information about the parameters and the methods of this class, see :class:`e2cnn.nn.R2Conv`.
            The two modules are essentially the same, except for the type of convolution used.
            
        Args:
            in_type (FieldType): the type of the input field
            out_type (FieldType): the type of the output field
            kernel_size (int): the size of the filter
            padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0``
            output_padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0``
            stride(int, optional): the stride of the convolving kernel. Default: ``1``
            dilation(int, optional): the spacing between kernel elements. Default: ``1``
            groups (int, optional): number of blocked connections from input channels to output channels.
                                    Default: ``1``.
            bias (bool, optional): Whether to add a bias to the output (only to fields which contain a
                    trivial irrep) or not. Default ``True``
            initialize (bool, optional): initialize the weights of the model. Default: ``True``
            
        """

        assert in_type.gspace == out_type.gspace
        assert isinstance(in_type.gspace, GeneralOnR2)

        super(R2ConvTransposed, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type
        self.out_type = out_type

        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.output_padding = output_padding
        self.groups = groups

        if groups > 1:
            # Check the input and output classes can be split in `groups` groups, all equal to each other
            # first, check that the number of fields is divisible by `groups`
            assert len(in_type) % groups == 0
            assert len(out_type) % groups == 0
            in_size = len(in_type) // groups
            out_size = len(out_type) // groups

            # then, check that all groups are equal to each other, i.e. have the same types in the same order
            assert all(
                in_type.representations[i] == in_type.representations[i %
                                                                      in_size]
                for i in range(len(in_type)))
            assert all(out_type.representations[i] == out_type.representations[
                i % out_size] for i in range(len(out_type)))

            # finally, retrieve the type associated to a single group in output.
            # this type will be used to build a smaller kernel basis and a smaller filter
            # as in PyTorch, to build a filter for grouped convolution, we build a filter which maps from all input
            # groups to one output group. Then, PyTorch's standard convolution routine interpret this filter as `groups`
            # different filters, each mapping an input group to an output group.
            out_type = out_type.index_select(list(range(out_size)))

        if bias:
            # bias can be applied only to trivial irreps inside the representation
            # to apply bias to a field we learn a bias for each trivial irreps it contains
            # and, then, we transform it with the change of basis matrix to be able to apply it to the whole field
            # this is equivalent to transform the field to its irreps through the inverse change of basis,
            # sum the bias only to the trivial irrep and then map it back with the change of basis

            # count the number of trivial irreps
            trivials = 0
            for r in self.out_type:
                for irr in r.irreps:
                    if self.out_type.fibergroup.irreps[irr].is_trivial():
                        trivials += 1

            # if there is at least 1 trivial irrep
            if trivials > 0:

                # matrix containing the columns of the change of basis which map from the trivial irreps to the
                # field representations. This matrix allows us to map the bias defined only over the trivial irreps
                # to a bias for the whole field more efficiently
                bias_expansion = torch.zeros(self.out_type.size, trivials)

                p, c = 0, 0
                for r in self.out_type:
                    pi = 0
                    for irr in r.irreps:
                        irr = self.out_type.fibergroup.irreps[irr]
                        if irr.is_trivial():
                            bias_expansion[p:p + r.size, c] = torch.tensor(
                                r.change_of_basis[:, pi])
                            c += 1
                        pi += irr.size
                    p += r.size

                self.register_buffer("bias_expansion", bias_expansion)
                self.bias = Parameter(torch.zeros(trivials),
                                      requires_grad=True)
                self.register_buffer("expanded_bias",
                                     torch.zeros(out_type.size))
            else:
                self.bias = None
                self.expanded_bias = None
        else:
            self.bias = None
            self.expanded_bias = None

        grid, basis_filter, rings, sigma, maximum_frequency = compute_basis_params(
            kernel_size, frequencies_cutoff, rings, sigma, dilation,
            basis_filter)

        # BasisExpansion: submodule which takes care of building the filter
        self._basisexpansion = None

        # notice that `out_type` is used instead of `self.out_type` such that it works also when `groups > 1`
        if basisexpansion == 'blocks':
            self._basisexpansion = BlocksBasisExpansion(
                in_type,
                out_type,
                grid,
                sigma=sigma,
                rings=rings,
                maximum_offset=maximum_offset,
                maximum_frequency=maximum_frequency,
                basis_filter=basis_filter,
                recompute=recompute)

        else:
            raise ValueError('Basis Expansion algorithm "%s" not recognized' %
                             basisexpansion)

        self.weights = Parameter(torch.zeros(self.basisexpansion.dimension()),
                                 requires_grad=True)
        self.register_buffer(
            "filter",
            torch.zeros(in_type.size, out_type.size, kernel_size, kernel_size))

        if initialize:
            # by default, the weights are initialized with a generalized form of he's weight initialization
            init.generalized_he_init(self.weights.data, self.basisexpansion)