コード例 #1
0
ファイル: relu.py プロジェクト: steven-lang/e2cnn
 def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-5) -> List[Tuple[Any, float]]:
     
     c = self.in_type.size
     
     x = torch.randn(3, c, 10, 10)
     
     x = GeometricTensor(x, self.in_type)
     
     errors = []
     
     for el in self.space.testing_elements:
         out1 = self(x).transform_fibers(el)
         out2 = self(x.transform_fibers(el))
         
         errs = (out1.tensor - out2.tensor).detach().numpy()
         errs = np.abs(errs).reshape(-1)
         print(el, errs.max(), errs.mean(), errs.var())
         
         assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \
             'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \
                 .format(el, errs.max(), errs.mean(), errs.var())
         
         errors.append((el, errs.mean()))
     
     return errors
コード例 #2
0
ファイル: pointwise_max.py プロジェクト: steven-lang/e2cnn
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""

        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map

        """

        assert input.type == self.in_type

        # evaluate the max operation densely (stride = 1)
        output = F.max_pool2d(input.tensor, self.kernel_size, 1, self.padding,
                              self.dilation, self.ceil_mode)

        output = F.conv2d(output,
                          self.filter,
                          stride=self.stride,
                          padding=self._pad,
                          groups=output.shape[1])

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)
コード例 #3
0
 def cat(gtensors: Iterable[GeometricTensor], *args,
         **kwargs) -> GeometricTensor:
     tensors = [t.tensor for t in gtensors]
     tensors_cat = torch.cat(tensors, *args, **kwargs)
     feature_type = sum([t.type for t in gtensors[1:]],
                        start=gtensors[0].type)
     return GeometricTensor(tensors_cat, feature_type)
コード例 #4
0
    def forward(self, input: Dict[str, GeometricTensor]) -> GeometricTensor:
        r"""
        
        Apply each module to the corresponding input tensors and stack the results
        
        Args:
            input (dict): a dictionary mapping each label to a GeometricTensor

        Returns:
            the concatenation of the output of each module
            
        """

        # compute the output shape
        out_shape = self.evaluate_output_shape(
            **{l: t.tensor.shape
               for l, t in input.items()})

        device = list(input.values())[0].tensor.device

        # pre-allocate the output tensor
        output = torch.empty(out_shape, dtype=torch.float, device=device)

        last_channel = 0
        # iterate through the modules
        for i, labels in enumerate(self._labels):
            module = getattr(self, f"submodule_{i}")
            # retrieve the corresponding sub-tensor
            output[:, last_channel:last_channel + module.out_type.size,
                   ...] = module(*[input[l] for l in labels]).tensor
            last_channel += module.out_type.size

        return GeometricTensor(output, self.out_type)
コード例 #5
0
ファイル: branching_module.py プロジェクト: wonjongg/e2cnn
    def _retrieve_subfiber(self, input: GeometricTensor,
                           l: str) -> GeometricTensor:
        r"""
        
        Return a new GeometricTensor containg the portion of memory of the input tensor corresponding to the fields
        the input non-linearity acts on. The method automatically deals with the continuity of these fields, using
        either indexing or slicing.
        
        The resulting tensor is returned wrapped in a GeometricTensor with the proper representation
        
        Args:
            input (GeometricTensor): the input tensor
            l (str): the label to consider

        Returns:
            (GeometricTensor): the sub-tensor containing the fields belonging to the input label
            
        """
        indices = getattr(self, f"indices_{l}")

        if self._contiguous[l]:
            # if the fields are contiguous, use slicing
            data = input.tensor[:, indices[0]:indices[1], ...]
        else:
            # otherwise, use indexing
            data = input.tensor[:, indices, ...]

        # wrap the result in a GeometricTensor
        return GeometricTensor(data, self.out_type[l])
コード例 #6
0
 def forward(self, x):
     x = GeometricTensor(x, self.input_type)
     x = self.model(x)
     x = self.pool(x)
     x = x.tensor
     x = self.final(x)
     return x
コード例 #7
0
    def forward(self, input: GeometricTensor):
        r"""
        Convolve the input with the expanded filter and bias.
        
        Args:
            input (GeometricTensor): input feature field transforming according to ``in_type``

        Returns:
            output feature field transforming according to ``out_type``
            
        """

        assert input.type == self.in_type

        if not self.training:
            filter = self.filter
            bias = self.expanded_bias
        else:
            # retrieve the filter and the bias
            filter, bias = self.expand_parameters()

        # use it for convolution and return the result
        output = conv2d(input.tensor,
                        filter,
                        padding=self.padding,
                        stride=self.stride,
                        dilation=self.dilation,
                        groups=self.groups,
                        bias=bias)

        return GeometricTensor(output, self.out_type)
コード例 #8
0
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        assert input.type == self.in_type

        assert input.tensor.shape[2:] == self.mask.shape[2:]

        out = input.tensor * self.mask
        return GeometricTensor(out, self.out_type)
コード例 #9
0
    def forward(self, input: GeometricTensor, inverse=False):
        r"""
            Convolve the input with the expanded filter and bias.

            Args:
                input (GeometricTensor): input feature field transforming according to ``in_type``
            Returns:
                output feature field transforming according to ``out_type``

            """
        assert input.type == self.in_type

        if not self.training:
            _filter = self.filter
            _bias = self.expanded_bias
        else:
            # retrieve the filter and the bias
            _filter, _bias = self.expand_parameters()

        if inverse:
            ipdb.set_trace()
            _filter = -_filter  #Inverse is just negative of the Matrix Exp
        # use it for convolution and return the result
        output = input.tensor
        product = input.tensor

        if self.padding_mode == 'zeros':
            for i in range(1, self.terms + 1):
                product = conv2d(product,
                                 _filter,
                                 padding=self.padding,
                                 stride=self.stride,
                                 dilation=self.dilation,
                                 groups=self.groups,
                                 bias=_bias) / i
                output = output + product

                if self.dynamic_truncation != 0 and i > 5:
                    if product.abs().max().item() < dynamic_truncation:
                        break
        else:
            for i in range(1, self.terms + 1):
                product = conv2d(
                    pad(product, self._reversed_padding_repeated_twice,
                        self.padding_mode),
                    _filter,
                    stride=self.stride,
                    dilation=self.dilation,
                    padding=(0, 0),
                    groups=self.groups,
                    bias=_bias) / i

                output = output + product

                if self.dynamic_truncation != 0 and i > 5:
                    if product.abs().max().item() < dynamic_truncation:
                        break

        return GeometricTensor(output, self.out_type)
コード例 #10
0
 def forward(self, input: GeometricTensor):
     
     assert input.type == self.in_type
     
     # retrieve the values from the input using the permutation of the indices computed before
     data = input.tensor[:, self.indices, ...]
     
     return GeometricTensor(data, self.out_type)
コード例 #11
0
    def forward(self, input: GeometricTensor):
        r"""
        Convolve the input with the expanded filter and bias.

        Args:
            input (GeometricTensor): input feature field transforming according to ``in_type``
        Returns:
            output feature field transforming according to ``out_type``

        """
        if torch.is_tensor(input):
            input = GeometricTensor(input, self.in_type)
        assert input.type == self.in_type

        if not self.training:
            _filter = self.filter
            _bias = self.expanded_bias
        else:
            # retrieve the filter and the bias
            _filter, _bias = self.expand_parameters()
            # self.filter = _filter
            # self.bias = _bias

        # self.filter = self._update_u_v()
        # use it for convolution and return the result

        if self.padding_mode == 'zeros':
            output = conv2d(input.tensor, _filter,
                            stride=self.stride,
                            padding=self.padding,
                            dilation=self.dilation,
                            groups=self.groups,
                            bias=_bias)
        else:
            output = conv2d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode),
                            _filter,
                            stride=self.stride,
                            dilation=self.dilation,
                            padding=(0,0),
                            groups=self.groups,
                            bias=_bias)

        return GeometricTensor(output, self.out_type)
コード例 #12
0
    def check_equivariance(self,
                           atol: float = 1e-7,
                           rtol: float = 1e-5) -> List[Tuple[Any, float]]:
        r"""
        
        Method that automatically tests the equivariance of the current module.
        The default implementation of this method relies on :meth:`e2cnn.nn.GeometricTensor.transform` and uses the
        the group elements in :attr:`~e2cnn.nn.FieldType.testing_elements`.
        
        This method can be overwritten for custom tests.
        
        Returns:
            a list containing containing for each testing element a pair with that element and the corresponding
            equivariance error
        
        """

        c = self.in_type.size

        x = torch.randn(3, c, 10, 10)

        x = GeometricTensor(x, self.in_type)

        errors = []

        for el in self.out_type.testing_elements:
            print(el)

            out1 = self(x).transform(el).tensor.detach().numpy()
            out2 = self(x.transform(el)).tensor.detach().numpy()

            errs = out1 - out2
            errs = np.abs(errs).reshape(-1)
            print(el, errs.max(), errs.mean(), errs.var())

            assert np.allclose(out1, out2, atol=atol, rtol=rtol), \
                'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'\
                    .format(el, errs.max(), errs.mean(), errs.var())

            errors.append((el, errs.mean()))

        return errors
コード例 #13
0
 def check_equivariance(self, atol: float = 2e-6, rtol: float = 1e-5, full_space_action: bool = True) -> List[Tuple[Any, float]]:
     
     if full_space_action:
         
         return super(MultipleModule, self).check_equivariance(atol=atol, rtol=rtol)
     
     else:
         c = self.in_type.size
     
         x = torch.randn(10, c, 9, 9)
         print(c, self.out_type.size)
         print([r.name for r in self.in_type.representations])
         print([r.name for r in self.out_type.representations])
         x = GeometricTensor(x, self.in_type)
     
         errors = []
     
         for el in self.gspace.testing_elements:
             out1 = self(x).transform_fibers(el)
             out2 = self(x.transform_fibers(el))
         
             errs = (out1.tensor - out2.tensor).detach().numpy()
             errs = np.abs(errs).reshape(-1)
             print(el, errs.max(), errs.mean(), errs.var())
             
             if not torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol):
                 tmp = np.abs((out1.tensor - out2.tensor).detach().numpy())
                 tmp = tmp.reshape(out1.tensor.shape[0], out1.tensor.shape[1], -1).max(axis=2)#.mean(axis=0)
                 
                 np.set_printoptions(precision=2, threshold=200000000, suppress=True, linewidth=500)
                 print(tmp.shape)
                 print(tmp)
         
             assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \
                 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \
                     .format(el, errs.max(), errs.mean(), errs.var())
         
             errors.append((el, errs.mean()))
     
         return errors
コード例 #14
0
ファイル: field.py プロジェクト: steven-lang/e2cnn
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""
        
        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map
            
        """

        assert input.type == self.in_type

        if not self.training:
            return input

        input = input.tensor

        if not self.inplace:
            output = torch.empty_like(input)

        # iterate through all field sizes
        for s in self._order:

            indices = getattr(self, f"indices_{s}")

            shape = input.shape[:1] + (self._nfields[s], s) + input.shape[2:]

            if self._contiguous[s]:
                # if the fields are contiguous, we can use slicing
                out = dropout_field(
                    input[:, indices[0]:indices[1], ...].view(shape), self.p,
                    self.training, self.inplace)
                if not self.inplace:
                    shape = input.shape[:1] + (self._nfields[s] *
                                               s, ) + input.shape[2:]
                    output[:, indices[0]:indices[1], ...] = out.view(shape)
            else:
                # otherwise we have to use indexing
                out = dropout_field(input[:, indices, ...].view(shape), self.p,
                                    self.training, self.inplace)
                if not self.inplace:
                    shape = input.shape[:1] + (self._nfields[s] *
                                               s, ) + input.shape[2:]
                    output[:, indices, ...] = out.view(shape)

        if self.inplace:
            output = input

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)
コード例 #15
0
ファイル: relu.py プロジェクト: steven-lang/e2cnn
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""

        Applies ReLU function on the input fields

        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map after relu has been applied

        """
        
        assert input.type == self.in_type, "Error! the type of the input does not match the input type of this module"
        return GeometricTensor(F.relu(input.tensor, inplace=self._inplace), self.out_type)
コード例 #16
0
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""

        Applies ELU function on the input fields

        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map after elu has been applied

        """

        assert input.type == self.in_type
        return GeometricTensor(F.elu(input.tensor, inplace=self._inplace),
                               self.out_type)
