Esempio n. 1
0
    def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]:
        r"""

        Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by
        the input ``id``.
        
        As the trivial group contains only one element, there are no other subgroups.
        The only accepted input value is ``id=1`` and returns this same group.
        This functionality is implemented only for consistency with the other G-spaces.
        
        Args:
            id (int): the order of the subgroup

        Returns:
            a tuple containing

                - **gspace**: the restricted gspace

                - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space

                - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup)


        """
        group, mapping, child = self.fibergroup.subgroup(id)
        return gspaces.TrivialOnR2(fibergroup=group), mapping, child
Esempio n. 2
0
    def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]:
        r"""

        Build the :class:`~e2cnn.gspaces.GSpace` associated with the subgroup of the current fiber group identified
        by the input ``id``.

        As the reflection group contains only two elements, it has only one subgroup: the trivial group.
        The only accepted input values are ``id=1`` which returns an instance of :class:`~e2cnn.gspaces.TrivialOnR2` and
        ``id=2`` which returns a new instance of the current group.

        Args:
            id (tuple): the id of the subgroup

        Returns:
            a tuple containing

                - **gspace**: the restricted gspace

                - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space

                - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup)

        """
        group, mapping, child = self.fibergroup.subgroup(id)
        if id == 1:
            return gspaces.TrivialOnR2(fibergroup=group), mapping, child
        else:
            return gspaces.Flip2dOnR2(axis=self.axis,
                                      fibergroup=group), mapping, child
Esempio n. 3
0
    def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]:
        r"""

        Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by
        the input ``id``.
        
        ``id`` is a positive integer :math:`M` indicating the number of rotations in the subgroup.
        If the current fiber group is :math:`C_N` (:class:`~e2cnn.group.CyclicGroup`), then :math:`M` needs to divide
        :math:`N`. Otherwise, :math:`M` can be any positive integer.
        
        Args:
            id (int): the number :math:`M` of rotations in the subgroup

        Returns:
            a tuple containing

                - **gspace**: the restricted gspace

                - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space

                - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup)


        """
        subgroup, mapping, child = self.fibergroup.subgroup(id)

        if id > 1:
            return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child
        elif id == 1:
            return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child
        else:
            raise ValueError(f"id {id} not recognized!")
