def forward(ctx,
                in_feat,
                mode=GlobalPoolingMode.AUTO,
                in_coords_key=None,
                glob_coords_key=None,
                coords_manager=None):
        assert isinstance(mode, GlobalPoolingMode), \
            f"Mode must be an instance of GlobalPoolingMode, {mode}"
        if glob_coords_key is None:
            glob_coords_key = CoordsKey(in_coords_key.D)

        gpool_forward = get_minkowski_function('GlobalPoolingForward', in_feat)
        broadcast_forward = get_minkowski_function('BroadcastForward', in_feat)
        add = operation_type_to_int(OperationType.ADDITION)
        multiply = operation_type_to_int(OperationType.MULTIPLICATION)

        mean = in_feat.new()
        num_nonzero = in_feat.new()

        cpp_in_coords_key = in_coords_key.CPPCoordsKey
        cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
        cpp_coords_manager = coords_manager.CPPCoordsManager

        mean, num_nonzero = gpool_forward(in_feat, cpp_in_coords_key,
                                          cpp_glob_coords_key,
                                          cpp_coords_manager, True, mode.value)
        # X - \mu
        centered_feat = broadcast_forward(in_feat, -mean, add,
                                          cpp_in_coords_key,
                                          cpp_glob_coords_key,
                                          cpp_coords_manager)

        # Variance = 1/N \sum (X - \mu) ** 2
        variance, num_nonzero = gpool_forward(centered_feat**2,
                                              cpp_in_coords_key,
                                              cpp_glob_coords_key,
                                              cpp_coords_manager, True,
                                              mode.value)

        # norm_feat = (X - \mu) / \sigma
        inv_std = 1 / (variance + 1e-8).sqrt()
        norm_feat = broadcast_forward(centered_feat, inv_std, multiply,
                                      cpp_in_coords_key, cpp_glob_coords_key,
                                      cpp_coords_manager)

        ctx.mode = mode
        ctx.in_coords_key, ctx.glob_coords_key = in_coords_key, glob_coords_key
        ctx.coords_manager = coords_manager
        # For GPU tensors, must use save_for_backward.
        ctx.save_for_backward(inv_std, norm_feat)
        return norm_feat
 def backward(ctx, grad_out_feat):
     bw_fn = get_minkowski_function('GlobalPoolingBackward', grad_out_feat)
     grad_in_feat = bw_fn(ctx.in_feat, grad_out_feat, ctx.num_nonzero,
                          ctx.in_coords_key.CPPCoordsKey,
                          ctx.out_coords_key.CPPCoordsKey,
                          ctx.coords_manager.CPPCoordsManager, ctx.average)
     return grad_in_feat, None, None, None, None, None
    def forward(ctx,
                input_features,
                average=True,
                mode=GlobalPoolingMode.AUTO,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        assert isinstance(mode, GlobalPoolingMode), \
            f"Mode must be an instance of GlobalPoolingMode, {mode}"

        ctx.in_coords_key = in_coords_key
        ctx.out_coords_key = out_coords_key

        ctx.in_feat = input_features
        ctx.average = average
        ctx.coords_manager = coords_manager
        ctx.mode = mode.value

        fw_fn = get_minkowski_function('GlobalPoolingForward', input_features)
        out_feat, num_nonzero = fw_fn(ctx.in_feat,
                                      ctx.in_coords_key.CPPCoordsKey,
                                      ctx.out_coords_key.CPPCoordsKey,
                                      ctx.coords_manager.CPPCoordsManager,
                                      ctx.average, ctx.mode)

        ctx.num_nonzero = num_nonzero

        return out_feat
    def backward(ctx, out_grad):
        # https://kevinzakka.github.io/2016/09/14/batch_normalization/
        in_coords_key, glob_coords_key = ctx.in_coords_key, ctx.glob_coords_key
        coords_manager = ctx.coords_manager

        # To prevent the memory leakage, compute the norm again
        inv_std, norm_feat = ctx.saved_tensors

        gpool_forward = get_minkowski_function('GlobalPoolingForward',
                                               out_grad)
        broadcast_forward = get_minkowski_function('BroadcastForward',
                                                   out_grad)
        add = operation_type_to_int(OperationType.ADDITION)
        multiply = operation_type_to_int(OperationType.MULTIPLICATION)

        cpp_in_coords_key = in_coords_key.CPPCoordsKey
        cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
        cpp_coords_manager = coords_manager.CPPCoordsManager

        # 1/N \sum dout
        mean_dout, num_nonzero = gpool_forward(out_grad, cpp_in_coords_key,
                                               cpp_glob_coords_key,
                                               cpp_coords_manager, True,
                                               ctx.mode.value)

        # 1/N \sum (dout * out)
        mean_dout_feat, num_nonzero = gpool_forward(out_grad * norm_feat,
                                                    cpp_in_coords_key,
                                                    cpp_glob_coords_key,
                                                    cpp_coords_manager, True,
                                                    ctx.mode.value)

        # out * 1/N \sum (dout * out)
        feat_mean_dout_feat = broadcast_forward(norm_feat, mean_dout_feat,
                                                multiply, cpp_in_coords_key,
                                                cpp_glob_coords_key,
                                                cpp_coords_manager)

        unnorm_din = broadcast_forward(out_grad - feat_mean_dout_feat,
                                       -mean_dout, add, cpp_in_coords_key,
                                       cpp_glob_coords_key, cpp_coords_manager)

        norm_din = broadcast_forward(unnorm_din, inv_std, multiply,
                                     cpp_in_coords_key, cpp_glob_coords_key,
                                     cpp_coords_manager)

        return norm_din, None, None, None, None
 def backward(ctx, grad_out_feat):
     grad_in_feat = grad_out_feat.new()
     bw_fn = get_minkowski_function('GlobalMaxPoolingBackward',
                                    grad_out_feat)
     bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.max_index,
           ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
           ctx.coords_manager.CPPCoordsManager)
     return grad_in_feat, None, None, None, None, None