コード例 #17
0
ファイル: pointwise.py プロジェクト: steven-lang/e2cnn
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""
        
        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map
            
        """

        assert input.type == self.in_type

        output = F.dropout(input.tensor, self.p, self.training, self.inplace)

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)
コード例 #18
0
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""
        
        Apply the Norm Pooling to the input feature map.
        
        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map
            
        """

        assert input.type == self.in_type

        input = input.tensor
        b, c, h, w = input.shape

        output = torch.empty(self.evaluate_output_shape(input.shape),
                             device=input.device,
                             dtype=torch.float)

        for id, contiguous in self._contiguous.items():
            size, subfield_size = id
            n_subfields = size // subfield_size

            in_indices = getattr(self, f"in_indices_{id}")
            out_indices = getattr(self, f"out_indices_{id}")

            if contiguous:
                fm = input[:, in_indices[0]:in_indices[1], ...]
            else:
                fm = input[:, in_indices, ...]

            # split the channel dimension in 2 dimensions, separating fields
            fm, _ = fm.view(b, -1, n_subfields, subfield_size, h,
                            w).norm(dim=3).max(dim=2)

            if contiguous:
                output[:, out_indices[0]:out_indices[1], ...] = fm
            else:
                output[:, out_indices, ...] = fm

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)
コード例 #19
0
ファイル: vectorfield.py プロジェクト: wonjongg/e2cnn
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""
        
        Apply the VectorField non-linearity to the input feature map.
        
        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map
            
        """

        assert input.type == self.in_type

        b, c, h, w = input.tensor.shape

        # split the channel dimension in 2 dimensions, separating fields
        fm = input.tensor.view(b, -1, self._rotations, h, w)

        # evaluate the base rotation associated with the group action
        base_angle = 2 * np.pi / self._rotations

        # for each field, retrieve the maximum activation (and the argmax)
        max_activations, argmaxes = torch.max(fm, 2)
        max_activations = torch.relu_(max_activations)

        # compute the angles from the index of the maximum activation in the field
        max_angles = argmaxes.to(dtype=torch.float) * base_angle

        # build the output tensor
        output = torch.empty(b,
                             self.out_type.size,
                             h,
                             w,
                             dtype=torch.float,
                             device=input.tensor.device)

        # to build the output vectors, take the cosine and the sine of the argmax angle
        # and multiply the 2-dimensional vector by the activation value
        output[:, ::2, ...] = torch.cos(max_angles) * max_activations
        output[:, 1::2, ...] = torch.sin(max_angles) * max_activations

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)
コード例 #20
0
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""
        
        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map
            
        """

        assert input.type == self.in_type

        # run the common avg-pooling
        output = F.adaptive_avg_pool2d(input.tensor, self.output_size)

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)
コード例 #21
0
ファイル: pointwise.py プロジェクト: wonjongg/e2cnn
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""

        Applies the pointwise activation function on the input fields

        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map after the non-linearities have been applied

        """

        assert input.type == self.in_type

        # TODO - remove the 'contiguous()' call as soon as PyTorch's error is fixed
        # return GeometricTensor(self._function(input.tensor.contiguous()), self.out_type)
        return GeometricTensor(self._function(input.tensor), self.out_type)
コード例 #22
0
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""

        Applies ReLU function on the input fields

        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map after relu has been applied

        """

        assert input.type == self.in_type, "Error! the type of the input does not match the input type of this module"
        return GeometricTensor(
            (input.tensor *
             torch.sigmoid_(input.tensor * F.softplus(self.beta))).div_(1.1),
            self.out_type)
