def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]: r""" Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by the input ``id``. As the trivial group contains only one element, there are no other subgroups. The only accepted input value is ``id=1`` and returns this same group. This functionality is implemented only for consistency with the other G-spaces. Args: id (int): the order of the subgroup Returns: a tuple containing - **gspace**: the restricted gspace - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup) """ group, mapping, child = self.fibergroup.subgroup(id) return gspaces.TrivialOnR2(fibergroup=group), mapping, child
def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]: r""" Build the :class:`~e2cnn.gspaces.GSpace` associated with the subgroup of the current fiber group identified by the input ``id``. As the reflection group contains only two elements, it has only one subgroup: the trivial group. The only accepted input values are ``id=1`` which returns an instance of :class:`~e2cnn.gspaces.TrivialOnR2` and ``id=2`` which returns a new instance of the current group. Args: id (tuple): the id of the subgroup Returns: a tuple containing - **gspace**: the restricted gspace - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup) """ group, mapping, child = self.fibergroup.subgroup(id) if id == 1: return gspaces.TrivialOnR2(fibergroup=group), mapping, child else: return gspaces.Flip2dOnR2(axis=self.axis, fibergroup=group), mapping, child
def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]: r""" Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by the input ``id``. ``id`` is a positive integer :math:`M` indicating the number of rotations in the subgroup. If the current fiber group is :math:`C_N` (:class:`~e2cnn.group.CyclicGroup`), then :math:`M` needs to divide :math:`N`. Otherwise, :math:`M` can be any positive integer. Args: id (int): the number :math:`M` of rotations in the subgroup Returns: a tuple containing - **gspace**: the restricted gspace - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup) """ subgroup, mapping, child = self.fibergroup.subgroup(id) if id > 1: return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child elif id == 1: return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child else: raise ValueError(f"id {id} not recognized!")
def __init__(self, depth, widen_factor, dropout_rate, num_classes=100, N: int = 8, r: int = 1, f: bool = True, deltaorth: bool = False, fixparams: bool = True, initial_stride: int = 1, ): r""" Build and equivariant Wide ResNet. The parameter ``N`` controls rotation equivariance and the parameter ``f`` reflection equivariance. More precisely, ``N`` is the number of discrete rotations the model is initially equivariant to. ``N = 1`` means the model is only reflection equivariant from the beginning. ``f`` is a boolean flag specifying whether the model should be reflection equivariant or not. If it is ``False``, the model is not reflection equivariant. ``r`` is the restriction level: - ``0``: no restriction. The model is equivariant to ``N`` rotations from the input to the output - ``1``: restriction before the last block. The model is equivariant to ``N`` rotations before the last block (i.e. in the first 2 blocks). Then it is restricted to ``N/2`` rotations until the output. - ``2``: restriction after the first block. The model is equivariant to ``N`` rotations in the first block. Then it is restricted to ``N/2`` rotations until the output (i.e. in the last 3 blocks). - ``3``: restriction after the first and the second block. The model is equivariant to ``N`` rotations in the first block. It is restricted to ``N/2`` rotations before the second block and to ``1`` rotations before the last block. NOTICE: if restriction to ``N/2`` is performed, ``N`` needs to be even! """ super(Wide_ResNet, self).__init__() assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' n = int((depth - 4) / 6) k = widen_factor print(f'| Wide-Resnet {depth}x{k}') nStages = [16, 16 * k, 32 * k, 64 * k] self._fixparams = fixparams self._layer = 0 # number of discrete rotations to be equivariant to self._N = N # if the model is [F]lip equivariant self._f = f if self._f: if N != 1: self.gspace = gspaces.FlipRot2dOnR2(N) else: self.gspace = gspaces.Flip2dOnR2() else: if N != 1: self.gspace = gspaces.Rot2dOnR2(N) else: self.gspace = gspaces.TrivialOnR2() # level of [R]estriction: # r = 0: never do restriction, i.e. initial group (either DN or CN) preserved for the whole network # r = 1: restrict before the last block, i.e. initial group (either DN or CN) preserved for the first # 2 blocks, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last block # r = 2: restrict after the first block, i.e. initial group (either DN or CN) preserved for the first # block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the last 2 blocks # r = 3: restrict after each block. Initial group (either DN or CN) preserved for the first # block, then restrict to N/2 rotations (either D{N/2} or C{N/2}) in the second block and to 1 rotation # in the last one (D1 or C1) assert r in [0, 1, 2, 3] self._r = r # the input has 3 color channels (RGB). # Color channels are trivial fields and don't transform when the input is rotated or flipped r1 = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3) # input field type of the model self.in_type = r1 # in the first layer we always scale up the output channels to allow for enough independent filters r2 = FIELD_TYPE["regular"](self.gspace, nStages[0], fixparams=True) # dummy attribute keeping track of the output field type of the last submodule built, i.e. the input field type of # the next submodule to build self._in_type = r2 self.conv1 = conv5x5(r1, r2) self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=initial_stride) if self._r >= 2: N_new = N//2 id = (0, N_new) if self._f else N_new self.restrict1 = self._restrict_layer(id) else: self.restrict1 = lambda x: x self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2) if self._r == 3: id = (0, 1) if self._f else 1 self.restrict2 = self._restrict_layer(id) elif self._r == 1: N_new = N // 2 id = (0, N_new) if self._f else N_new self.restrict2 = self._restrict_layer(id) else: self.restrict2 = lambda x: x # last layer maps to a trivial (invariant) feature map self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2, totrivial=True) self.bn = enn.InnerBatchNorm(self.layer3.out_type, momentum=0.9) self.relu = enn.ReLU(self.bn.out_type, inplace=True) self.linear = torch.nn.Linear(self.bn.out_type.size, num_classes) for name, module in self.named_modules(): if isinstance(module, enn.R2Conv): if deltaorth: init.deltaorthonormal_init(module.weights, module.basisexpansion) elif isinstance(module, torch.nn.BatchNorm2d): module.weight.data.fill_(1) module.bias.data.zero_() elif isinstance(module, torch.nn.Linear): module.bias.data.zero_() print("MODEL TOPOLOGY:") for i, (name, mod) in enumerate(self.named_modules()): print(f"\t{i} - {name}")
def restrict(self, id: Tuple[Union[None, float, int], int]) -> Tuple[gspaces.GSpace, Callable, Callable]: r""" Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by the input ``id``, which is a tuple :math:`(k, M)`. Here, :math:`M` is a positive integer indicating the number of discrete rotations in the subgroup while :math:`k` is either ``None`` (no reflections) or an angle indicating the axis of reflection. If the current fiber group is :math:`D_N` (:class:`~e2cnn.group.DihedralGroup`), then :math:`M` needs to divide :math:`N` and :math:`k` needs to be an integer in :math:`\{0, \dots, \frac{N}{M}-1\}`. Otherwise, :math:`M` can be any positive integer while :math:`k` needs to be a real number in :math:`[0, \frac{2\pi}{M}]`. Valid combinations are: - (``None``, :math:`1`): restrict to no reflection and rotation symmetries - (``None``, :math:`M`): restrict to only the :math:`M` rotations generated by :math:`r_{2\pi/M}`. - (:math:`0`, :math:`1`): restrict to only reflections :math:`\langle f \rangle` around the same axis as in the current group - (:math:`0`, :math:`M`): restrict to reflections and :math:`M` rotations generated by :math:`r_{2\pi/M}` and :math:`f` If the current fiber group is :math:`D_N` (an instance of :class:`~e2cnn.group.DihedralGroup`): - (:math:`k`, :math:`M`): restrict to reflections :math:`\langle r_{k\frac{2\pi}{N}} f \rangle` around the axis of the current G-space rotated by :math:`k\frac{\pi}{N}` and :math:`M` rotations generated by :math:`r_{2\pi/M}` If the current fiber group is :math:`O(2)` (an instance of :class:`~e2cnn.group.O2`): - (:math:`\theta`, :math:`M`): restrict to reflections :math:`\langle r_{\theta} f \rangle` around the axis of the current G-space rotated by :math:`\frac{\theta}{2}` and :math:`M` rotations generated by :math:`r_{2\pi/M}` - (``None``, :math:`-1`): restrict to all (continuous) rotations Args: id (tuple): the id of the subgroup Returns: a tuple containing - **gspace**: the restricted gspace - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup) """ subgroup, mapping, child = self.fibergroup.subgroup(id) if id[0] is not None: # the new flip axis is the previous one rotated by the new chosen axis for the flip # notice that the actual group element used to generate the subgroup does not correspond to the flip axis # but to 2 times that angle if self.fibergroup.order() > 1: n = self.fibergroup.rotation_order rotation = id[0] * 2.0 * np.pi / n else: rotation = id[0] new_axis = divmod(self.axis + 0.5*rotation, 2*np.pi)[1] if id[0] is None and id[1] == 1: return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child elif id[0] is None and (id[1] > 1 or id[1] == -1): return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child elif id[0] is not None and id[1] == 1: return gspaces.Flip2dOnR2(fibergroup=subgroup, axis=new_axis), mapping, child elif id[0] is not None: return gspaces.FlipRot2dOnR2(fibergroup=subgroup, axis=new_axis), mapping, child else: raise ValueError(f"id {id} not recognized!")