예제 #6
0
    def backward(ctx, grad_out_feat):
        if not grad_out_feat.is_contiguous():
            grad_out_feat = grad_out_feat.contiguous()

        grad_in_feat = grad_out_feat.new()
        bw_fn = get_minkowski_function('PruningBackward', grad_out_feat)
        bw_fn(grad_in_feat, grad_out_feat, ctx.in_coords_key.CPPCoordsKey,
              ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_manager.CPPCoordsManager)
        return grad_in_feat, None, None, None, None, None
예제 #7
0
    def backward(ctx, grad_out_feat):
        if not grad_out_feat.is_contiguous():
            grad_out_feat = grad_out_feat.contiguous()

        bw_fn = get_minkowski_function('UnionBackward', grad_out_feat)
        grad_in_feats = bw_fn(grad_out_feat,
                              [key.CPPCoordsKey for key in ctx.in_coords_keys],
                              ctx.out_coords_key.CPPCoordsKey,
                              ctx.coords_manager.CPPCoordsManager)
        return (None, None, None, *grad_in_feats)
    def forward(ctx,
                input_features,
                kernel,
                tensor_stride=1,
                stride=1,
                kernel_size=-1,
                dilation=1,
                region_type=0,
                region_offset=None,
                generate_new_coords=False,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        """
        region_type=0 HyperCube
        """
        # Prep arguments
        # Kernel shape (n_spatial_kernels, in_nfeat, out_nfeat)
        assert input_features.shape[1] == kernel.shape[1], \
            "The input shape " + str(list(input_features.shape)) + \
            " does not match the kernel shape " + str(list(kernel.shape))
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        assert in_coords_key.D == out_coords_key.D
        assert input_features.type() == kernel.type(), \
            f"Type mismatch input: {input_features.type()} != kernel: {kernel.type()}"
        if not input_features.is_contiguous():
            input_features = input_features.contiguous()

        tensor_stride, stride, kernel_size, dilation, region_type = prep_args(
            tensor_stride, stride, kernel_size, dilation, region_type,
            in_coords_key.D)

        if region_offset is None:
            region_offset = torch.IntTensor()

        ctx.in_feat = input_features
        ctx.kernel = kernel
        ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation,
                       region_type, in_coords_key, out_coords_key,
                       coords_manager)

        D = in_coords_key.D
        out_feat = input_features.new()

        fw_fn = get_minkowski_function('ConvolutionTransposeForward',
                                       input_features)
        fw_fn(ctx.in_feat, out_feat, kernel,
              convert_to_int_list(ctx.tensor_stride, D),
              convert_to_int_list(ctx.stride, D),
              convert_to_int_list(ctx.kernel_size, D),
              convert_to_int_list(ctx.dilation, D), region_type, region_offset,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_man.CPPCoordsManager, generate_new_coords)
        return out_feat