コード例 #23
0
ファイル: gpool.py プロジェクト: steven-lang/e2cnn
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""
        
        Apply Group Pooling to the input feature map.
        
        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map
            
        """

        assert input.type == self.in_type

        input = input.tensor
        b, c, h, w = input.shape

        output = torch.empty(self.evaluate_output_shape(input.shape),
                             device=input.device,
                             dtype=torch.float)

        for s, contiguous in self._contiguous.items():

            in_indices = getattr(self, "in_indices_{}".format(s))
            out_indices = getattr(self, "out_indices_{}".format(s))

            if contiguous:
                fm = input[:, in_indices[0]:in_indices[1], ...]
            else:
                fm = input[:, in_indices, ...]

            # split the channel dimension in 2 dimensions, separating fields
            fm = fm.view(b, -1, s, h, w)

            max_activations, _ = torch.max(fm, 2)

            if contiguous:
                output[:, out_indices[0]:out_indices[1], ...] = max_activations
            else:
                output[:, out_indices, ...] = max_activations

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)
コード例 #24
0
ファイル: pointwise_max.py プロジェクト: steven-lang/e2cnn
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""
        
        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map
            
        """

        assert input.type == self.in_type

        # run the common max-pooling
        output = F.max_pool2d(input.tensor, self.kernel_size, self.stride,
                              self.padding, self.dilation, self.ceil_mode)

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)
コード例 #25
0
    def forward(self, input: GeometricTensor):
        assert input.type == self.in_type

        if not self.training:
            filter = self.filter
            bias = self.expanded_bias
        else:
            # retrieve the filter and the bias
            filter, bias = self.expand_parameters()

        # use it for convolution and return the result
        output = conv_transpose2d(input.tensor,
                                  filter,
                                  padding=self.padding,
                                  output_padding=self.output_padding,
                                  stride=self.stride,
                                  dilation=self.dilation,
                                  groups=self.groups,
                                  bias=bias)

        return GeometricTensor(output, self.out_type)