Esempio n. 4
0
    def __init__(self, depth, widen_factor, dropout_rate, num_classes=100,
                 N: int = 8,
                 r: int = 1,
                 f: bool = True,
                 deltaorth: bool = False,
                 fixparams: bool = True,
                 initial_stride: int = 1,
                 ):
        r"""
        
        Build and equivariant Wide ResNet.
        
        The parameter ``N`` controls rotation equivariance and the parameter ``f`` reflection equivariance.
        
        More precisely, ``N`` is the number of discrete rotations the model is initially equivariant to.
        ``N = 1`` means the model is only reflection equivariant from the beginning.
        
        ``f`` is a boolean flag specifying whether the model should be reflection equivariant or not.
        If it is ``False``, the model is not reflection equivariant.
        
        ``r`` is the restriction level:
        
        - ``0``: no restriction. The model is equivariant to ``N`` rotations from the input to the output

        - ``1``: restriction before the last block. The model is equivariant to ``N`` rotations before the last block
               (i.e. in the first 2 blocks). Then it is restricted to ``N/2`` rotations until the output.
        
        - ``2``: restriction after the first block. The model is equivariant to ``N`` rotations in the first block.
               Then it is restricted to ``N/2`` rotations until the output (i.e. in the last 3 blocks).
               
        - ``3``: restriction after the first and the second block. The model is equivariant to ``N`` rotations in the first
               block. It is restricted to ``N/2`` rotations before the second block and to ``1`` rotations before the last
               block.
        
        NOTICE: if restriction to ``N/2`` is performed, ``N`` needs to be even!
        
        """
        super(Wide_ResNet, self).__init__()
        
        assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
        n = int((depth - 4) / 6)
        k = widen_factor
        
        print(f'| Wide-Resnet {depth}x{k}')
        
        nStages = [16, 16 * k, 32 * k, 64 * k]
        
        self._fixparams = fixparams
        
        self._layer = 0
        
        # number of discrete rotations to be equivariant to
        self._N = N
        
        # if the model is [F]lip equivariant
        self._f = f
        if self._f:
            if N != 1:
                self.gspace = gspaces.FlipRot2dOnR2(N)
            else:
                self.gspace = gspaces.Flip2dOnR2()
        else:
            if N != 1:
                self.gspace = gspaces.Rot2dOnR2(N)
            else:
                self.gspace = gspaces.TrivialOnR2()

        # level of [R]estriction:
        #   r = 0: never do restriction, i.e. initial group (either DN or CN) preserved for the whole network
        #   r = 1: restrict before the last block, i.e. initial group (either DN or CN) preserved for the first
        #          2 blocks, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last block
        #   r = 2: restrict after the first block, i.e. initial group (either DN or CN) preserved for the first
        #          block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last 2 blocks
        #   r = 3: restrict after each block. Initial group (either DN or CN) preserved for the first
        #          block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the second block and to 1 rotation
        #          in the last one (D1 or C1)
        assert r in [0, 1, 2, 3]
        self._r = r
        
        # the input has 3 color channels (RGB).
        # Color channels are trivial fields and don't transform when the input is rotated or flipped
        r1 = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
        
        # input field type of the model
        self.in_type = r1
        
        # in the first layer we always scale up the output channels to allow for enough independent filters
        r2 = FIELD_TYPE["regular"](self.gspace, nStages[0], fixparams=True)
        
        # dummy attribute keeping track of the output field type of the last submodule built, i.e. the input field type of
        # the next submodule to build
        self._in_type = r2
        
        self.conv1 = conv5x5(r1, r2)
        self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=initial_stride)
        if self._r >= 2:
            N_new = N//2
            id = (0, N_new) if self._f else N_new
            self.restrict1 = self._restrict_layer(id)
        else:
            self.restrict1 = lambda x: x
        
        self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
        if self._r == 3:
            id = (0, 1) if self._f else 1
            self.restrict2 = self._restrict_layer(id)
        elif self._r == 1:
            N_new = N // 2
            id = (0, N_new) if self._f else N_new
            self.restrict2 = self._restrict_layer(id)
        else:
            self.restrict2 = lambda x: x
        
        # last layer maps to a trivial (invariant) feature map
        self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2, totrivial=True)
        
        self.bn = enn.InnerBatchNorm(self.layer3.out_type, momentum=0.9)
        self.relu = enn.ReLU(self.bn.out_type, inplace=True)
        self.linear = torch.nn.Linear(self.bn.out_type.size, num_classes)
        
        for name, module in self.named_modules():
            if isinstance(module, enn.R2Conv):
                if deltaorth:
                    init.deltaorthonormal_init(module.weights, module.basisexpansion)
            elif isinstance(module, torch.nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()
            elif isinstance(module, torch.nn.Linear):
                module.bias.data.zero_()
        
        print("MODEL TOPOLOGY:")
        for i, (name, mod) in enumerate(self.named_modules()):
            print(f"\t{i} - {name}")
Esempio n. 5
0
    def restrict(self, id: Tuple[Union[None, float, int], int]) -> Tuple[gspaces.GSpace, Callable, Callable]:
        r"""

        Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by
        the input ``id``, which is a tuple :math:`(k, M)`.
        
        Here, :math:`M` is a positive integer indicating the number of discrete rotations in the subgroup while
        :math:`k` is either ``None`` (no reflections) or an angle indicating the axis of reflection.
        If the current fiber group is :math:`D_N` (:class:`~e2cnn.group.DihedralGroup`), then :math:`M` needs to divide
        :math:`N` and :math:`k` needs to be an integer in :math:`\{0, \dots, \frac{N}{M}-1\}`.
        Otherwise, :math:`M` can be any positive integer while :math:`k` needs to be a real number in
        :math:`[0, \frac{2\pi}{M}]`.
        
        Valid combinations are:
        
        - (``None``, :math:`1`): restrict to no reflection and rotation symmetries
        
        - (``None``, :math:`M`): restrict to only the :math:`M` rotations generated by :math:`r_{2\pi/M}`.
        
        - (:math:`0`, :math:`1`): restrict to only reflections :math:`\langle f \rangle` around the same axis as in the current group
        
        - (:math:`0`, :math:`M`): restrict to reflections and :math:`M` rotations generated by :math:`r_{2\pi/M}` and :math:`f`
        
        If the current fiber group is :math:`D_N` (an instance of :class:`~e2cnn.group.DihedralGroup`):
        
        - (:math:`k`, :math:`M`): restrict to reflections :math:`\langle r_{k\frac{2\pi}{N}} f \rangle` around the axis of the current G-space rotated by :math:`k\frac{\pi}{N}` and :math:`M` rotations generated by :math:`r_{2\pi/M}`
        
        If the current fiber group is :math:`O(2)` (an instance of :class:`~e2cnn.group.O2`):
        
        - (:math:`\theta`, :math:`M`): restrict to reflections :math:`\langle r_{\theta} f \rangle` around the axis of the current G-space rotated by :math:`\frac{\theta}{2}` and :math:`M` rotations generated by :math:`r_{2\pi/M}`
        
        - (``None``, :math:`-1`): restrict to all (continuous) rotations
        
        Args:
            id (tuple): the id of the subgroup

        Returns:
            a tuple containing

                - **gspace**: the restricted gspace

                - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space

                - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup)


        """
    
        subgroup, mapping, child = self.fibergroup.subgroup(id)
        
        if id[0] is not None:
            # the new flip axis is the previous one rotated by the new chosen axis for the flip
            # notice that the actual group element used to generate the subgroup does not correspond to the flip axis
            # but to 2 times that angle
            
            if self.fibergroup.order() > 1:
                n = self.fibergroup.rotation_order
                rotation = id[0] * 2.0 * np.pi / n
            else:
                rotation = id[0]
                
            new_axis = divmod(self.axis + 0.5*rotation, 2*np.pi)[1]

        if id[0] is None and id[1] == 1:
            return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child
        elif id[0] is None and (id[1] > 1 or id[1] == -1):
            return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child
        elif id[0] is not None and id[1] == 1:
            return gspaces.Flip2dOnR2(fibergroup=subgroup, axis=new_axis), mapping, child
        elif id[0] is not None:
            return gspaces.FlipRot2dOnR2(fibergroup=subgroup, axis=new_axis), mapping, child
        else:
            raise ValueError(f"id {id} not recognized!")