示例#1
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.gspace = Rot2dOnR2(4)
        self.input_type = FieldType(self.gspace, 3*[self.gspace.trivial_repr]) 

        self.small_type = FieldType(self.gspace, 4*[self.gspace.regular_repr])
        self.mid_type = FieldType(self.gspace, 16*[self.gspace.regular_repr])

        self.model = nn.Sequential(
            R2Conv(self.input_type, self.small_type, kernel_size=3, padding=1, bias=False),
            InnerBatchNorm(self.small_type),
            ReLU(self.small_type),

            R2Conv(self.small_type, self.small_type, kernel_size=3, padding=1, bias=False),
            InnerBatchNorm(self.small_type),
            ReLU(self.small_type),

            R2Conv(self.small_type, self.small_type, kernel_size=3, padding=1, bias=False),
            InnerBatchNorm(self.small_type),
            ReLU(self.small_type),

            R2Conv(self.small_type, self.mid_type, kernel_size=3, padding=1, bias=False),
            InnerBatchNorm(self.mid_type),
            ReLU(self.mid_type),

            R2Conv(self.mid_type, self.small_type, kernel_size=3, padding=1, bias=False),
            InnerBatchNorm(self.small_type),
            ReLU(self.small_type),
        )

        self.pool = GroupPooling(self.small_type)
        pool_out = self.pool.out_type.size
        self.final = nn.Conv2d(pool_out, 1, kernel_size=3, padding=1)
示例#2
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.gspace = Rot2dOnR2(-1, maximum_frequency=2)
        self.input_type = FieldType(self.gspace, 3*[self.gspace.trivial_repr]) 

        self.small_type = FieldType(self.gspace, 4*list(self.gspace.irreps.values()))
        self.mid_type = FieldType(self.gspace, 16*list(self.gspace.irreps.values()))

        self.model = nn.Sequential(
            R2Conv(self.input_type, self.small_type, kernel_size=3, padding=1, bias=False),
            GNormBatchNorm(self.small_type),
            NormNonLinearity(self.small_type),

            R2Conv(self.small_type, self.small_type, kernel_size=3, padding=1, bias=False),
            GNormBatchNorm(self.small_type),
            NormNonLinearity(self.small_type),
            
            R2Conv(self.small_type, self.small_type, kernel_size=3, padding=1, bias=False),
            GNormBatchNorm(self.small_type),
            NormNonLinearity(self.small_type),

            R2Conv(self.small_type, self.mid_type, kernel_size=3, padding=1, bias=False),
            GNormBatchNorm(self.mid_type),
            NormNonLinearity(self.mid_type),

            R2Conv(self.mid_type, self.small_type, kernel_size=3, padding=1, bias=False),
            GNormBatchNorm(self.small_type),
            NormNonLinearity(self.small_type),
        )

        self.pool = NormPool(self.small_type)
        pool_out = self.pool.out_type.size
        self.final = nn.Conv2d(pool_out, 1, kernel_size=3, padding=1)
示例#3
0
    def __init__(self, in_type: FieldType, id):
        r"""
        
        Restricts the type of the input to the subgroup identified by ``id``.
        
        It computes the output type in the constructor and wraps the underlying tensor (:class:`torch.Tensor`) in input
        with the output type during the forward pass.
        
        This module only acts as a wrapper for :meth:`e2cnn.nn.FieldType.restrict`
        (or :meth:`e2cnn.nn.GeometricTensor.restrict`).
        The accepted values of ``id`` depend on the underlying ``gspace`` in the input type ``in_type``; check the
        documentation of the method :meth:`e2cnn.gspaces.GSpace.restrict` of the gspace used for
        further information.
        
        
        .. seealso::
            meth:`e2cnn.nn.FieldType.restrict`, meth:`e2cnn.nn.GeometricTensor.restrict`,
            meth:`e2cnn.gspaces.GSpace.restrict`
        
        Args:
            in_type (FieldType): the input field type
            id: a valid id for a subgroup of the space associated with the input type
            
        """
        assert isinstance(in_type, FieldType)
        assert isinstance(in_type.gspace, GeneralOnR2)

        super(EquivariantModule, self).__init__()

        self._id = id
        self.in_type = in_type
        self.out_type = in_type.restrict(id)
示例#4
0
    def _transform_fiber_representation(in_type: FieldType) -> FieldType:
        r"""
        
        Compute the output representation from the input one after applying the concatenated non-linearity.
        
        Args:
            in_type (FieldType): the input field type

        Returns:
            (FieldType): the new output field type
            
        """
        transformed = {}

        # transform each different input Representation
        for repr in in_type._unique_representations:
            transformed[
                repr] = ConcatenatedNonLinearity._transform_representation(
                    repr)

        new_representations = []

        # concatenate the new representations
        for repr in in_type.representations:
            new_representations.append(transformed[repr])

        return FieldType(in_type.gspace, new_representations)
示例#5
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.gspace = Rot2dOnR2(-1, maximum_frequency=2)
        self.input_type = FieldType(self.gspace, 3*[self.gspace.trivial_repr]) 

        layers = []

        irreps = [ v for k, v in self.gspace.irreps.items() if k != self.gspace.trivial_repr.name]

        trivials = FieldType(self.gspace, [self.gspace.trivial_repr]*10)
        gates = FieldType(self.gspace, len(irreps) * [self.gspace.trivial_repr]*10)
        gated = FieldType(self.gspace, irreps*10).sorted()
        gate = gates + gated

        self.small_type = trivials + gate

        layers.append(
            R2Conv(self.input_type, self.small_type, kernel_size=3, padding=1, bias=False)
        )
        layers.append( 
            MultipleModule(layers[-1].out_type,
            labels=[
                 *(["trivial"] * (len(trivials) + len(gates)) + ["gated"] * len(gated))
            ],
            modules=[
                (InnerBatchNorm(trivials + gates), 'trivial'),
                (NormBatchNorm(gated), 'gated')
            ])
        )
        layers.append(
            MultipleModule(layers[-1].out_type,
            labels=[
                *(["trivial"] * len(trivials) + ["gate"] * len(gate))
            ], 
            modules=[
                (ReLU(trivials), 'trivial'),
                (GatedNonLinearity1(gate), 'gate')
            ])

        )

        self.model = nn.Sequential(*layers)

        self.pool = NormPool(layers[-1].out_type)
        pool_out = self.pool.out_type.size
        self.final = nn.Conv2d(pool_out, 1, kernel_size=3, padding=1)