예제 #9
0
    def backward(ctx, grad_out_feat):
        if not grad_out_feat.is_contiguous():
            grad_out_feat = grad_out_feat.contiguous()

        grad_in_feat = grad_out_feat.new()
        grad_in_feat_glob = grad_out_feat.new()
        bw_fn = get_minkowski_function('BroadcastBackward', grad_out_feat)
        bw_fn(ctx.in_feat, grad_in_feat, ctx.in_feat_glob, grad_in_feat_glob,
              grad_out_feat, ctx.op, ctx.in_coords_key.CPPCoordsKey,
              ctx.glob_coords_key.CPPCoordsKey,
              ctx.coords_manager.CPPCoordsManager)
        return grad_in_feat, grad_in_feat_glob, None, None, None, None
예제 #10
0
 def backward(ctx, grad_out_feat):
     grad_in_feat = grad_out_feat.new()
     D = ctx.in_coords_key.D
     bw_fn = get_minkowski_function('PoolingTransposeBackward',
                                    grad_out_feat)
     bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.num_nonzero,
           convert_to_int_list(ctx.tensor_stride, D),
           convert_to_int_list(ctx.stride, D),
           convert_to_int_list(ctx.kernel_size, D),
           convert_to_int_list(ctx.dilation, D), ctx.region_type,
           ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
           ctx.coords_man.CPPCoordsManager)
     return grad_in_feat, None, None, None, None, None, None, None, None, None, None
예제 #11
0
    def backward(ctx, grad_out_feat):
        if not grad_out_feat.is_contiguous():
            grad_out_feat = grad_out_feat.contiguous()

        grad_in_feat = grad_out_feat.new()
        D = ctx.in_coords_key.D
        bw_fn = get_minkowski_function('MaxPoolingBackward', grad_out_feat)
        bw_fn(ctx.in_feat, grad_in_feat, grad_out_feat, ctx.max_index,
              convert_to_int_list(ctx.tensor_stride, D),
              convert_to_int_list(ctx.stride, D),
              convert_to_int_list(ctx.kernel_size, D),
              convert_to_int_list(ctx.dilation, D), ctx.region_type,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_man.CPPCoordsManager)
        return grad_in_feat, None, None, None, None, None, None, None, None, None
예제 #12
0
    def forward(ctx, in_coords_keys, out_coords_key, coords_manager, *in_feats):
        assert isinstance(in_feats, list) or isinstance(in_feats, tuple), \
            "Input must be a list or a set of Tensors"
        assert len(in_feats) > 1, \
            "input must be a set with at least 2 Tensors"

        in_feats = [in_feat.contiguous() for in_feat in in_feats]

        ctx.in_coords_keys = in_coords_keys
        ctx.out_coords_key = out_coords_key
        ctx.coords_manager = coords_manager

        fw_fn = get_minkowski_function('UnionForward', in_feats[0])
        return fw_fn(in_feats, [key.CPPCoordsKey for key in ctx.in_coords_keys],
                     ctx.out_coords_key.CPPCoordsKey,
                     ctx.coords_manager.CPPCoordsManager)
예제 #13
0
    def forward(ctx,
                input_features,
                tensor_stride=1,
                stride=1,
                kernel_size=-1,
                dilation=1,
                region_type=0,
                region_offset=None,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        assert isinstance(region_type, RegionType)
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        assert in_coords_key.D == out_coords_key.D
        if not input_features.is_contiguous():
            input_features = input_features.contiguous()

        tensor_stride, stride, kernel_size, dilation, region_type = prep_args(
            tensor_stride, stride, kernel_size, dilation, region_type,
            in_coords_key.D)

        if region_offset is None:
            region_offset = torch.IntTensor()

        ctx.in_feat = input_features
        ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation,
                       region_type, in_coords_key, out_coords_key,
                       coords_manager)

        D = in_coords_key.D
        out_feat = input_features.new()
        max_index = input_features.new().int()

        ctx.max_index = max_index

        fw_fn = get_minkowski_function('MaxPoolingForward', input_features)
        fw_fn(input_features, out_feat, max_index,
              convert_to_int_list(ctx.tensor_stride, D),
              convert_to_int_list(ctx.stride, D),
              convert_to_int_list(ctx.kernel_size, D),
              convert_to_int_list(ctx.dilation, D), region_type, region_offset,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_man.CPPCoordsManager)
        return out_feat
