def backward(ctx, out_grad):
        # https://kevinzakka.github.io/2016/09/14/batch_normalization/
        in_coords_key, glob_coords_key, coords_manager = ctx.saved_vars

        # To prevent the memory leakage, compute the norm again
        inv_std, norm_feat = ctx.saved_tensors

        gpooling_mode = PoolingMode.GLOBAL_AVG_POOLING_KERNEL
        gpool_avg_forward = get_minkowski_function("GlobalPoolingForward",
                                                   out_grad)
        broadcast_forward = get_minkowski_function("BroadcastForward",
                                                   out_grad)

        # 1/N \sum dout
        mean_dout, num_nonzero = gpool_avg_forward(
            out_grad,
            gpooling_mode,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )

        # 1/N \sum (dout * out)
        mean_dout_feat, num_nonzero = gpool_avg_forward(
            out_grad * norm_feat,
            gpooling_mode,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )

        # out * 1/N \sum (dout * out)
        feat_mean_dout_feat = broadcast_forward(
            norm_feat,
            mean_dout_feat,
            BroadcastMode.ELEMENTWISE_MULTIPLICATION,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )

        unnorm_din = broadcast_forward(
            out_grad - feat_mean_dout_feat,
            -mean_dout,
            BroadcastMode.ELEMENTWISE_ADDITON,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )

        norm_din = broadcast_forward(
            unnorm_din,
            inv_std,
            BroadcastMode.ELEMENTWISE_MULTIPLICATION,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )

        return norm_din, None, None, None, None
    def forward(
        ctx,
        in_feat: torch.Tensor,
        in_coords_key: CoordinateMapKey,
        glob_coords_key: CoordinateMapKey = None,
        coords_manager: CoordinateManager = None,
        gpooling_mode=PoolingMode.GLOBAL_AVG_POOLING_KERNEL,
    ):
        if glob_coords_key is None:
            glob_coords_key = CoordinateMapKey(
                in_coords_key.get_coordinate_size())

        gpool_avg_forward = get_minkowski_function("GlobalPoolingForward",
                                                   in_feat)
        broadcast_forward = get_minkowski_function("BroadcastForward", in_feat)

        mean, num_nonzero = gpool_avg_forward(
            in_feat,
            gpooling_mode,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )

        # X - \mu
        centered_feat = broadcast_forward(
            in_feat,
            -mean,
            BroadcastMode.ELEMENTWISE_ADDITON,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )

        # Variance = 1/N \sum (X - \mu) ** 2
        variance, num_nonzero = gpool_avg_forward(
            centered_feat**2,
            gpooling_mode,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )

        # norm_feat = (X - \mu) / \sigma
        inv_std = 1 / (variance + 1e-8).sqrt()
        norm_feat = broadcast_forward(
            centered_feat,
            inv_std,
            BroadcastMode.ELEMENTWISE_MULTIPLICATION,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )

        ctx.saved_vars = (in_coords_key, glob_coords_key, coords_manager,
                          gpooling_mode)
        # For GPU tensors, must use save_for_backward.
        ctx.save_for_backward(inv_std, norm_feat)
        return norm_feat
 def backward(ctx, grad_out_feat):
     grad_out_feat = grad_out_feat.contiguous()
     bw_fn = get_minkowski_function("LocalPoolingBackward", grad_out_feat)
     grad_in_feat = bw_fn(
         ctx.input_features,
         grad_out_feat,
         ctx.num_nonzero,
         ctx.kernel_generator.kernel_size,
         ctx.kernel_generator.kernel_stride,
         ctx.kernel_generator.kernel_dilation,
         ctx.kernel_generator.region_type,
         ctx.kernel_generator.region_offsets,
         ctx.pooling_mode,
         ctx.in_coordinate_map_key,
         ctx.out_coordinate_map_key,
         ctx.coordinate_manager._manager,
     )
     return (
         grad_in_feat,
         None,
         None,
         None,
         None,
         None,
     )
    def forward(
        ctx,
        input_features: torch.Tensor,
        input_features_global: torch.Tensor,
        operation_type: BroadcastMode,
        in_coords_key: CoordinateMapKey,
        glob_coords_key: CoordinateMapKey,
        coords_manager: CoordinateManager,
    ):
        assert isinstance(operation_type, BroadcastMode)

        ctx.saved_vars = (
            input_features,
            input_features_global,
            operation_type,
            in_coords_key,
            glob_coords_key,
            coords_manager,
        )

        fw_fn = get_minkowski_function("BroadcastForward", input_features)
        return fw_fn(
            input_features,
            input_features_global,
            operation_type,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )
Beispiel #5
0
    def backward(ctx, grad_out_feat: torch.Tensor):
        grad_out_feat = grad_out_feat.contiguous()
        (
            kernel_generator,
            convolution_mode,
            in_coordinate_map_key,
            out_coordinate_map_key,
            coordinate_manager,
        ) = ctx.misc

        bw_fn = get_minkowski_function("ConvolutionBackward", grad_out_feat)
        grad_in_feat, grad_kernel = bw_fn(
            ctx.input_features,
            grad_out_feat,
            ctx.kernel_weights,
            kernel_generator.kernel_size,
            kernel_generator.kernel_stride,
            kernel_generator.kernel_dilation,
            kernel_generator.region_type,
            kernel_generator.region_offsets,
            convolution_mode,
            in_coordinate_map_key,
            out_coordinate_map_key,
            coordinate_manager._manager,
        )
        return (
            grad_in_feat,
            grad_kernel,
            None,
            None,
            None,
            None,
            None,
        )
    def forward(
        ctx,
        input_features: torch.Tensor,
        pooling_mode: PoolingMode,
        in_coordinate_map_key: CoordinateMapKey,
        out_coordinate_map_key: CoordinateMapKey = None,
        coordinate_manager: CoordinateManager = None,
    ):
        if out_coordinate_map_key is None:
            out_coordinate_map_key = CoordinateMapKey(
                in_coordinate_map_key.get_coordinate_size())
        input_features = input_features.contiguous()

        ctx.input_features = input_features
        ctx.in_coords_key = in_coordinate_map_key
        ctx.out_coords_key = out_coordinate_map_key
        ctx.coordinate_manager = coordinate_manager
        ctx.pooling_mode = pooling_mode

        fw_fn = get_minkowski_function("GlobalPoolingForward", input_features)
        out_feat, num_nonzero = fw_fn(
            input_features,
            pooling_mode,
            ctx.in_coords_key,
            ctx.out_coords_key,
            ctx.coordinate_manager._manager,
        )
        ctx.num_nonzero = num_nonzero

        return out_feat
 def backward(ctx, grad_out_feat: torch.Tensor):
     bw_fn = get_minkowski_function("PruningBackward", grad_out_feat)
     grad_in_feat = bw_fn(
         grad_out_feat,
         ctx.in_coords_key,
         ctx.out_coords_key,
         ctx.coords_manager._manager,
     )
     return grad_in_feat, None, None, None, None
Beispiel #8
0
    def forward(
        ctx,
        input_features,
        tensor_stride=1,
        stride=1,
        kernel_size=-1,
        dilation=1,
        region_type=-1,
        region_offset=None,
        average=False,
        in_coords_key=None,
        out_coords_key=None,
        coords_manager=None,
    ):
        assert isinstance(region_type, RegionType)
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        assert in_coords_key.D == out_coords_key.D
        tensor_stride, stride, kernel_size, dilation, region_type = prep_args(
            tensor_stride, stride, kernel_size, dilation, region_type,
            in_coords_key.D)

        if region_offset is None:
            region_offset = torch.IntTensor()

        ctx.in_feat = input_features
        out_feat = input_features.new()
        ctx.num_nonzero = input_features.new()
        ctx = save_ctx(
            ctx,
            tensor_stride,
            stride,
            kernel_size,
            dilation,
            region_type,
            in_coords_key,
            out_coords_key,
            coords_manager,
        )
        D = in_coords_key.D
        fw_fn = get_minkowski_function("PoolingTransposeForward",
                                       input_features)
        fw_fn(
            ctx.in_feat,
            out_feat,
            ctx.num_nonzero,
            convert_to_int_list(ctx.tensor_stride, D),
            convert_to_int_list(ctx.stride, D),
            convert_to_int_list(ctx.kernel_size, D),
            convert_to_int_list(ctx.dilation, D),
            region_type,
            region_offset,
            ctx.in_coords_key.CPPCoordsKey,
            ctx.out_coords_key.CPPCoordsKey,
            ctx.coords_man.CPPCoordsManager,
        )
        return out_feat