示例#6
0
    def __init__(
        self,
        modules: List[Tuple[EquivariantModule, Union[str, List[str]]]],
    ):
        r"""
        
        Applies different modules to multiple tensors in input.
        
        ``modules`` contains a list of pairs, each containing an :class:`~e2cnn.nn.EquivariantModule` and a label (or a
        list of labels).
        
        This module takes as input a dictionary mapping labels to tensors.
        Then, each module in ``modules`` is applied to the tensors associated to its labels.
        Finally, output tensors are stacked together.
        
        .. todo ::
            Technically this is not an EquivariantModule as the input is not a single tensor.
            Either fix EquivariantModule to support multiple inputs and outputs or set this as just subclass of
            torch.nn.Module.
        
        Args:
            modules (list): list of modules to apply to the labeled input tensors
            
        """

        super(MergeModule, self).__init__()

        labels = []

        for i in range(len(modules)):
            if isinstance(modules[i][1], str):
                modules[i] = (modules[i][0], [modules[i][1]])
            else:
                assert isinstance(modules[i][1], list)
                for s in modules[i][1]:
                    assert isinstance(s, str)

            labels += modules[i][1]

        self.in_type = None
        self.gspace = modules[0][0].in_type.gspace

        out_repr = []

        for module, labels in modules:
            if isinstance(module.in_type, tuple):
                assert all(t.gspace == self.gspace for t in module.in_type)
            else:
                assert module.in_type.gspace == self.gspace

            out_repr += module.out_type.representations

        self.out_type = FieldType(self.gspace, out_repr)

        self._labels = []
        # add the input modules as sub-modules
        for i, (module, labels) in enumerate(modules):
            self._labels.append(labels)
            self.add_module('submodule_{}'.format(i), module)
示例#7
0
    def get_features(self, in_channels, n_features):
        features = super().get_features(in_channels, n_features)

        for module, (n_in, n_out) in features.items():
            if module == 'encoder1':
                features[module] = (FieldType(
                    self.gspace, n_in * [self.gspace.trivial_repr]),
                                    FieldType(
                                        self.gspace,
                                        n_out * [self.gspace.regular_repr]))
            else:
                features[module] = (FieldType(
                    self.gspace, n_in * [self.gspace.regular_repr]),
                                    FieldType(
                                        self.gspace,
                                        n_out * [self.gspace.regular_repr]))
        return features
示例#8
0
    def __init__(self, in_type: FieldType):
        r"""
        
        Disentangles the representations in the field type of the input.
        
        This module only acts as a wrapper for :func:`e2cnn.group.disentangle`.
        In the constructor, it disentangles each representation in the input type to build the output type and
        pre-compute the change of basis matrices needed to transform each input field.
        
        During the forward pass, each field in the input tensor is transformed with the change of basis corresponding
        to its representation.
        
        Args:
            in_type (FieldType): the input field type
            
        """
        assert isinstance(in_type, FieldType)
        assert isinstance(in_type.gspace, GeneralOnR2)
        
        super(EquivariantModule, self).__init__()
        
        self.in_type = in_type

        disentangled_representations = {}
        
        _change_of_basis_matrices = {}
        self._sizes = {}
        
        for r in self.in_type._unique_representations:
            self._sizes[r.name] = r.size
            cob, reprs = disentangle(r)
            disentangled_representations[r.name] = reprs
            _change_of_basis_matrices[r.name] = torch.FloatTensor(cob)
            self.register_buffer('change_of_basis_{}'.format(r.name), _change_of_basis_matrices[r.name])

        out_reprs = []
        self._nfields = defaultdict(int)
        for r in self.in_type.representations:
            self._nfields[r.name] += 1
            out_reprs += disentangled_representations[r.name]

        self.out_type = FieldType(self.in_type.gspace, out_reprs)

        grouped_indices = indexes_from_labels(self.in_type, [r.name for r in self.in_type.representations])
        
        self._order = []
        self._contiguous = {}
        
        for repr_name, (contiguous, fields_indices, fiber_indices) in grouped_indices.items():
            self._order.append(repr_name)
            self._contiguous[repr_name] = contiguous
            
            if contiguous:
                fiber_indices = (min(fiber_indices), max(fiber_indices)+1)
            
            fiber_indices = torch.LongTensor(fiber_indices)
            self.register_buffer(f"fiber_indices_{repr_name}", fiber_indices)