예제 #14
0
    def forward(ctx, in_feat, mask, in_coords_key, out_coords_key,
                coords_manager):
        assert in_feat.size(0) == mask.size(0)
        assert isinstance(mask,
                          torch.BoolTensor), "Mask must be a cpu bool tensor."
        if not in_feat.is_contiguous():
            in_feat = in_feat.contiguous()
        if not mask.is_contiguous():
            mask = mask.contiguous()

        ctx.in_coords_key = in_coords_key
        ctx.out_coords_key = out_coords_key
        ctx.coords_manager = coords_manager

        out_feat = in_feat.new()

        fw_fn = get_minkowski_function('PruningForward', in_feat)
        fw_fn(in_feat, out_feat, mask, ctx.in_coords_key.CPPCoordsKey,
              ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_manager.CPPCoordsManager)
        return out_feat
예제 #15
0
    def forward(ctx,
                input_features,
                tensor_stride=1,
                stride=1,
                kernel_size=-1,
                dilation=1,
                region_type=-1,
                region_offset=None,
                average=False,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        assert isinstance(region_type, RegionType)
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        assert in_coords_key.D == out_coords_key.D
        tensor_stride, stride, kernel_size, dilation, region_type = prep_args(
            tensor_stride, stride, kernel_size, dilation, region_type,
            in_coords_key.D)

        if region_offset is None:
            region_offset = torch.IntTensor()

        ctx.in_feat = input_features
        out_feat = input_features.new()
        ctx.num_nonzero = input_features.new()
        ctx = save_ctx(ctx, tensor_stride, stride, kernel_size, dilation,
                       region_type, in_coords_key, out_coords_key,
                       coords_manager)
        D = in_coords_key.D
        fw_fn = get_minkowski_function('PoolingTransposeForward',
                                       input_features)
        fw_fn(ctx.in_feat, out_feat, ctx.num_nonzero,
              convert_to_int_list(ctx.tensor_stride, D),
              convert_to_int_list(ctx.stride, D),
              convert_to_int_list(ctx.kernel_size, D),
              convert_to_int_list(ctx.dilation, D), region_type, region_offset,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_man.CPPCoordsManager)
        return out_feat
예제 #16
0
    def forward(ctx, input_features, input_features_global, operation_type,
                in_coords_key, glob_coords_key, coords_manager):
        assert input_features.shape[1] == input_features_global.shape[1]
        assert input_features.type() == input_features_global.type()
        assert isinstance(operation_type, OperationType)
        if not input_features.is_contiguous():
            input_features = input_features.contiguous()
        if not input_features_global.is_contiguous():
            input_features_global = input_features_global.contiguous()

        ctx.op = operation_type_to_int(operation_type)

        ctx.in_feat = input_features
        ctx.in_feat_glob = input_features_global
        ctx.in_coords_key = in_coords_key
        ctx.glob_coords_key = glob_coords_key
        ctx.coords_manager = coords_manager

        fw_fn = get_minkowski_function('BroadcastForward', input_features)
        out_feat = fw_fn(ctx.in_feat, ctx.in_feat_glob, ctx.op,
                         ctx.in_coords_key.CPPCoordsKey,
                         ctx.glob_coords_key.CPPCoordsKey,
                         ctx.coords_manager.CPPCoordsManager)
        return out_feat
예제 #17
0
    def forward(ctx,
                input_features,
                in_coords_key=None,
                out_coords_key=None,
                coords_manager=None):
        if out_coords_key is None:
            out_coords_key = CoordsKey(in_coords_key.D)
        ctx.in_coords_key = in_coords_key
        ctx.out_coords_key = out_coords_key

        ctx.in_feat = input_features
        out_feat = input_features.new()

        max_index = input_features.new().int()

        ctx.max_index = max_index
        ctx.coords_manager = coords_manager

        fw_fn = get_minkowski_function('GlobalMaxPoolingForward',
                                       input_features)
        fw_fn(ctx.in_feat, out_feat, ctx.max_index,
              ctx.in_coords_key.CPPCoordsKey, ctx.out_coords_key.CPPCoordsKey,
              ctx.coords_manager.CPPCoordsManager)
        return out_feat