Beispiel #9
0
 def backward(ctx, grad_out_feat):
     bw_fn = get_minkowski_function("GlobalPoolingBackward", grad_out_feat)
     grad_in_feat = bw_fn(
         ctx.input_features,
         grad_out_feat,
         ctx.num_nonzero,
         ctx.pooling_mode,
         ctx.in_coords_key,
         ctx.out_coords_key,
         ctx.coordinate_manager._manager,
     )
     return grad_in_feat, None, None, None, None, None
Beispiel #10
0
 def backward(ctx, grad_out_feat):
     grad_in_feat = grad_out_feat.new()
     D = ctx.in_coords_key.D
     bw_fn = get_minkowski_function("PoolingTransposeBackward",
                                    grad_out_feat)
     bw_fn(
         ctx.in_feat,
         grad_in_feat,
         grad_out_feat,
         ctx.num_nonzero,
         convert_to_int_list(ctx.tensor_stride, D),
         convert_to_int_list(ctx.stride, D),
         convert_to_int_list(ctx.kernel_size, D),
         convert_to_int_list(ctx.dilation, D),
         ctx.region_type,
         ctx.in_coords_key.CPPCoordsKey,
         ctx.out_coords_key.CPPCoordsKey,
         ctx.coords_man.CPPCoordsManager,
     )
     return grad_in_feat, None, None, None, None, None, None, None, None, None, None
    def backward(
        ctx, grad_out_feat=None, grad_in_map=None, grad_out_map=None, grad_weights=None
    ):
        grad_out_feat = grad_out_feat.contiguous()
        bw_fn = get_minkowski_function("InterpolationBackward", grad_out_feat)
        (
            in_coordinate_map_key,
            coordinate_manager,
        ) = ctx.inputs
        in_map, out_map, weights = ctx.saved_tensors

        grad_in_feat = bw_fn(
            grad_out_feat,
            in_map,
            out_map,
            weights,
            in_coordinate_map_key,
            coordinate_manager._manager,
        )
        return grad_in_feat, None, None, None
    def forward(
        ctx,
        in_feat: torch.Tensor,
        mask: torch.Tensor,
        in_coords_key: CoordinateMapKey,
        out_coords_key: CoordinateMapKey = None,
        coords_manager: CoordinateManager = None,
    ):
        ctx.in_coords_key = in_coords_key
        ctx.out_coords_key = out_coords_key
        ctx.coords_manager = coords_manager

        fw_fn = get_minkowski_function("PruningForward", in_feat)
        return fw_fn(
            in_feat,
            mask,
            ctx.in_coords_key,
            ctx.out_coords_key,
            ctx.coords_manager._manager,
        )