示例#9
0
    def __init__(self, in_type: FieldType, **kwargs):
        r"""
        
        Module that implements Norm Pooling.
        For each input field, an output one is built by taking the norm of that field; as a result, the output
        field transforms according to a trivial representation.
        
        Args:
            in_type (FieldType): the input field type
            
        """
        assert isinstance(in_type.gspace, GeneralOnR2)

        super(NormPool, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type

        # build the output representation substituting each input field with a trivial representation
        self.out_type = FieldType(self.space,
                                  [self.space.trivial_repr] * len(in_type))

        # indices of the channels corresponding to fields belonging to each group in the input representation
        _in_indices = defaultdict(lambda: [])
        # indices of the channels corresponding to fields belonging to each group in the output representation
        _out_indices = defaultdict(lambda: [])

        # whether each group of fields is contiguous or not
        self._contiguous = {}

        # group fields by their size and
        #   - check if fields of the same size are contiguous
        #   - retrieve the indices of the fields
        indeces = indexes_from_labels(
            in_type, [r.size for r in in_type.representations])

        for s, (contiguous, fields, idxs) in indeces.items():
            self._contiguous[s] = contiguous
            if contiguous:
                # for contiguous fields, only the first and last indices are kept
                _in_indices[s] = torch.LongTensor([min(idxs), max(idxs) + 1])
                _out_indices[s] = torch.LongTensor(
                    [min(fields), max(fields) + 1])
            else:
                # otherwise, transform the list of indices into a tensor
                _in_indices[s] = torch.LongTensor(idxs)
                _out_indices[s] = torch.LongTensor(fields)

            # register the indices tensors as parameters of this module
            self.register_buffer('in_indices_{}'.format(s), _in_indices[s])
            self.register_buffer('out_indices_{}'.format(s), _out_indices[s])
示例#10
0
    def __init__(self,
                 in_type: FieldType,
                 permutation: List[int]
                 ):
        r"""
        
        Permutes the fields of the input tensor according to the input ``permutation``.
        
        The parameter ``permutation`` should be a list containing a permutation of the integers ``{0, 1, ..., n-1}``,
        where ``n`` is the number of fields of ``in_type`` (see :meth:`e2cnn.nn.FieldType.__len__`).
        
        
        Args:
            in_type (FieldType): input field type
            permutation (list): permutation to apply
            
        """
    
        assert isinstance(in_type.gspace, GeneralOnR2)
        
        super(ReshuffleModule, self).__init__()
        
        # check if it is actually a permutation of the fields
        assert set(permutation) == set(range(0, len(in_type.representations)))
        
        self.space = in_type.gspace
        self.in_type = in_type
        
        # permute the fields in the input representation to build the output representation
        self.out_type = in_type.index_select(permutation)
        
        # compute the starting position of each field in the input representation
        positions = []
        last_position = 0
        for r in in_type.representations:
            positions.append(last_position)
            last_position += r.size

        # compute the indices for the permutation
        
        indices = []
        for c in permutation:
            size = in_type.representations[c].size
            
            indices += list(range(positions[c], positions[c]+size))
        
        self.register_buffer("indices", torch.LongTensor(indices))
示例#11
0
    def __init__(self, in_type: FieldType, **kwargs):
        r"""
        
        VectorField non-linearities.
        This non-linearity only supports the regular representation of cyclic group :math:`C_N`, i.e. the group of
        :math:`N` discrete rotations.
        For each input field, the output one is built by taking the rotation associated with the highest
        activation; then, a 2-dimensional vector with an angle with respect to the x-axis equal to that rotation and a
        length equal to its activation is set in the output field.
        
        Args:
            in_type (FieldType): the input field type
            
        """
        assert isinstance(in_type.gspace, GeneralOnR2)

        assert in_type.gspace.fibergroup.order() > 1

        for r in in_type.representations:
            assert 'vectorfield' in r.supported_nonlinearities,\
                'Error! Representation "{}" does not support "vector-field" non-linearity'.format(r.name)

            assert r.name == 'regular' and r.size == in_type.gspace.fibergroup.order(
            ), r.name

        super(VectorFieldNonLinearity, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type

        # build the output representation substituting each input field with a rotation representation with frequency 1
        self.out_type = FieldType(
            self.space, [self.space.representations['irrep_1']] * len(in_type))

        # the number of rotations associated with the group action
        self._rotations = self.space.fibergroup.order()
示例#12
0
    def __init__(self,
                 in_type: FieldType,
                 gates: List = None,
                 drop_gates: bool = True,
                 **kwargs):
        r"""
        
        Gated non-linearities.
        This module applies a bias and a sigmoid function of the gates fields and, then, multiplies each gated
        field by one of the gates.
        
        The input representation of the gated fields is preserved by this operation while the gate fields are
        discarded.
        
        The gates and the gated fields are provided in one unique input tensor and, therefore, :attr:`in_repr` should
        be the representation of the fiber containing both gates and gated fields.
        Moreover, the parameter :attr:`gates` needs to be set with a list long as the total number of fields,
        containing in a position ``i`` the string ``"gate"`` if the ``i``-th field is a gate or the string ``"gated"``
        if the ``i``-th field is a gated field. No other strings are allowed.
        By default (``gates = None``), the first half of the fields is assumed to contain the gates (and, so, these
        fields have to be trivial fields) while the second one is assumed to contain the gated fields.
        
        In any case, the number of gates and the number of gated fields have to match (therefore, the number of
        fields has to be an even number).
        
        Args:
            in_type (FieldType): the input field type
            gates (list, optional): list of strings specifying which field in input is a gate and which is a gated field
            drop_gates (bool, optional): if ``True`` (default), drop the trivial fields after using them to compute
                    the gates. If ``False``, the gates are stacked with the gated fields in the output
            
        """

        assert isinstance(in_type.gspace, GeneralOnR2)

        if gates is None:
            assert len(in_type) % 2 == 0

            g = len(in_type) // 2
            gates = [GATES_ID] * g + [GATED_ID] * g

        assert len(gates) == len(in_type)

        super(GatedNonLinearity1, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type

        self.drop_gates = drop_gates

        self._contiguous = {}
        _input_indices = defaultdict(lambda: [])
        _output_indices = defaultdict(lambda: [])

        self._nfields = defaultdict(int)

        self.branching = None

        for g, r in zip(gates, in_type.representations):
            if g == GATES_ID:
                # assert GATES_ID in r.supported_nonlinearities, \
                assert r.is_trivial(), \
                    "Error! Representation \"{}\" can't be a \"gate\"".format(r.name)
            elif g == GATED_ID:
                assert GATED_ID in r.supported_nonlinearities, \
                    'Error! Representation "{}" does not support "gated" non-linearity'.format(r.name)
            else:
                raise ValueError('Error! "{}" type not recognized'.format(g))

        ngates = len([g for g in gates if g == GATES_ID])
        ngated = len([g for g in gates if g == GATED_ID])

        assert ngates == ngated, \
            'Error! Number of gates ({}) does not match the number of gated non-linearities required ({})' \
            .format(ngates, ngated)

        if self.drop_gates:
            # only gated fields are preserved
            # therefore, the output representation is computed from the input one, removing the gates
            self.out_type = in_type.index_select(
                [i for i, g in enumerate(gates) if g == GATED_ID])
        else:
            self.out_type = in_type

        in_last_position = 0
        out_last_position = 0
        last_type = None

        # group fields by their type (gated or gate) and their size, check if fields of the same type are
        # contiguous and retrieve the indices of the fields
        for g, r in zip(gates, in_type.representations):
            if g == GATES_ID:
                type = g
            else:
                type = r.size
                self._nfields[r.size] += 1

            if type != last_type:
                if not type in self._contiguous:
                    self._contiguous[type] = True
                else:
                    self._contiguous[type] = False
            last_type = type

            _input_indices[type] += list(
                range(in_last_position, in_last_position + r.size))
            in_last_position += r.size

            if g != GATES_ID or not self.drop_gates:
                # since gates are discarded in output, the position on the output fiber is shifted
                # only when a gated field is met
                _output_indices[type] += list(
                    range(out_last_position, out_last_position + r.size))
                out_last_position += r.size

        _input_indices = dict(_input_indices)
        # if self.drop_gates:
        _output_indices = dict(_output_indices)
        # else:
        #     self._output_indices = self._input_indices

        for t, contiguous in self._contiguous.items():
            if contiguous:
                # for contiguous fields, only the first and last indices are kept
                _input_indices[t] = torch.LongTensor(
                    [min(_input_indices[t]),
                     max(_input_indices[t]) + 1])
                if t != GATES_ID or not self.drop_gates:
                    _output_indices[t] = torch.LongTensor(
                        [min(_output_indices[t]),
                         max(_output_indices[t]) + 1])
            else:
                # otherwise, transform the list of indices into a tensor
                _input_indices[t] = torch.LongTensor(_input_indices[t])

                if t != GATES_ID or not self.drop_gates:
                    _output_indices[t] = torch.LongTensor(_output_indices[t])

            # register the indices tensors as parameters of this module
            self.register_buffer('input_indices_{}'.format(t),
                                 _input_indices[t])
            if t != GATES_ID or not self.drop_gates:
                self.register_buffer('output_indices_{}'.format(t),
                                     _output_indices[t])

        # gates need to be distinguished from gated fields
        _gates_indices = _input_indices.pop(GATES_ID)
        self.register_buffer('gates_indices', _gates_indices)

        # build a sorted list of the fields groups, such that every time they are iterated through in the same order
        self._order = sorted(_input_indices.keys())

        # the bias for the gates
        self.bias = Parameter(torch.randn(1, ngates, 1, 1, dtype=torch.float),
                              requires_grad=True)
示例#13
0
    def __init__(
        self,
        in_type: FieldType,
        out_type: FieldType,
        kernel_size: int,
        padding: int = 0,
        output_padding: int = 0,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        basisexpansion: str = 'blocks',
        sigma: Union[List[float], float] = None,
        frequencies_cutoff: Union[float, Callable[[float], int]] = None,
        rings: List[float] = None,
        maximum_offset: int = None,
        recompute: bool = False,
        basis_filter: Callable[[dict], bool] = None,
        initialize: bool = True,
    ):
        r"""
        
        Transposed G-steerable planar convolution layer.
        
        .. warning ::
            
            Transposed convolution can produce artifacts which can harm the overall equivariance of the model.
            We suggest using :class:`~e2cnn.nn.R2Upsampling` combined with :class:`~e2cnn.nn.R2Conv` to perform
            upsampling.
        
        .. seealso ::
            For additional information about the parameters and the methods of this class, see :class:`e2cnn.nn.R2Conv`.
            The two modules are essentially the same, except for the type of convolution used.
            
        Args:
            in_type (FieldType): the type of the input field
            out_type (FieldType): the type of the output field
            kernel_size (int): the size of the filter
            padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0``
            output_padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0``
            stride(int, optional): the stride of the convolving kernel. Default: ``1``
            dilation(int, optional): the spacing between kernel elements. Default: ``1``
            groups (int, optional): number of blocked connections from input channels to output channels.
                                    Default: ``1``.
            bias (bool, optional): Whether to add a bias to the output (only to fields which contain a
                    trivial irrep) or not. Default ``True``
            initialize (bool, optional): initialize the weights of the model. Default: ``True``
            
        """

        assert in_type.gspace == out_type.gspace
        assert isinstance(in_type.gspace, GeneralOnR2)

        super(R2ConvTransposed, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type
        self.out_type = out_type

        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.output_padding = output_padding
        self.groups = groups

        if groups > 1:
            # Check the input and output classes can be split in `groups` groups, all equal to each other
            # first, check that the number of fields is divisible by `groups`
            assert len(in_type) % groups == 0
            assert len(out_type) % groups == 0
            in_size = len(in_type) // groups
            out_size = len(out_type) // groups

            # then, check that all groups are equal to each other, i.e. have the same types in the same order
            assert all(
                in_type.representations[i] == in_type.representations[i %
                                                                      in_size]
                for i in range(len(in_type)))
            assert all(out_type.representations[i] == out_type.representations[
                i % out_size] for i in range(len(out_type)))

            # finally, retrieve the type associated to a single group in output.
            # this type will be used to build a smaller kernel basis and a smaller filter
            # as in PyTorch, to build a filter for grouped convolution, we build a filter which maps from all input
            # groups to one output group. Then, PyTorch's standard convolution routine interpret this filter as `groups`
            # different filters, each mapping an input group to an output group.
            out_type = out_type.index_select(list(range(out_size)))

        if bias:
            # bias can be applied only to trivial irreps inside the representation
            # to apply bias to a field we learn a bias for each trivial irreps it contains
            # and, then, we transform it with the change of basis matrix to be able to apply it to the whole field
            # this is equivalent to transform the field to its irreps through the inverse change of basis,
            # sum the bias only to the trivial irrep and then map it back with the change of basis

            # count the number of trivial irreps
            trivials = 0
            for r in self.out_type:
                for irr in r.irreps:
                    if self.out_type.fibergroup.irreps[irr].is_trivial():
                        trivials += 1

            # if there is at least 1 trivial irrep
            if trivials > 0:

                # matrix containing the columns of the change of basis which map from the trivial irreps to the
                # field representations. This matrix allows us to map the bias defined only over the trivial irreps
                # to a bias for the whole field more efficiently
                bias_expansion = torch.zeros(self.out_type.size, trivials)

                p, c = 0, 0
                for r in self.out_type:
                    pi = 0
                    for irr in r.irreps:
                        irr = self.out_type.fibergroup.irreps[irr]
                        if irr.is_trivial():
                            bias_expansion[p:p + r.size, c] = torch.tensor(
                                r.change_of_basis[:, pi])
                            c += 1
                        pi += irr.size
                    p += r.size

                self.register_buffer("bias_expansion", bias_expansion)
                self.bias = Parameter(torch.zeros(trivials),
                                      requires_grad=True)
                self.register_buffer("expanded_bias",
                                     torch.zeros(out_type.size))
            else:
                self.bias = None
                self.expanded_bias = None
        else:
            self.bias = None
            self.expanded_bias = None

        grid, basis_filter, rings, sigma, maximum_frequency = compute_basis_params(
            kernel_size, frequencies_cutoff, rings, sigma, dilation,
            basis_filter)

        # BasisExpansion: submodule which takes care of building the filter
        self._basisexpansion = None

        # notice that `out_type` is used instead of `self.out_type` such that it works also when `groups > 1`
        if basisexpansion == 'blocks':
            self._basisexpansion = BlocksBasisExpansion(
                in_type,
                out_type,
                grid,
                sigma=sigma,
                rings=rings,
                maximum_offset=maximum_offset,
                maximum_frequency=maximum_frequency,
                basis_filter=basis_filter,
                recompute=recompute)

        else:
            raise ValueError('Basis Expansion algorithm "%s" not recognized' %
                             basisexpansion)

        self.weights = Parameter(torch.zeros(self.basisexpansion.dimension()),
                                 requires_grad=True)
        self.register_buffer(
            "filter",
            torch.zeros(in_type.size, out_type.size, kernel_size, kernel_size))

        if initialize:
            # by default, the weights are initialized with a generalized form of he's weight initialization
            init.generalized_he_init(self.weights.data, self.basisexpansion)
示例#14
0
    def __init__(self, in_type: FieldType, **kwargs):
        r"""
        
        Module that implements *group pooling*.
        This module only supports permutation representations such as regular representation,
        quotient representation or trivial representation (though, in the last case, this module
        acts as identity).
        For each input field, an output field is built by taking the maximum activation within that field; as a result,
        the output field transforms according to a trivial representation.
        
        .. seealso::
            :attr:`~e2cnn.group.Group.regular_representation`,
            :attr:`~e2cnn.group.Group.quotient_representation`
        
        Args:
            in_type (FieldType): the input field type
            
        """
        assert isinstance(in_type.gspace, GeneralOnR2)

        for r in in_type.representations:
            assert 'pointwise' in r.supported_nonlinearities,\
                'Error! Representation "{}" does not support "pointwise" non-linearity'.format(r.name)

        super(GroupPooling, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type

        # build the output representation substituting each input field with a trivial representation
        self.out_type = FieldType(self.space,
                                  [self.space.trivial_repr] * len(in_type))

        # indices of the channels corresponding to fields belonging to each group in the input representation
        _in_indices = defaultdict(lambda: [])
        # indices of the channels corresponding to fields belonging to each group in the output representation
        _out_indices = defaultdict(lambda: [])

        # whether each group of fields is contiguous or not
        self._contiguous = {}

        # group fields by their size and
        #   - check if fields of the same size are contiguous
        #   - retrieve the indices of the fields
        indeces = indexes_from_labels(
            in_type, [r.size for r in in_type.representations])

        for s, (contiguous, fields, idxs) in indeces.items():
            self._contiguous[s] = contiguous
            if contiguous:
                # for contiguous fields, only the first and last indices are kept
                _in_indices[s] = torch.LongTensor([min(idxs), max(idxs) + 1])
                _out_indices[s] = torch.LongTensor(
                    [min(fields), max(fields) + 1])
            else:
                # otherwise, transform the list of indices into a tensor
                _in_indices[s] = torch.LongTensor(idxs)
                _out_indices[s] = torch.LongTensor(fields)

            # register the indices tensors as parameters of this module
            self.register_buffer('in_indices_{}'.format(s), _in_indices[s])
            self.register_buffer('out_indices_{}'.format(s), _out_indices[s])
示例#15
0
    def __init__(self,
                 in_type: FieldType,
                 gates: List = None,
                 drop_gates: bool = True,
                 **kwargs):
        r"""
        
        Induced Gated non-linearities.
        
        .. todo::
            complete documentation!
        
        .. note::
            Make sure all induced gate and gates have same subgroup
            
        
        Args:
            in_type (FieldType): the input field type
            gates (list, optional): list of strings specifying which field in input is a gate and which is a gated field
            drop_gates (bool, optional): if ``True`` (default), drop the trivial fields after using them to compute
                    the gates. If ``False``, the gates are stacked with the gated fields in the output
            
        """

        assert isinstance(in_type.gspace, GeneralOnR2)

        if gates is None:
            assert len(in_type) % 2 == 0

            g = len(in_type) // 2
            gates = [GATES_ID] * g + [GATED_ID] * g

        assert len(gates) == len(in_type)

        super(InducedGatedNonLinearity1, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type

        self.drop_gates = drop_gates

        self._contiguous = {}
        _input_indices = defaultdict(lambda: [])
        _output_indices = defaultdict(lambda: [])

        self._nfields = defaultdict(int)

        self.branching = None

        for g, r in zip(gates, in_type.representations):
            if g == GATES_ID:
                assert "induced_gate" in r.supported_nonlinearities, \
                    "Error! Representation \"{}\" can't be a \"gate\"".format(r.name)
            elif g == GATED_ID:
                for nl in r.supported_nonlinearities:
                    if nl.startswith("induced_gated"):
                        break
                else:
                    raise ValueError(
                        'Error! Representation "{}" does not support "gated" non-linearity'
                        .format(r.name))
            else:
                raise ValueError('Error! "{}" type not recognized'.format(g))

        ngates = len([g for g in gates if g == GATES_ID])
        ngated = len([g for g in gates if g == GATED_ID])

        assert ngates == ngated, \
            'Error! Number of gates ({}) does not match the number of gated non-linearities required ({})' \
            .format(ngates, ngated)

        quotient_size = None
        for g, r in zip(gates, in_type):
            if g == GATES_ID:
                if quotient_size is None:
                    quotient_size = r.size
                else:
                    assert r.size == quotient_size

        subfield_sizes = {}
        for g, r in zip(gates, in_type):
            if g == GATED_ID:
                subfield_size = None
                for nl in r.supported_nonlinearities:
                    if nl.startswith('induced_gated'):
                        assert subfield_size is None, "Error! The representation supports multiple " \
                                                      "sub-fields of different sizes"
                        subfield_size = int(nl.split('_')[-1])
                        assert r.size % subfield_size == 0
                        assert r.size // subfield_size == quotient_size
                        subfield_sizes[r.name] = subfield_size

        self.quotient_size = quotient_size

        if self.drop_gates:
            # only gated fields are preserved
            # therefore, the output representation is computed from the input one, removing the gates
            self.out_type = in_type.index_select(
                [i for i, g in enumerate(gates) if g == GATED_ID])
        else:
            self.out_type = in_type

        in_last_position = 0
        out_last_position = 0
        last_type = None

        # group fields by their type (gated or gate) and their size, check if fields of the same type are
        # contiguous and retrieve the indices of the fields
        for g, r in zip(gates, in_type.representations):
            if g == GATES_ID:
                type = g
            else:
                type = r.size, subfield_sizes[r.name]
                self._nfields[type] += 1

            if type != last_type:
                if not type in self._contiguous:
                    self._contiguous[type] = True
                else:
                    self._contiguous[type] = False
            last_type = type

            _input_indices[type] += list(
                range(in_last_position, in_last_position + r.size))
            in_last_position += r.size

            if g != GATES_ID or not self.drop_gates:
                # since gates are discarded in output, the position on the output fiber is shifted
                # only when a gated field is met
                _output_indices[type] += list(
                    range(out_last_position, out_last_position + r.size))
                out_last_position += r.size

        _input_indices = dict(_input_indices)
        # if self.drop_gates:
        _output_indices = dict(_output_indices)
        # else:
        #     self._output_indices = self._input_indices

        for t, contiguous in self._contiguous.items():
            if contiguous:
                # for contiguous fields, only the first and last indices are kept
                _input_indices[t] = torch.LongTensor(
                    [min(_input_indices[t]),
                     max(_input_indices[t]) + 1])
                if t != GATES_ID or not self.drop_gates:
                    _output_indices[t] = torch.LongTensor(
                        [min(_output_indices[t]),
                         max(_output_indices[t]) + 1])
            else:
                # otherwise, transform the list of indices into a tensor
                _input_indices[t] = torch.LongTensor(_input_indices[t])

                if t != GATES_ID or not self.drop_gates:
                    _output_indices[t] = torch.LongTensor(_output_indices[t])

            # register the indices tensors as parameters of this module
            self.register_buffer('input_indices_{}'.format(t),
                                 _input_indices[t])
            if t != GATES_ID or not self.drop_gates:
                self.register_buffer('output_indices_{}'.format(t),
                                     _output_indices[t])

        # gates need to be distinguished from gated fields
        _gates_indices = _input_indices.pop(GATES_ID)
        self.register_buffer('gates_indices', _gates_indices)

        # build a sorted list of the fields groups, such that every time they are iterated through in the same order
        self._order = sorted(_input_indices.keys())

        # the bias for the gates
        self.bias = Parameter(torch.randn(1,
                                          ngates,
                                          1,
                                          1,
                                          1,
                                          dtype=torch.float),
                              requires_grad=True)
示例#16
0
    def __init__(self,
                 in_type: FieldType,
                 labels: List[str],
                 reshuffle: int = 0):
        r"""
        
        Splits the input tensor in multiple branches identified by the input ``labels``.
        A label is associated to each field in the input class.
        During forward, fields are grouped by the labels and the input tensor is split accordingly,
        returning a dictionary mapping labels to tensors.
        
        If ``reshuffle`` is set to a positive integer, this module first builds a copy of the input tensor sorting the
        fields according to the value set:
        
        - 1: fields are sorted by their labels
        
        - 2: fields are sorted by their labels and, then, by their size
        
        - 3: fields are sorted by their labels, by their size and, then, by their type
        
        In this way, fields that need to be retrieved together are contiguous and it is possible to exploit slicing
        to split the tensor.
        By default, ``reshuffle = 0`` which means that no sorting is performed and, so, if input
        fields are not contiguous this layer will use indexing to retrieve sub-tensors.
        
        .. todo ::
            Technically this is not an EquivariantModule as the output is not a single tensor.
            Either fix EquivariantModule to support multiple inputs and outputs or set this as just subclass of
            torch.nn.Module.
        
        Args:
            in_type (FieldType): the input class
            labels (list): the list of labels to group the fields
            reshuffle (int, optional): set how to reshuffle the input fields before splitting the tensor.
                                       By default (``0``) no reshuffling is done.
            
        """

        assert isinstance(in_type.gspace, GeneralOnR2)

        assert 0 <= reshuffle < 3

        super(BranchingModule, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type
        self.reshuffle_layer = None

        self._labels = set(labels)

        total_fields = len(in_type.representations)

        assert total_fields == len(labels), \
            'Error! Number of labels ({}) does not match number of fields ({})'.format(len(labels), total_fields)

        # If the user required to sort the input representation build
        # a ReshuffleLayer to apply at the beginning of the forward
        if reshuffle > 0:

            # fields are sorted, in order of priority, by the non-linearity applied, their size and their name
            # according to the reshuffle set
            keys = []
            c = 0
            for l in labels:
                # build an array containing the sorting keys and the fields' positions
                keys.append((l, in_type.representations[c].size,
                             in_type.representations[c].name, c))
                c += 1

            # sort the keys list to build the fields permutation
            keys = sorted(keys, key=lambda x: x[:reshuffle])
            permutation = [k[3] for k in keys]

            # if the fields were already sorted, it is useless to add the ReshuffleLayer
            if permutation != list(range(len(in_type.representations))):
                # build the ReshuffleLayer
                self.reshuffle_layer = ReshuffleModule(self.in_type,
                                                       permutation)

                # add the reshuffle layer to the sub-modules
                self.add_module('reshuffle', self.reshuffle_layer)

                # set the input representation to consider for the non-linearities to the sorted one
                # (i.e. the output of the ReshuffleLayer)
                in_type = self.reshuffle_layer.out_type

                # permute the non-linearities list accordingly
                labels = [labels[p] for p in permutation]

        # for each label, build the representation of the sub-fiber on which it acts
        self.out_type = in_type.group_by_labels(labels)

        # check which non-linearity has all its fields consecutive
        self._contiguous = {}

        last_label = None
        for l in labels:
            if l != last_label:
                if not l in self._contiguous:
                    self._contiguous[l] = True
                else:
                    self._contiguous[l] = False

            last_label = l

        _input_indices = defaultdict(lambda: [])

        fields = defaultdict(lambda: [])

        # for each label, compute:
        #   - the set of indices on the fiber of its fields and
        #   - the the indices of the fields belonging to it
        c = 0
        last_position = 0
        for l in labels:
            # append the indices of the current field
            _input_indices[l] += list(
                range(last_position,
                      last_position + in_type.representations[c].size))

            # append the index of the current field to the list of fields belonging to this label
            fields[l].append(c)

            # move on the fiber
            last_position += in_type.representations[c].size

            # move to the next field
            c += 1

        for l, contiguous in self._contiguous.items():
            if contiguous:
                # for labels with contiguous fields, only the first and the last indices are preserved
                _input_indices[l] = torch.LongTensor(
                    [min(_input_indices[l]),
                     max(_input_indices[l]) + 1])

            else:
                # for the others, the indices list is trasformed into a PyTorch's Tensor
                _input_indices[l] = torch.LongTensor(_input_indices[l])

            # register the indices tensors as parameters of this module
            self.register_buffer('indices_{}'.format(l), _input_indices[l])
示例#17
0
    def __init__(
        self,
        in_type: FieldType,
        out_type: FieldType,
        kernel_size: int,
        padding: int = 0,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        basisexpansion: str = 'blocks',
        sigma: Union[List[float], float] = None,
        frequencies_cutoff: Union[float, Callable[[float], int]] = None,
        rings: List[float] = None,
        maximum_offset: int = None,
        recompute: bool = False,
        basis_filter: Callable[[dict], bool] = None,
    ):
        r"""
        
        
        G-steerable planar convolution mapping between the input and output :class:`~e2cnn.nn.FieldType` s specified by
        the parameters ``in_type`` and ``out_type``.
        This operation is equivariant under the action of :math:`\R^2\rtimes G` where :math:`G` is the
        :attr:`e2cnn.nn.FieldType.fibergroup` of ``in_type`` and ``out_type``.
        
        Specifically, let :math:`\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}` and
        :math:`\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}` be the representations specified by the input and output
        field types.
        Then :class:`~e2cnn.nn.R2Conv` guarantees an equivariant mapping
        
        .. math::
            \kappa \star [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [\kappa \star f] \qquad\qquad \forall g \in G, u \in \R^2
            
        where the transformation of the input and output fields are given by
 
        .. math::
            [\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\
            [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\

        The equivariance of G-steerable convolutions is guaranteed by restricting the space of convolution kernels to an
        equivariant subspace.
        As proven in `3D Steerable CNNs <https://arxiv.org/abs/1807.02547>`_, this parametrizes the *most general
        equivariant convolutional map* between the input and output fields.
        For feature fields on :math:`\R^2` (e.g. images), the complete G-steerable kernel spaces for :math:`G \leq \O2`
        is derived in `General E(2)-Equivariant Steerable CNNs <https://arxiv.org/abs/1911.08251>`_.

        During training, in each forward pass the module expand the basis of G-steerable kernels with learned weights
        before calling :func:`torch.nn.functional.conv2d`.
        When :meth:`~troch.nn.Module.eval()` is called, the filter is built with the current trained weights and stored
        for future reuse such that no overhead of expanding the kernel remains.
 
        
        The parameters ``basisexpansion``, ``sigma``, ``frequencies_cutoff``, ``rings`` and ``maximum_offset`` are
        optional parameters used to control how the basis for the filters is built, how it is sampled on the filter
        grid and how it is expanded to build the filter. We suggest to keep these default values.
        
        
        Args:
            in_type (FieldType): the type of the input field, specifying its transformation law
            out_type (FieldType): the type of the output field, specifying its transformation law
            kernel_size (int): the size of the (square) filter
            padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0``
            stride(int, optional): the stride of the kernel. Default: ``1``
            dilation(int, optional): the spacing between kernel elements. Default: ``1``
            groups (int, optional): number of blocked connections from input channels to output channels.
                                    It allows depthwise convolution. When used, the input and output types need to be
                                    divisible in ``groups`` groups, all equal to each other.
                                    Default: ``1``.
            bias (bool, optional): Whether to add a bias to the output (only to fields which contain a
                    trivial irrep) or not. Default ``True``
            basisexpansion (str, optional): the basis expansion algorithm to use
            sigma (list or float, optional): width of each ring where the bases are sampled. If only one scalar
                    is passed, it is used for all rings.
            frequencies_cutoff (callable or float, optional): function mapping the radii of the basis elements to the
                    maximum frequency accepted. If a float values is passed, the maximum frequency is equal to the
                    radius times this factor. By default (``None``), a more complex policy is used.
            rings (list, optional): radii of the rings where to sample the bases
            maximum_offset (int, optional): number of additional (aliased) frequencies in the intertwiners for finite
                    groups. By default (``None``), all additional frequencies allowed by the frequencies cut-off
                    are used.
            recompute (bool, optional): if ``True``, recomputes a new basis for the equivariant kernels.
                    By Default (``False``), it  caches the basis built or reuse a cached one, if it is found.
            basis_filter (callable, optional): function which takes as input a descriptor of a basis element
                    (as a dictionary) and returns a boolean value: whether to preserve (``True``) or discard (``False``)
                    the basis element. By default (``None``), no filtering is applied.
        
        Attributes:
            
            ~.weights (torch.Tensor): the learnable parameters which are used to expand the kernel
            ~.filter (torch.Tensor): the convolutional kernel obtained by expanding the parameters
                                    in :attr:`~e2cnn.nn.R2Conv.weights`
            ~.bias (torch.Tensor): the learnable parameters which are used to expand the bias, if ``bias=True``
            ~.expanded_bias (torch.Tensor): the equivariant bias which is summed to the output, obtained by expanding
                                    the parameters in :attr:`~e2cnn.nn.R2Conv.bias`
        
        """

        assert in_type.gspace == out_type.gspace
        assert isinstance(in_type.gspace, GeneralOnR2)

        super(R2Conv, self).__init__()

        self.space = in_type.gspace
        self.in_type = in_type
        self.out_type = out_type

        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.groups = groups

        if groups > 1:
            # Check the input and output classes can be split in `groups` groups, all equal to each other
            # first, check that the number of fields is divisible by `groups`
            assert len(in_type) % groups == 0
            assert len(out_type) % groups == 0
            in_size = len(in_type) // groups
            out_size = len(out_type) // groups

            # then, check that all groups are equal to each other, i.e. have the same types in the same order
            assert all(
                in_type.representations[i] == in_type.representations[i %
                                                                      in_size]
                for i in range(len(in_type)))
            assert all(out_type.representations[i] == out_type.representations[
                i % out_size] for i in range(len(out_type)))

            # finally, retrieve the type associated to a single group in input.
            # this type will be used to build a smaller kernel basis and a smaller filter
            # as in PyTorch, to build a filter for grouped convolution, we build a filter which maps from one input
            # group to all output groups. Then, PyTorch's standard convolution routine interpret this filter as `groups`
            # different filters, each mapping an input group to an output group.
            in_type = in_type.index_select(list(range(in_size)))

        if bias:
            # bias can be applied only to trivial irreps inside the representation
            # to apply bias to a field we learn a bias for each trivial irreps it contains
            # and, then, we transform it with the change of basis matrix to be able to apply it to the whole field
            # this is equivalent to transform the field to its irreps through the inverse change of basis,
            # sum the bias only to the trivial irrep and then map it back with the change of basis

            # count the number of trivial irreps
            trivials = 0
            for r in self.out_type:
                for irr in r.irreps:
                    if self.out_type.fibergroup.irreps[irr].is_trivial():
                        trivials += 1

            # if there is at least 1 trivial irrep
            if trivials > 0:

                # matrix containing the columns of the change of basis which map from the trivial irreps to the
                # field representations. This matrix allows us to map the bias defined only over the trivial irreps
                # to a bias for the whole field more efficiently
                bias_expansion = torch.zeros(self.out_type.size, trivials)

                p, c = 0, 0
                for r in self.out_type:
                    pi = 0
                    for irr in r.irreps:
                        irr = self.out_type.fibergroup.irreps[irr]
                        if irr.is_trivial():
                            bias_expansion[p:p + r.size, c] = torch.tensor(
                                r.change_of_basis[:, pi])
                            c += 1
                        pi += irr.size
                    p += r.size

                self.register_buffer("bias_expansion", bias_expansion)
                self.bias = Parameter(torch.zeros(trivials),
                                      requires_grad=True)
                self.register_buffer("expanded_bias",
                                     torch.zeros(out_type.size))
            else:
                self.bias = None
                self.expanded_bias = None
        else:
            self.bias = None
            self.expanded_bias = None

        grid, basis_filter, rings, sigma, maximum_frequency = compute_basis_params(
            kernel_size, frequencies_cutoff, rings, sigma, dilation,
            basis_filter)

        # BasisExpansion: submodule which takes care of building the filter
        self._basisexpansion = None

        # notice that `in_type` is used instead of `self.in_type` such that it works also when `groups > 1`
        if basisexpansion == 'blocks':
            self._basisexpansion = BlocksBasisExpansion(
                in_type,
                out_type,
                grid,
                sigma=sigma,
                rings=rings,
                maximum_offset=maximum_offset,
                maximum_frequency=maximum_frequency,
                basis_filter=basis_filter,
                recompute=recompute)

        else:
            raise ValueError('Basis Expansion algorithm "%s" not recognized' %
                             basisexpansion)

        self.weights = Parameter(torch.zeros(self.basisexpansion.dimension()),
                                 requires_grad=True)
        self.register_buffer(
            "filter",
            torch.zeros(out_type.size, in_type.size, kernel_size, kernel_size))

        # by default, the weights are initialized with a generalized form of He's weight initialization
        init.generalized_he_init(self.weights.data, self.basisexpansion)
示例#18
0
    def __init__(self, in_type: FieldType, **kwargs):
        r"""
        
        Module that implements Induced Norm Pooling.
        This module requires the input fields to be associated to an induced representation from a representation
        which supports 'norm' non-linearities.
        
        First, for each input field, an output one is built by taking the maximum norm of all its sub-fields.
        
        Args:
            in_type (FieldType): the input field type
            
        """
        assert isinstance(in_type.gspace, GeneralOnR2)

        super(InducedNormPool, self).__init__()

        for r in in_type.representations:
            assert any(nl.startswith('induced_norm') for nl in r.supported_nonlinearities), \
                'Error! Representation "{}" does not support "induced_norm" non-linearity'.format(r.name)

        self.space = in_type.gspace
        self.in_type = in_type

        # build the output representation substituting each input field with a trivial representation
        self.out_type = FieldType(self.space,
                                  [self.space.trivial_repr] * len(in_type))

        # whether each group of fields is contiguous or not
        self._contiguous = {}

        # group fields by their size and the size of the subfields and
        #   - check if fields of the same size are contiguous
        #   - retrieve the indices of the fields

        # indices of the channels corresponding to fields belonging to each group in the input representation
        _in_indices = defaultdict(lambda: [])
        # indices of the channels corresponding to fields belonging to each group in the output representation
        _out_indices = defaultdict(lambda: [])

        # number of fields of each size
        self._nfields = defaultdict(int)

        # whether each group of fields is contiguous or not
        self._contiguous = {}

        position = 0
        last_id = None
        for i, r in enumerate(self.in_type.representations):
            subfield_size = None
            for nl in r.supported_nonlinearities:
                if nl.startswith('induced_norm'):
                    assert subfield_size is None, "Error! The representation supports multiple " \
                                                  "sub-fields of different sizes"
                    subfield_size = int(nl.split('_')[-1])
                    assert r.size % subfield_size == 0

            id = (r.size, subfield_size)

            if id != last_id:
                self._contiguous[id] = not id in self._contiguous

            last_id = id

            _in_indices[id] += list(range(position, position + r.size))
            _out_indices[id] += [i]
            self._nfields[id] += 1
            position += r.size

        for id, contiguous in self._contiguous.items():
            if contiguous:
                # for contiguous fields, only the first and last indices are kept
                _in_indices[id] = torch.LongTensor(
                    [min(_in_indices[id]),
                     max(_in_indices[id]) + 1])
                _out_indices[id] = torch.LongTensor(
                    [min(_out_indices[id]),
                     max(_out_indices[id]) + 1])
            else:
                # otherwise, transform the list of indices into a tensor
                _in_indices[id] = torch.LongTensor(_in_indices[id])
                _out_indices[id] = torch.LongTensor(_out_indices[id])

            # register the indices tensors as parameters of this module
            self.register_buffer('in_indices_{}'.format(id), _in_indices[id])
            self.register_buffer('out_indices_{}'.format(id), _out_indices[id])