コード例 #26
0
ファイル: pointwise_avg.py プロジェクト: stegmuel/e2cnn
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""

        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map

        """

        assert input.type == self.in_type

        output = F.conv2d(input.tensor,
                          self.filter,
                          stride=self.stride,
                          padding=self.padding,
                          groups=input.shape[1])

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)
コード例 #27
0
ファイル: disentangle_module.py プロジェクト: wonjongg/e2cnn
 def forward(self, input: GeometricTensor) -> GeometricTensor:
     assert input.type == self.in_type
     
     input = input.tensor
     
     b, c, w, h = input.shape
     
     output = torch.empty_like(input)
     
     # for each different representation in the fiber
     for repr_name in self._order:
         
         contiguous = self._contiguous[repr_name]
         fiber_indices = getattr(self, f"fiber_indices_{repr_name}")
         
         # retrieve the associated change of basis
         cob = getattr(self, f"change_of_basis_{repr_name}")
         
         # retrieve the associated fields from the input tensor
         if contiguous:
             input_fields = input[:, fiber_indices[0]:fiber_indices[1], ...]
         else:
             input_fields = input[:, fiber_indices, ...]
         
         # reshape to align all the fields in order to exploit broadcasting
         input_fields = input_fields.view(b, self._nfields[repr_name], self._sizes[repr_name], w, h)
         
         # TODO: can we exploit the fact the change of basis is a permutation matrix?
         # transform all the fields with the change of basis
         transformed_fields = torch.einsum("oi,bciwh->bcowh", (cob, input_fields)).reshape(b, -1, w, h)
         
         # insert the transformed fields in the output tensor
         if contiguous:
             output[:, fiber_indices[0]:fiber_indices[1], ...] = transformed_fields
         else:
             output[:, fiber_indices, ...] = transformed_fields
     
     return GeometricTensor(output, self.out_type)
コード例 #28
0
    def forward(self, input: GeometricTensor) -> GeometricTensor:

        assert input.type == self.in_type

        b, c, w, h = input.tensor.shape

        # build the output tensor
        output = torch.empty(b,
                             2 * c,
                             w,
                             h,
                             dtype=torch.float,
                             device=input.tensor.device)

        # each channels is transformed to 2 channels:
        # first, apply the non-linearity to its value
        output[:, ::2, ...] = self._function(input.tensor)

        # then, apply the non-linearity to its values with the sign inverted
        output[:, 1::2, ...] = self._function(-1 * input.tensor)

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)
コード例 #29
0
    def forward(self, input: GeometricTensor):
        r"""
        
        Args:
            input (torch.Tensor): input feature map

        Returns:
             the result of the convolution
             
        """

        assert input.type == self.in_type

        if self._align_corners is None:
            output = interpolate(input.tensor,
                                 scale_factor=self._scale_factor,
                                 mode=self._mode)
        else:
            output = interpolate(input.tensor,
                                 scale_factor=self._scale_factor,
                                 mode=self._mode,
                                 align_corners=self._align_corners)

        return GeometricTensor(output, self.out_type)
コード例 #30
0
ファイル: inner.py プロジェクト: drewm1980/e2cnn
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""
        
        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map
            
        """

        assert input.type == self.in_type

        b, c, h, w = input.tensor.shape

        output = torch.empty_like(input.tensor)

        # iterate through all field sizes
        for s, contiguous in self._contiguous.items():

            indices = getattr(self, f"indices_{s}")
            batchnorm = getattr(self, f'batch_norm_[{s}]')

            if contiguous:
                # if the fields were contiguous, we can use slicing
                output[:, indices[0]:indices[1], :, :] = batchnorm(
                    input.tensor[:, indices[0]:indices[1], :, :].view(
                        b, -1, s, h, w)).view(b, -1, h, w)
            else:
                # otherwise we have to use indexing
                output[:, indices, :, :] = batchnorm(
                    input.tensor[:, indices, :, :].view(b, -1, s, h,
                                                        w)).view(b, -1, h, w)

        # wrap the result in a GeometricTensor
        return GeometricTensor(output, self.out_type)