Beispiel #13
0
    def forward(
        ctx,
        input_features: torch.Tensor,
        kernel_weights: torch.Tensor,
        kernel_generator: KernelGenerator,
        convolution_mode: ConvolutionMode,
        in_coordinate_map_key: CoordinateMapKey,
        out_coordinate_map_key: CoordinateMapKey = None,
        coordinate_manager: CoordinateManager = None,
    ):
        if out_coordinate_map_key is None:
            out_coordinate_map_key = CoordinateMapKey(
                in_coordinate_map_key.get_coordinate_size())

        input_features = input_features.contiguous()

        ctx.input_features = input_features
        ctx.kernel_weights = kernel_weights
        ctx.misc = [
            kernel_generator,
            convolution_mode,
            in_coordinate_map_key,
            out_coordinate_map_key,
            coordinate_manager,
        ]

        fw_fn = get_minkowski_function("ConvolutionForward", input_features)
        return fw_fn(
            ctx.input_features,
            kernel_weights,
            kernel_generator.kernel_size,
            kernel_generator.kernel_stride,
            kernel_generator.kernel_dilation,
            kernel_generator.region_type,
            kernel_generator.region_offsets,
            kernel_generator.expand_coordinates,
            convolution_mode,
            in_coordinate_map_key,
            out_coordinate_map_key,
            coordinate_manager._manager,
        )
    def forward(
        ctx,
        input_features: torch.Tensor,
        pooling_mode: PoolingMode,
        kernel_generator: KernelGenerator,
        in_coordinate_map_key: CoordinateMapKey,
        out_coordinate_map_key: CoordinateMapKey = None,
        coordinate_manager: CoordinateManager = None,
    ):
        if out_coordinate_map_key is None:
            out_coordinate_map_key = CoordinateMapKey(
                in_coordinate_map_key.get_coordinate_size())

        input_features = input_features.contiguous()
        ctx.input_features = input_features
        ctx = save_ctx(
            ctx,
            kernel_generator,
            in_coordinate_map_key,
            out_coordinate_map_key,
            coordinate_manager,
        )
        ctx.pooling_mode = pooling_mode

        fw_fn = get_minkowski_function("LocalPoolingTransposeForward",
                                       input_features)
        out_feat, num_nonzero = fw_fn(
            ctx.input_features,
            kernel_generator.kernel_size,
            kernel_generator.kernel_stride,
            kernel_generator.kernel_dilation,
            kernel_generator.region_type,
            kernel_generator.region_offsets,
            kernel_generator.expand_coordinates,
            pooling_mode,
            ctx.in_coordinate_map_key,
            ctx.out_coordinate_map_key,
            ctx.coordinate_manager._manager,
        )
        ctx.num_nonzero = num_nonzero
        return out_feat
 def forward(
     ctx,
     input_features: torch.Tensor,
     tfield: torch.Tensor,
     in_coordinate_map_key: CoordinateMapKey,
     coordinate_manager: CoordinateManager = None,
 ):
     input_features = input_features.contiguous()
     # in_map, out_map, weights = coordinate_manager.interpolation_map_weight(
     #     in_coordinate_map_key, tfield)
     fw_fn = get_minkowski_function("InterpolationForward", input_features)
     out_feat, in_map, out_map, weights = fw_fn(
         input_features,
         tfield,
         in_coordinate_map_key,
         coordinate_manager._manager,
     )
     ctx.save_for_backward(in_map, out_map, weights)
     ctx.inputs = (
         in_coordinate_map_key,
         coordinate_manager,
     )
     return out_feat, in_map, out_map, weights
Beispiel #16
0
 def backward(ctx, grad_out_feat: torch.Tensor):
     bw_fn = get_minkowski_function("ConvolutionTransposeBackward",
                                    grad_out_feat)
     grad_in_feat, grad_kernel = bw_fn(
         ctx.input_features,
         grad_out_feat,
         ctx.kernel_weights,
         ctx.kernel_generator.kernel_size,
         ctx.kernel_generator.kernel_stride,
         ctx.kernel_generator.kernel_dilation,
         ctx.kernel_generator.region_type,
         ctx.kernel_generator.region_offsets,
         ctx.in_coordinate_map_key,
         ctx.out_coordinate_map_key,
         ctx.coordinate_manager._manager,
     )
     return (
         grad_in_feat,
         grad_kernel,
         None,
         None,
         None,
         None,
     )
    def backward(ctx, grad_out_feat):
        if not grad_out_feat.is_contiguous():
            grad_out_feat = grad_out_feat.contiguous()

        (
            input_features,
            input_features_global,
            operation_type,
            in_coords_key,
            glob_coords_key,
            coords_manager,
        ) = ctx.saved_vars

        bw_fn = get_minkowski_function("BroadcastBackward", grad_out_feat)
        grad_in_feat, grad_in_feat_glob = bw_fn(
            input_features,
            input_features_global,
            grad_out_feat,
            operation_type,
            in_coords_key,
            glob_coords_key,
            coords_manager._manager,
        )
        return grad_in_feat, grad_in_feat_glob, None, None, None, None