def __init__(self, in_type: FieldType, permutation: List[int] ): r""" Permutes the fields of the input tensor according to the input ``permutation``. The parameter ``permutation`` should be a list containing a permutation of the integers ``{0, 1, ..., n-1}``, where ``n`` is the number of fields of ``in_type`` (see :meth:`e2cnn.nn.FieldType.__len__`). Args: in_type (FieldType): input field type permutation (list): permutation to apply """ assert isinstance(in_type.gspace, GeneralOnR2) super(ReshuffleModule, self).__init__() # check if it is actually a permutation of the fields assert set(permutation) == set(range(0, len(in_type.representations))) self.space = in_type.gspace self.in_type = in_type # permute the fields in the input representation to build the output representation self.out_type = in_type.index_select(permutation) # compute the starting position of each field in the input representation positions = [] last_position = 0 for r in in_type.representations: positions.append(last_position) last_position += r.size # compute the indices for the permutation indices = [] for c in permutation: size = in_type.representations[c].size indices += list(range(positions[c], positions[c]+size)) self.register_buffer("indices", torch.LongTensor(indices))
def __init__( self, in_type: FieldType, out_type: FieldType, kernel_size: int, padding: int = 0, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, basisexpansion: str = 'blocks', sigma: Union[List[float], float] = None, frequencies_cutoff: Union[float, Callable[[float], int]] = None, rings: List[float] = None, maximum_offset: int = None, recompute: bool = False, basis_filter: Callable[[dict], bool] = None, ): r""" G-steerable planar convolution mapping between the input and output :class:`~e2cnn.nn.FieldType` s specified by the parameters ``in_type`` and ``out_type``. This operation is equivariant under the action of :math:`\R^2\rtimes G` where :math:`G` is the :attr:`e2cnn.nn.FieldType.fibergroup` of ``in_type`` and ``out_type``. Specifically, let :math:`\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}` and :math:`\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}` be the representations specified by the input and output field types. Then :class:`~e2cnn.nn.R2Conv` guarantees an equivariant mapping .. math:: \kappa \star [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [\kappa \star f] \qquad\qquad \forall g \in G, u \in \R^2 where the transformation of the input and output fields are given by .. math:: [\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\ [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\ The equivariance of G-steerable convolutions is guaranteed by restricting the space of convolution kernels to an equivariant subspace. As proven in `3D Steerable CNNs <https://arxiv.org/abs/1807.02547>`_, this parametrizes the *most general equivariant convolutional map* between the input and output fields. For feature fields on :math:`\R^2` (e.g. images), the complete G-steerable kernel spaces for :math:`G \leq \O2` is derived in `General E(2)-Equivariant Steerable CNNs <https://arxiv.org/abs/1911.08251>`_. During training, in each forward pass the module expand the basis of G-steerable kernels with learned weights before calling :func:`torch.nn.functional.conv2d`. When :meth:`~troch.nn.Module.eval()` is called, the filter is built with the current trained weights and stored for future reuse such that no overhead of expanding the kernel remains. The parameters ``basisexpansion``, ``sigma``, ``frequencies_cutoff``, ``rings`` and ``maximum_offset`` are optional parameters used to control how the basis for the filters is built, how it is sampled on the filter grid and how it is expanded to build the filter. We suggest to keep these default values. Args: in_type (FieldType): the type of the input field, specifying its transformation law out_type (FieldType): the type of the output field, specifying its transformation law kernel_size (int): the size of the (square) filter padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0`` stride(int, optional): the stride of the kernel. Default: ``1`` dilation(int, optional): the spacing between kernel elements. Default: ``1`` groups (int, optional): number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in ``groups`` groups, all equal to each other. Default: ``1``. bias (bool, optional): Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default ``True`` basisexpansion (str, optional): the basis expansion algorithm to use sigma (list or float, optional): width of each ring where the bases are sampled. If only one scalar is passed, it is used for all rings. frequencies_cutoff (callable or float, optional): function mapping the radii of the basis elements to the maximum frequency accepted. If a float values is passed, the maximum frequency is equal to the radius times this factor. By default (``None``), a more complex policy is used. rings (list, optional): radii of the rings where to sample the bases maximum_offset (int, optional): number of additional (aliased) frequencies in the intertwiners for finite groups. By default (``None``), all additional frequencies allowed by the frequencies cut-off are used. recompute (bool, optional): if ``True``, recomputes a new basis for the equivariant kernels. By Default (``False``), it caches the basis built or reuse a cached one, if it is found. basis_filter (callable, optional): function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (``True``) or discard (``False``) the basis element. By default (``None``), no filtering is applied. Attributes: ~.weights (torch.Tensor): the learnable parameters which are used to expand the kernel ~.filter (torch.Tensor): the convolutional kernel obtained by expanding the parameters in :attr:`~e2cnn.nn.R2Conv.weights` ~.bias (torch.Tensor): the learnable parameters which are used to expand the bias, if ``bias=True`` ~.expanded_bias (torch.Tensor): the equivariant bias which is summed to the output, obtained by expanding the parameters in :attr:`~e2cnn.nn.R2Conv.bias` """ assert in_type.gspace == out_type.gspace assert isinstance(in_type.gspace, GeneralOnR2) super(R2Conv, self).__init__() self.space = in_type.gspace self.in_type = in_type self.out_type = out_type self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.padding = padding self.groups = groups if groups > 1: # Check the input and output classes can be split in `groups` groups, all equal to each other # first, check that the number of fields is divisible by `groups` assert len(in_type) % groups == 0 assert len(out_type) % groups == 0 in_size = len(in_type) // groups out_size = len(out_type) // groups # then, check that all groups are equal to each other, i.e. have the same types in the same order assert all( in_type.representations[i] == in_type.representations[i % in_size] for i in range(len(in_type))) assert all(out_type.representations[i] == out_type.representations[ i % out_size] for i in range(len(out_type))) # finally, retrieve the type associated to a single group in input. # this type will be used to build a smaller kernel basis and a smaller filter # as in PyTorch, to build a filter for grouped convolution, we build a filter which maps from one input # group to all output groups. Then, PyTorch's standard convolution routine interpret this filter as `groups` # different filters, each mapping an input group to an output group. in_type = in_type.index_select(list(range(in_size))) if bias: # bias can be applied only to trivial irreps inside the representation # to apply bias to a field we learn a bias for each trivial irreps it contains # and, then, we transform it with the change of basis matrix to be able to apply it to the whole field # this is equivalent to transform the field to its irreps through the inverse change of basis, # sum the bias only to the trivial irrep and then map it back with the change of basis # count the number of trivial irreps trivials = 0 for r in self.out_type: for irr in r.irreps: if self.out_type.fibergroup.irreps[irr].is_trivial(): trivials += 1 # if there is at least 1 trivial irrep if trivials > 0: # matrix containing the columns of the change of basis which map from the trivial irreps to the # field representations. This matrix allows us to map the bias defined only over the trivial irreps # to a bias for the whole field more efficiently bias_expansion = torch.zeros(self.out_type.size, trivials) p, c = 0, 0 for r in self.out_type: pi = 0 for irr in r.irreps: irr = self.out_type.fibergroup.irreps[irr] if irr.is_trivial(): bias_expansion[p:p + r.size, c] = torch.tensor( r.change_of_basis[:, pi]) c += 1 pi += irr.size p += r.size self.register_buffer("bias_expansion", bias_expansion) self.bias = Parameter(torch.zeros(trivials), requires_grad=True) self.register_buffer("expanded_bias", torch.zeros(out_type.size)) else: self.bias = None self.expanded_bias = None else: self.bias = None self.expanded_bias = None grid, basis_filter, rings, sigma, maximum_frequency = compute_basis_params( kernel_size, frequencies_cutoff, rings, sigma, dilation, basis_filter) # BasisExpansion: submodule which takes care of building the filter self._basisexpansion = None # notice that `in_type` is used instead of `self.in_type` such that it works also when `groups > 1` if basisexpansion == 'blocks': self._basisexpansion = BlocksBasisExpansion( in_type, out_type, grid, sigma=sigma, rings=rings, maximum_offset=maximum_offset, maximum_frequency=maximum_frequency, basis_filter=basis_filter, recompute=recompute) else: raise ValueError('Basis Expansion algorithm "%s" not recognized' % basisexpansion) self.weights = Parameter(torch.zeros(self.basisexpansion.dimension()), requires_grad=True) self.register_buffer( "filter", torch.zeros(out_type.size, in_type.size, kernel_size, kernel_size)) # by default, the weights are initialized with a generalized form of He's weight initialization init.generalized_he_init(self.weights.data, self.basisexpansion)
def __init__( self, in_type: FieldType, out_type: FieldType, kernel_size: int, padding: int = 0, output_padding: int = 0, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, basisexpansion: str = 'blocks', sigma: Union[List[float], float] = None, frequencies_cutoff: Union[float, Callable[[float], int]] = None, rings: List[float] = None, maximum_offset: int = None, recompute: bool = False, basis_filter: Callable[[dict], bool] = None, initialize: bool = True, ): r""" Transposed G-steerable planar convolution layer. .. warning :: Transposed convolution can produce artifacts which can harm the overall equivariance of the model. We suggest using :class:`~e2cnn.nn.R2Upsampling` combined with :class:`~e2cnn.nn.R2Conv` to perform upsampling. .. seealso :: For additional information about the parameters and the methods of this class, see :class:`e2cnn.nn.R2Conv`. The two modules are essentially the same, except for the type of convolution used. Args: in_type (FieldType): the type of the input field out_type (FieldType): the type of the output field kernel_size (int): the size of the filter padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0`` output_padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0`` stride(int, optional): the stride of the convolving kernel. Default: ``1`` dilation(int, optional): the spacing between kernel elements. Default: ``1`` groups (int, optional): number of blocked connections from input channels to output channels. Default: ``1``. bias (bool, optional): Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default ``True`` initialize (bool, optional): initialize the weights of the model. Default: ``True`` """ assert in_type.gspace == out_type.gspace assert isinstance(in_type.gspace, GeneralOnR2) super(R2ConvTransposed, self).__init__() self.space = in_type.gspace self.in_type = in_type self.out_type = out_type self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.padding = padding self.output_padding = output_padding self.groups = groups if groups > 1: # Check the input and output classes can be split in `groups` groups, all equal to each other # first, check that the number of fields is divisible by `groups` assert len(in_type) % groups == 0 assert len(out_type) % groups == 0 in_size = len(in_type) // groups out_size = len(out_type) // groups # then, check that all groups are equal to each other, i.e. have the same types in the same order assert all( in_type.representations[i] == in_type.representations[i % in_size] for i in range(len(in_type))) assert all(out_type.representations[i] == out_type.representations[ i % out_size] for i in range(len(out_type))) # finally, retrieve the type associated to a single group in output. # this type will be used to build a smaller kernel basis and a smaller filter # as in PyTorch, to build a filter for grouped convolution, we build a filter which maps from all input # groups to one output group. Then, PyTorch's standard convolution routine interpret this filter as `groups` # different filters, each mapping an input group to an output group. out_type = out_type.index_select(list(range(out_size))) if bias: # bias can be applied only to trivial irreps inside the representation # to apply bias to a field we learn a bias for each trivial irreps it contains # and, then, we transform it with the change of basis matrix to be able to apply it to the whole field # this is equivalent to transform the field to its irreps through the inverse change of basis, # sum the bias only to the trivial irrep and then map it back with the change of basis # count the number of trivial irreps trivials = 0 for r in self.out_type: for irr in r.irreps: if self.out_type.fibergroup.irreps[irr].is_trivial(): trivials += 1 # if there is at least 1 trivial irrep if trivials > 0: # matrix containing the columns of the change of basis which map from the trivial irreps to the # field representations. This matrix allows us to map the bias defined only over the trivial irreps # to a bias for the whole field more efficiently bias_expansion = torch.zeros(self.out_type.size, trivials) p, c = 0, 0 for r in self.out_type: pi = 0 for irr in r.irreps: irr = self.out_type.fibergroup.irreps[irr] if irr.is_trivial(): bias_expansion[p:p + r.size, c] = torch.tensor( r.change_of_basis[:, pi]) c += 1 pi += irr.size p += r.size self.register_buffer("bias_expansion", bias_expansion) self.bias = Parameter(torch.zeros(trivials), requires_grad=True) self.register_buffer("expanded_bias", torch.zeros(out_type.size)) else: self.bias = None self.expanded_bias = None else: self.bias = None self.expanded_bias = None grid, basis_filter, rings, sigma, maximum_frequency = compute_basis_params( kernel_size, frequencies_cutoff, rings, sigma, dilation, basis_filter) # BasisExpansion: submodule which takes care of building the filter self._basisexpansion = None # notice that `out_type` is used instead of `self.out_type` such that it works also when `groups > 1` if basisexpansion == 'blocks': self._basisexpansion = BlocksBasisExpansion( in_type, out_type, grid, sigma=sigma, rings=rings, maximum_offset=maximum_offset, maximum_frequency=maximum_frequency, basis_filter=basis_filter, recompute=recompute) else: raise ValueError('Basis Expansion algorithm "%s" not recognized' % basisexpansion) self.weights = Parameter(torch.zeros(self.basisexpansion.dimension()), requires_grad=True) self.register_buffer( "filter", torch.zeros(in_type.size, out_type.size, kernel_size, kernel_size)) if initialize: # by default, the weights are initialized with a generalized form of he's weight initialization init.generalized_he_init(self.weights.data, self.basisexpansion)
def __init__(self, in_type: FieldType, gates: List = None, drop_gates: bool = True, **kwargs): r""" Gated non-linearities. This module applies a bias and a sigmoid function of the gates fields and, then, multiplies each gated field by one of the gates. The input representation of the gated fields is preserved by this operation while the gate fields are discarded. The gates and the gated fields are provided in one unique input tensor and, therefore, :attr:`in_repr` should be the representation of the fiber containing both gates and gated fields. Moreover, the parameter :attr:`gates` needs to be set with a list long as the total number of fields, containing in a position ``i`` the string ``"gate"`` if the ``i``-th field is a gate or the string ``"gated"`` if the ``i``-th field is a gated field. No other strings are allowed. By default (``gates = None``), the first half of the fields is assumed to contain the gates (and, so, these fields have to be trivial fields) while the second one is assumed to contain the gated fields. In any case, the number of gates and the number of gated fields have to match (therefore, the number of fields has to be an even number). Args: in_type (FieldType): the input field type gates (list, optional): list of strings specifying which field in input is a gate and which is a gated field drop_gates (bool, optional): if ``True`` (default), drop the trivial fields after using them to compute the gates. If ``False``, the gates are stacked with the gated fields in the output """ assert isinstance(in_type.gspace, GeneralOnR2) if gates is None: assert len(in_type) % 2 == 0 g = len(in_type) // 2 gates = [GATES_ID] * g + [GATED_ID] * g assert len(gates) == len(in_type) super(GatedNonLinearity1, self).__init__() self.space = in_type.gspace self.in_type = in_type self.drop_gates = drop_gates self._contiguous = {} _input_indices = defaultdict(lambda: []) _output_indices = defaultdict(lambda: []) self._nfields = defaultdict(int) self.branching = None for g, r in zip(gates, in_type.representations): if g == GATES_ID: # assert GATES_ID in r.supported_nonlinearities, \ assert r.is_trivial(), \ "Error! Representation \"{}\" can't be a \"gate\"".format(r.name) elif g == GATED_ID: assert GATED_ID in r.supported_nonlinearities, \ 'Error! Representation "{}" does not support "gated" non-linearity'.format(r.name) else: raise ValueError('Error! "{}" type not recognized'.format(g)) ngates = len([g for g in gates if g == GATES_ID]) ngated = len([g for g in gates if g == GATED_ID]) assert ngates == ngated, \ 'Error! Number of gates ({}) does not match the number of gated non-linearities required ({})' \ .format(ngates, ngated) if self.drop_gates: # only gated fields are preserved # therefore, the output representation is computed from the input one, removing the gates self.out_type = in_type.index_select( [i for i, g in enumerate(gates) if g == GATED_ID]) else: self.out_type = in_type in_last_position = 0 out_last_position = 0 last_type = None # group fields by their type (gated or gate) and their size, check if fields of the same type are # contiguous and retrieve the indices of the fields for g, r in zip(gates, in_type.representations): if g == GATES_ID: type = g else: type = r.size self._nfields[r.size] += 1 if type != last_type: if not type in self._contiguous: self._contiguous[type] = True else: self._contiguous[type] = False last_type = type _input_indices[type] += list( range(in_last_position, in_last_position + r.size)) in_last_position += r.size if g != GATES_ID or not self.drop_gates: # since gates are discarded in output, the position on the output fiber is shifted # only when a gated field is met _output_indices[type] += list( range(out_last_position, out_last_position + r.size)) out_last_position += r.size _input_indices = dict(_input_indices) # if self.drop_gates: _output_indices = dict(_output_indices) # else: # self._output_indices = self._input_indices for t, contiguous in self._contiguous.items(): if contiguous: # for contiguous fields, only the first and last indices are kept _input_indices[t] = torch.LongTensor( [min(_input_indices[t]), max(_input_indices[t]) + 1]) if t != GATES_ID or not self.drop_gates: _output_indices[t] = torch.LongTensor( [min(_output_indices[t]), max(_output_indices[t]) + 1]) else: # otherwise, transform the list of indices into a tensor _input_indices[t] = torch.LongTensor(_input_indices[t]) if t != GATES_ID or not self.drop_gates: _output_indices[t] = torch.LongTensor(_output_indices[t]) # register the indices tensors as parameters of this module self.register_buffer('input_indices_{}'.format(t), _input_indices[t]) if t != GATES_ID or not self.drop_gates: self.register_buffer('output_indices_{}'.format(t), _output_indices[t]) # gates need to be distinguished from gated fields _gates_indices = _input_indices.pop(GATES_ID) self.register_buffer('gates_indices', _gates_indices) # build a sorted list of the fields groups, such that every time they are iterated through in the same order self._order = sorted(_input_indices.keys()) # the bias for the gates self.bias = Parameter(torch.randn(1, ngates, 1, 1, dtype=torch.float), requires_grad=True)
def __init__(self, in_type: FieldType, gates: List = None, drop_gates: bool = True, **kwargs): r""" Induced Gated non-linearities. .. todo:: complete documentation! .. note:: Make sure all induced gate and gates have same subgroup Args: in_type (FieldType): the input field type gates (list, optional): list of strings specifying which field in input is a gate and which is a gated field drop_gates (bool, optional): if ``True`` (default), drop the trivial fields after using them to compute the gates. If ``False``, the gates are stacked with the gated fields in the output """ assert isinstance(in_type.gspace, GeneralOnR2) if gates is None: assert len(in_type) % 2 == 0 g = len(in_type) // 2 gates = [GATES_ID] * g + [GATED_ID] * g assert len(gates) == len(in_type) super(InducedGatedNonLinearity1, self).__init__() self.space = in_type.gspace self.in_type = in_type self.drop_gates = drop_gates self._contiguous = {} _input_indices = defaultdict(lambda: []) _output_indices = defaultdict(lambda: []) self._nfields = defaultdict(int) self.branching = None for g, r in zip(gates, in_type.representations): if g == GATES_ID: assert "induced_gate" in r.supported_nonlinearities, \ "Error! Representation \"{}\" can't be a \"gate\"".format(r.name) elif g == GATED_ID: for nl in r.supported_nonlinearities: if nl.startswith("induced_gated"): break else: raise ValueError( 'Error! Representation "{}" does not support "gated" non-linearity' .format(r.name)) else: raise ValueError('Error! "{}" type not recognized'.format(g)) ngates = len([g for g in gates if g == GATES_ID]) ngated = len([g for g in gates if g == GATED_ID]) assert ngates == ngated, \ 'Error! Number of gates ({}) does not match the number of gated non-linearities required ({})' \ .format(ngates, ngated) quotient_size = None for g, r in zip(gates, in_type): if g == GATES_ID: if quotient_size is None: quotient_size = r.size else: assert r.size == quotient_size subfield_sizes = {} for g, r in zip(gates, in_type): if g == GATED_ID: subfield_size = None for nl in r.supported_nonlinearities: if nl.startswith('induced_gated'): assert subfield_size is None, "Error! The representation supports multiple " \ "sub-fields of different sizes" subfield_size = int(nl.split('_')[-1]) assert r.size % subfield_size == 0 assert r.size // subfield_size == quotient_size subfield_sizes[r.name] = subfield_size self.quotient_size = quotient_size if self.drop_gates: # only gated fields are preserved # therefore, the output representation is computed from the input one, removing the gates self.out_type = in_type.index_select( [i for i, g in enumerate(gates) if g == GATED_ID]) else: self.out_type = in_type in_last_position = 0 out_last_position = 0 last_type = None # group fields by their type (gated or gate) and their size, check if fields of the same type are # contiguous and retrieve the indices of the fields for g, r in zip(gates, in_type.representations): if g == GATES_ID: type = g else: type = r.size, subfield_sizes[r.name] self._nfields[type] += 1 if type != last_type: if not type in self._contiguous: self._contiguous[type] = True else: self._contiguous[type] = False last_type = type _input_indices[type] += list( range(in_last_position, in_last_position + r.size)) in_last_position += r.size if g != GATES_ID or not self.drop_gates: # since gates are discarded in output, the position on the output fiber is shifted # only when a gated field is met _output_indices[type] += list( range(out_last_position, out_last_position + r.size)) out_last_position += r.size _input_indices = dict(_input_indices) # if self.drop_gates: _output_indices = dict(_output_indices) # else: # self._output_indices = self._input_indices for t, contiguous in self._contiguous.items(): if contiguous: # for contiguous fields, only the first and last indices are kept _input_indices[t] = torch.LongTensor( [min(_input_indices[t]), max(_input_indices[t]) + 1]) if t != GATES_ID or not self.drop_gates: _output_indices[t] = torch.LongTensor( [min(_output_indices[t]), max(_output_indices[t]) + 1]) else: # otherwise, transform the list of indices into a tensor _input_indices[t] = torch.LongTensor(_input_indices[t]) if t != GATES_ID or not self.drop_gates: _output_indices[t] = torch.LongTensor(_output_indices[t]) # register the indices tensors as parameters of this module self.register_buffer('input_indices_{}'.format(t), _input_indices[t]) if t != GATES_ID or not self.drop_gates: self.register_buffer('output_indices_{}'.format(t), _output_indices[t]) # gates need to be distinguished from gated fields _gates_indices = _input_indices.pop(GATES_ID) self.register_buffer('gates_indices', _gates_indices) # build a sorted list of the fields groups, such that every time they are iterated through in the same order self._order = sorted(_input_indices.keys()) # the bias for the gates self.bias = Parameter(torch.randn(1, ngates, 1, 1, 1, dtype=torch.float), requires_grad=True)