Esempio n. 1
0
    def __getitem__(self, idx):
        data_ranges = {
            0: [-1, 1],
            1: [0, 2],
            2: [-2, 0],
            3: [-2, -2],
        }
        l = np.random.randint(0, 4, size=[2])

        data = generate_sparse_data([16, 64, 64], [16 * 64 * 64 // 2],
                                    3,
                                    data_range=data_ranges[l[0]],
                                    with_dense=False)
        data2 = generate_sparse_data([16, 64, 64], [16 * 64 * 64 // 2],
                                     3,
                                     data_range=data_ranges[l[1]],
                                     with_dense=False)

        features = np.ascontiguousarray(data["features"]).astype(np.float32)
        indices = np.ascontiguousarray(
            data["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
        features2 = np.ascontiguousarray(data2["features"]).astype(np.float32)
        indices2 = np.ascontiguousarray(
            data2["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
        features = np.ascontiguousarray(np.concatenate([features, features2]))
        indices = np.ascontiguousarray(np.concatenate([indices, indices2]))
        return features, indices, l
Esempio n. 2
0
    def testSpDeConv3d(self):
        np.random.seed(484)
        devices = ["cuda:0", "cpu:0"]
        shapes = [[19, 18, 17]]
        batchsizes = [1, 2]

        in_channels = [64]
        out_channels = [32, 48, 64]
        ksizes = [2, 3]
        strides = [2, 3]
        paddings = [0, 1, 2]
        dilations = [1, 2, 3]

        for dev, shape, bs, IC, OC, k, s, p, d in params_grid(
                devices, shapes, batchsizes, in_channels, out_channels, ksizes,
                strides, paddings, dilations):
            if all([s > 1, d > 1]):
                continue  # don't support this.
            device = torch.device(dev)
            num_points = [1000] * bs

            sparse_dict = generate_sparse_data(shape, num_points, IC)

            features = np.ascontiguousarray(sparse_dict["features"]).astype(np.float32)
            indices = np.ascontiguousarray(sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
            features_dense = sparse_dict["features_dense"].astype(np.float32)
            filters = np.random.uniform(0, 1, size=[k, k, k, IC, OC]).astype(np.float32)
            indices_t = torch.from_numpy(indices).int().to(device)
            features_t = torch.from_numpy(features).to(device)
            features_t.requires_grad = True
            features_dense_t = torch.from_numpy(features_dense).to(device)
            features_dense_t.requires_grad = True
            net = SparseDeConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d).to(device)
            net_ref = DeConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d).to(device)
            filters_t = torch.from_numpy(filters).to(device)
            net_ref.net[0].weight.data[:] = filters_t.permute(3, 4, 0, 1, 2).contiguous()
            net.net[0].weight.data[:] = filters_t
            out_ref = net_ref(features_dense_t)
            out = net(features_t, indices_t, bs).dense()
            dout = np.random.uniform(-0.2, 0.2, out_ref.shape).astype(features.dtype)
            dout_t = torch.from_numpy(dout).to(device)
            out.backward(dout_t)
            out_ref.backward(dout_t)
            din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4, 1).contiguous()
            din_sparse = gather_nd(din_dense, indices_t.long())
            din = features_t.grad.detach()
            din_np = din.cpu().numpy()
            din_sparse_np = din_sparse.cpu().numpy()
            self.assertAllClose(din_np, din_sparse_np, atol=1e-4)
            for layer, layer_ref in zip(net.net, net_ref.net):
                dw = layer.weight.grad.detach().cpu().numpy()
                dw_ref = layer_ref.weight.grad.detach().cpu().numpy()
                dw = dw.transpose(3, 4, 0, 1, 2)
                self.assertAllClose(dw, dw_ref, atol=1e-4)

            out_np = out.detach().cpu().numpy()
            out_ref_np = out_ref.detach().cpu().numpy()
            self.assertAllClose(out_np, out_ref_np, atol=1e-4)
Esempio n. 3
0
    def testSpMaxPool3d(self):
        np.random.seed(485)
        devices = ["cuda:0", "cpu:0"]
        shapes = [[19, 18, 17]]
        batchsizes = [1, 2]

        in_channels = [64]
        out_channels = [64]
        ksizes = [2, 3]
        strides = [1, 2, 3]
        paddings = [0, 1]
        dilations = [1, 2, 3]

        for dev, shape, bs, IC, OC, k, s, p, d in params_grid(
                devices, shapes, batchsizes, in_channels, out_channels, ksizes,
                strides, paddings, dilations):
            if all([s > 1, d > 1]):
                continue  # don't support this.
            device = torch.device(dev)
            num_points = [1000] * bs
            # when data contains negative, sparse maxpool is not equal to dense maxpool.
            sparse_dict = generate_sparse_data(shape, num_points, IC, data_range=[0.1, 1])

            features = np.ascontiguousarray(sparse_dict["features"]).astype(np.float32)
            indices = np.ascontiguousarray(sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
            features_dense = sparse_dict["features_dense"].astype(np.float32)
            filters = np.random.uniform(0, 1, size=[k, k, k, IC, OC]).astype(np.float32)
            indices_t = torch.from_numpy(indices).int().to(device)
            features_t = torch.from_numpy(features).to(device)
            features_t.requires_grad = True
            features_dense_t = torch.from_numpy(features_dense).to(device)
            features_dense_t.requires_grad = True
            net = SparseMaxPoolTestTorch(1, 3, shape, k, s, p, d).to(device)
            net_ref = MaxPool3dTestTorch(1, 3, shape, k, s, p, d).to(device)

            out_ref = net_ref(features_dense_t)
            out = net(features_t, indices_t, bs)
            outids = out.indices
            outfeatures = out.features
            out_dense = out.dense(channels_first=False)
            out = out_dense.permute(0, 4, 1, 2, 3).contiguous()

            dout_sparse = np.random.uniform(-0.2, 0.2, outfeatures.shape).astype(features.dtype)
            dout_sparse_t = torch.from_numpy(dout_sparse).to(device)
            dout_t = scatter_nd(outids.long(), dout_sparse_t, list(out_dense.shape))
            dout_t = dout_t.permute(0, 4, 1, 2, 3).contiguous()
            out.backward(dout_t)
            out_ref.backward(dout_t)
            din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4, 1).contiguous()
            din_sparse = gather_nd(din_dense, indices_t.long())
            din = features_t.grad.detach()
            din_np = din.cpu().numpy()
            din_sparse_np = din_sparse.cpu().numpy()
            self.assertAllClose(din_np, din_sparse_np, atol=1e-4)

            out_np = out.detach().cpu().numpy()
            out_ref_np = out_ref.detach().cpu().numpy()
            self.assertAllClose(out_np, out_ref_np, atol=1e-4)
Esempio n. 4
0
def main():
    # function for develop.
    np.random.seed(484)
    devices = ["cuda:0"]
    shapes = [[50, 30, 30]]
    batchsizes = [3]

    in_channels = [256]
    out_channels = [256]
    ksizes = [3]
    strides = [1]
    paddings = [0]
    dilations = [1]

    for dev, shape, bs, IC, OC, k, s, p, d in params_grid(
            devices, shapes, batchsizes, in_channels, out_channels, ksizes,
            strides, paddings, dilations):
        if all([s > 1, d > 1]):
            continue
        device = torch.device(dev)
        num_points = [5000] * bs

        sparse_dict = generate_sparse_data(shape, num_points, IC)

        features = np.ascontiguousarray(sparse_dict["features"]).astype(
            np.float32)
        indices = np.ascontiguousarray(
            sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
        features_dense = sparse_dict["features_dense"].astype(np.float32)
        indices_t = torch.from_numpy(indices)
        filters = np.random.uniform(0, 1, size=[k, k, k, IC,
                                                OC]).astype(np.float32)
        indices_t = torch.from_numpy(indices).int().to(device).half()
        features_t = torch.from_numpy(features).to(device).half()

        features_dense_t = torch.from_numpy(features_dense).to(device).half()
        net = SparseConv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
                                    d).to(device).half()
        net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
                                  d).to(device).half()
        filters_t = torch.from_numpy(filters).to(device).half()
        net_ref.net[0].weight[:] = filters_t.permute(4, 3, 0, 1,
                                                     2).contiguous()
        net.net[0].weight[:] = filters_t
        out_ref = net_ref(features_dense_t)
        times = []
        for i in range(30):
            t = time.time()
            out = net(features_t, indices_t, bs)
            torch.cuda.synchronize()
            times.append(time.time() - t)
        # print((net.grid == -1).float().sum(), net.grid.numel())
        # print("spconv time", time.time() - t)
        print("spconv time", np.mean(times[2:]))
        out = net(features_t, indices_t, bs).dense()
        print(
            np.linalg.norm(out.detach().cpu().numpy() -
                           out_ref.detach().cpu().numpy()))
Esempio n. 5
0
    def testSpCpConv3d(self):
        np.random.seed(484)
        devices = ["cuda:0", "cpu:0"]
        shapes = [[20, 20, 20]]
        batchsizes = [1, 2]

        in_channels = [64]
        out_channels = [32, 48, 64]
        ksizes = [2]
        strides = [2]
        paddings = [0, 1, 2]
        dilations = [1, 2, 3]

        for dev, shape, bs, IC, OC, k, s in params_grid(
                devices, shapes, batchsizes, in_channels, out_channels, ksizes,
                strides):
            device = torch.device(dev)
            num_points = [1000] * bs

            sparse_dict = generate_sparse_data(shape, num_points, IC)

            features = np.ascontiguousarray(sparse_dict["features"]).astype(np.float32)
            indices = np.ascontiguousarray(sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
            features_dense = sparse_dict["features_dense"].astype(np.float32)
            filters = np.random.uniform(0, 1, size=[k, k, k, IC, OC]).astype(np.float32)
            indices_t = torch.from_numpy(indices).int().to(device)
            indices_scn_t = torch.from_numpy(indices[:, [1, 2, 3, 0]]).int().to(device)
            features_t = torch.from_numpy(features).to(device)
            features_t.requires_grad = True
            features_ref_t = torch.from_numpy(features).to(device)
            features_ref_t.requires_grad = True

            net_ref = SCNCoupleDeConvTest(1, 3, shape, IC, OC, k, s).to(device)
            net = SparseCoupleDeConvTest(1, 3, shape, IC, OC, k, s).to(device)
            net_ref.net[0].weight.data[:] = net.net[0].weight.data[:].view(*net_ref.net[0].weight.shape)
            net_ref.net[1].weight.data[:] = net.net[1].weight.data[:].view(*net_ref.net[1].weight.shape)
            out_ref = net_ref(features_ref_t, indices_scn_t, bs)
            out = net(features_t, indices_t, bs)
            dout = np.random.uniform(-0.2, 0.2, out_ref.shape).astype(features.dtype)
            dout_t = torch.from_numpy(dout).to(device)
            out.backward(dout_t)
            out_ref.backward(dout_t)
            din = features_t.grad.detach()
            din_ref = features_ref_t.grad.detach()
            din_np = din.cpu().numpy()
            din_ref_np = din_ref.cpu().numpy()
            self.assertAllClose(din_ref_np, din_np, atol=1e-4)
            for layer, layer_ref in zip(net.net, net_ref.net):
                dw = layer.weight.grad.detach().cpu().numpy()
                dw_ref = layer_ref.weight.grad.detach().cpu().view(*dw.shape).numpy()
                self.assertAllClose(dw, dw_ref, atol=1e-4)

            out_np = out.detach().cpu().numpy()
            out_ref_np = out_ref.detach().cpu().numpy()
            self.assertAllClose(out_np, out_ref_np, atol=1e-4)
Esempio n. 6
0
def main_subm(algo, dtype=torch.float32):
    # function for develop.
    np.random.seed(484)
    torch.manual_seed(50051)
    # devices = ["cuda:0"]
    devices = ["cuda:0"]
    shapes = [[400, 400, 15]]
    batchsizes = [2]

    in_channels = [32]
    out_channels = [64]
    ksizes = [(3, 3, 3)]
    strides = [1]
    paddings = [1]
    dilations = [1]
    for dev, shape, bs, IC, OC, k, s, p, d in params_grid(
            devices, shapes, batchsizes, in_channels, out_channels, ksizes,
            strides, paddings, dilations):
        if all([s > 1, d > 1]):
            continue
        device = torch.device(dev)
        num_points = [120000] * bs

        sparse_dict = generate_sparse_data(shape, num_points, IC)

        features = np.ascontiguousarray(sparse_dict["features"]).astype(
            np.float32)
        indices = np.ascontiguousarray(
            sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
        features_dense = sparse_dict["features_dense"].astype(np.float32)
        indices_t = torch.from_numpy(indices)
        filters = np.random.uniform(0, 1, size=[k[0], 1, 1, IC,
                                                OC]).astype(np.float32)
        indices_t = torch.from_numpy(indices).int().to(device).to(dtype)
        features_t = torch.from_numpy(features).to(device).to(dtype)

        features_dense_t = torch.from_numpy(features_dense).to(device).to(
            dtype)
        net = SubMConv3dTestTorch(1, 3, shape, IC, OC, k, s, p, d,
                                  algo=algo).to(device).to(dtype)
        net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
                                  d).to(device).to(dtype)
        filters_t = torch.from_numpy(filters).to(device).to(dtype)
        net_ref.net[0].weight[:] = filters_t.permute(4, 3, 0, 1,
                                                     2).contiguous()
        net.net[0].weight[:] = filters_t
        out_ref = net_ref(features_dense_t)
        times = []
        for i in range(20):
            t = time.time()
            out = net(features_t, indices_t, bs)
            torch.cuda.synchronize()
            times.append(time.time() - t)
        # print((net.grid == -1).float().sum(), net.grid.numel())
        # print("spconv time", time.time() - t)
        print("spconv time", np.mean(times[10:]))
        out = net(features_t, indices_t, bs)
        # print(out.indices)
        out = out.dense()
        out_numpy = out.detach().cpu().numpy()
        # print(
        #     np.linalg.norm(out.detach().cpu().numpy() -
        #                    out_ref.detach().cpu().numpy()))
        print(out_numpy.min(), out_numpy.max(), out_numpy.mean(),
              out_numpy.sum())
    return out_numpy
Esempio n. 7
0
    def testSpConv3d(self):
        np.random.seed(484)
        torch.manual_seed(48848)
        devices = ["cuda:0"]
        shapes = [[19, 18, 17]]
        batchsizes = [1, 2]

        in_channels = [32]
        out_channels = [32, 48, 64]
        ksizes = [2, 3]
        strides = [1, 2, 3]
        paddings = [0, 1, 2]
        dilations = [1, 2, 3]
        algos = [
            ConvAlgo.Native, ConvAlgo.MaskImplicitGemm,
            ConvAlgo.MaskSplitImplicitGemm
        ]
        algos = [ConvAlgo.MaskSplitImplicitGemm]

        for dev, shape, bs, IC, OC, k, s, p, d, al in params_grid(
                devices, shapes, batchsizes, in_channels, out_channels, ksizes,
                strides, paddings, dilations, algos):
            if all([s > 1, d > 1]):
                continue  # don't support this.
            print(k, s, p, d)
            device = torch.device(dev)
            num_points = [1000] * bs
            dtype = torch.float32
            net = SparseConv3dTestTorch(1,
                                        3,
                                        shape,
                                        IC,
                                        OC,
                                        k,
                                        s,
                                        p,
                                        d,
                                        algo=al).to(device).to(dtype)
            net_ref = Conv3dTestTorch(1, 3, shape, IC, OC, k, s, p,
                                      d).to(device).to(dtype)

            sparse_dict = generate_sparse_data(shape, num_points, IC)

            features = np.ascontiguousarray(sparse_dict["features"]).astype(
                np.float32)
            indices = np.ascontiguousarray(
                sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)
            features_dense = sparse_dict["features_dense"].astype(np.float32)
            indices_t = torch.from_numpy(indices).int().to(device)
            features_t = torch.from_numpy(features).to(device).to(dtype)
            features_t.requires_grad = True
            features_dense_t = torch.from_numpy(features_dense).to(device).to(
                dtype)
            features_dense_t.requires_grad = True
            if net.algo == ConvAlgo.Native:
                if FILTER_HWIO:
                    filters = np.random.uniform(-1, 1,
                                                size=[k, k, k, IC,
                                                      OC]).astype(np.float32)
                else:
                    filters = np.random.uniform(-1, 1,
                                                size=[k, k, k, OC,
                                                      IC]).astype(np.float32)
                filters_t = torch.from_numpy(filters).to(device).to(dtype)
                if FILTER_HWIO:
                    net_ref.net[0].weight.data[:] = filters_t.permute(
                        4, 3, 0, 1, 2).contiguous()
                else:
                    net_ref.net[0].weight.data[:] = filters_t.permute(
                        3, 4, 0, 1, 2).contiguous()
            else:
                filters = np.random.uniform(-1, 1,
                                            size=[OC, k, k, k,
                                                  IC]).astype(np.float32)
                filters_t = torch.from_numpy(filters).to(device).to(dtype)
                net_ref.net[0].weight.data[:] = filters_t.permute(
                    0, 4, 1, 2, 3).contiguous()

            net.net[0].weight.data[:] = filters_t
            out_ref = net_ref(features_dense_t)
            out = net(features_t, indices_t, bs).dense()
            out_np = out.detach().cpu().numpy()
            out_ref_np = out_ref.detach().cpu().numpy()
            self.assertAllClose(out_np, out_ref_np, atol=1e-4)

            dout = np.random.uniform(-0.2, 0.2,
                                     out_ref.shape).astype(features.dtype)
            dout_t = torch.from_numpy(dout).to(device)
            out.backward(dout_t)
            out_ref.backward(dout_t)
            din_dense = features_dense_t.grad.detach().permute(0, 2, 3, 4,
                                                               1).contiguous()
            din_sparse = gather_nd(din_dense, indices_t.long())
            din = features_t.grad.detach()

            din_np = din.cpu().numpy()
            din_sparse_np = din_sparse.cpu().numpy()
            for layer, layer_ref in zip(net.net, net_ref.net):
                dw = layer.weight.grad.detach().cpu().numpy()
                dw_ref = layer_ref.weight.grad.detach().cpu().numpy()
                if net.algo == ConvAlgo.Native:
                    if FILTER_HWIO:
                        dw = dw.transpose(4, 3, 0, 1, 2)
                    else:
                        dw = dw.transpose(3, 4, 0, 1, 2)
                else:
                    # OHWI -> OIHW
                    dw = dw.transpose(0, 4, 1, 2, 3)

                self.assertAllClose(dw, dw_ref, atol=1e-4)
            self.assertAllClose(din_np, din_sparse_np, atol=1e-4)
Esempio n. 8
0
def dev_subm_inds_v2(subm: bool = False, run_conv: bool = True):
    limit_input_n = 16384
    limit_input_n = None
    np.random.seed(484)

    with (PACKAGE_ROOT.parent / "test/data/test_spconv.pkl").open("rb") as f:
        voxels_np, indices_np, spatial_shape = pickle.load(f)
        from spconv.test_utils import generate_sparse_data
        voxels_np = voxels_np[:limit_input_n]
        indices_np = indices_np[:limit_input_n]

        spatial_shape = [19, 18, 17]
        sparse_dict = generate_sparse_data(spatial_shape, [1024], 128)

        voxels_np = np.ascontiguousarray(sparse_dict["features"]).astype(
            np.float32)
        indices_np = np.ascontiguousarray(
            sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)

        voxels = tv.from_numpy(voxels_np).cuda()
        indices = tv.from_numpy(indices_np).cuda()
        indices_th = torch.from_numpy(indices_np).cuda()
    print(spatial_shape, indices_np.shape)
    ndim = 3
    if subm:
        ksize = [3, 3, 3]
        kv = np.prod(ksize)
        padding = [1] * ndim
        stride = [1] * ndim
        dilation = [1] * ndim
        out_padding = [0] * ndim
    else:
        ksize = [2, 2, 2]
        kv = np.prod(ksize)
        padding = [0] * ndim
        stride = [1] * ndim
        dilation = [1] * ndim
        out_padding = [0] * ndim
    out_inds, pair_ref, indice_num_per_loc = ops.get_indice_pairs(
        indices_th, 1, spatial_shape, ConvAlgo.Native, ksize, stride, padding,
        dilation, out_padding, subm)
    indice_num_per_loc_np = indice_num_per_loc.cpu().numpy()
    indice_pairs_np = pair_ref.cpu().numpy()
    algo = ConvAlgo.MaskSplitImplicitGemm
    if algo == ConvAlgo.MaskImplicitGemm:
        num_split = 1
    else:
        num_split = 2
    for i in range(5):
        res = ops.get_indice_pairs_implicit_gemm(indices_th, 1, spatial_shape,
                                                 algo, ksize, stride, padding,
                                                 dilation, out_padding, subm)
    out_inds = res[0]
    num_inds_per_loc = res[1]
    pair_fwd = res[2]
    pair_fwd_x = pair_fwd.cpu().numpy().reshape(-1)
    pair_fwd_x[pair_fwd_x == -1] = 0
    loc_num_np = (pair_fwd_x > 0).reshape(kv, -1).sum(1)
    print(loc_num_np)
    print(indice_num_per_loc_np)

    pair_bwd = res[3]
    pair_mask_fwd_splits = res[4]
    pair_mask_bwd_splits = res[5]
    mask_argsort_fwd_splits = res[6]
    mask_argsort_bwd_splits = res[7]
    masks = res[8]
    pair_mask_fwd_splits_tv = [
        ops.torch_tensor_to_tv(t, dtype=tv.uint32)
        for t in pair_mask_fwd_splits
    ]
    valid_location_bitcount = [
        SpconvOps.count_bits(t) for t in pair_mask_fwd_splits_tv
    ]
    valid_location_count = sum(
        [t.cpu().numpy().sum() for t in valid_location_bitcount])
    reduce_length = 32
    split_mask_valid_count = sum([
        reduce_mask_count(t.cpu().numpy(), reduce_length)
        for t in pair_mask_fwd_splits_tv
    ])
    if subm:
        print("SUBM", valid_location_count, split_mask_valid_count,
              pair_fwd.numel())
    else:
        print("REGULAR", valid_location_count, split_mask_valid_count,
              pair_fwd.numel())
    # return

    if run_conv:
        C = 64
        K = 64
        desps = CONV.desps
        mask_output_fwd = torch.zeros([2, div_up(out_inds.shape[0], 32)],
                                      dtype=torch.int32,
                                      device=indices_th.device)
        mask_output_bwd = torch.zeros([2, div_up(indices.dim(0), 32)],
                                      dtype=torch.int32,
                                      device=indices_th.device)

        for desp in desps:
            if desp.algo != GemmAlgo.Simt.value:
                continue
            # if desp.op_type == ConvOpType.kBackwardWeight.value:
            #     continue
            # if desp.tile_shape !
            if desp.dtype_a == dtypes.int8.tv_dtype:
                inp = np.random.randint(-1, 1, size=[voxels_np.shape[0],
                                                     C]).astype(np.int8)
                weight = np.random.randint(-1, 1, size=[K, *ksize,
                                                        C]).astype(np.int8)
                output = np.random.randint(-1, 1, size=[
                    out_inds.shape[0], K
                ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_output))
            else:
                inp = np.random.uniform(-1, 1, size=[
                    voxels_np.shape[0], C
                ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_input))
                weight = np.random.uniform(-1, 1, size=[K, *ksize, C]).astype(
                    dtypes.get_npdtype_from_tvdtype(desp.dtype_weight))
                output = np.random.uniform(-1, 1, size=[
                    out_inds.shape[0], K
                ]).astype(dtypes.get_npdtype_from_tvdtype(desp.dtype_output))
            weight_ref = weight.transpose(1, 2, 3, 0, 4)
            weight_ref = np.ascontiguousarray(weight_ref).reshape(-1, K, C)
            if desp.op_type == ConvOpType.kBackwardInput.value:
                inp_tv = tv.zeros(inp.shape, desp.dtype_input, 0)
            else:
                inp_tv = tv.from_numpy(inp).cuda()
            if desp.op_type == ConvOpType.kBackwardWeight.value:
                weight_tv = tv.zeros(weight.shape, desp.dtype_weight, 0)
            else:
                weight_tv = tv.from_numpy(weight).cuda()
            # _ = tv.zeros([5000, 10], tv.float32, 0)
            if desp.op_type == ConvOpType.kForward.value:
                output_tv = tv.zeros(output.shape, desp.dtype_output, 0)
            else:
                output_tv = tv.from_numpy(output).cuda()
            torch.cuda.synchronize()
            t = time.time()
            spk = 1
            if desp.op_type == ConvOpType.kBackwardWeight.value:
                # TODO support splitk parallel
                spk = 32
            if subm:
                if desp.op_type == ConvOpType.kForward.value:
                    indice_pairs = pair_fwd
                elif desp.op_type == ConvOpType.kBackwardInput.value:
                    indice_pairs = pair_bwd
                else:
                    indice_pairs = pair_fwd
                mask_output = mask_output_fwd
                # print([bin(x.item()) for x in masks])
                for j in range(num_split):
                    beta = 1 if j == 1 else 0
                    mask_filter = 0xffffffff
                    mask_filter = masks[j].item()

                    reverse_mask = False
                    if desp.op_type == ConvOpType.kBackwardWeight.value:
                        mask_op = mask_output[j]
                    else:
                        mask_op = pair_mask_fwd_splits[j]
                    if desp.op_type == ConvOpType.kBackwardInput.value:
                        reverse_mask = True
                    CONV.run_with_tuned_result(
                        BestConvAlgoByProfile(desp, spk),
                        desp.op_type,
                        inp_tv,
                        weight_tv,
                        output_tv,
                        torch_tensor_to_tv(mask_op, dtype=tv.uint32),
                        torch_tensor_to_tv(mask_argsort_fwd_splits[j]),
                        torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
                        torch_tensor_to_tv(indice_pairs),
                        reverse_mask,
                        mask_filter=mask_filter,
                        mask_width=32,
                        beta=beta,
                        verbose=True,
                    )
            else:
                if desp.op_type == ConvOpType.kForward.value:
                    indice_pairs = pair_fwd  # inp -> out
                    mask_ops = pair_mask_fwd_splits
                    mask_argsorts = mask_argsort_fwd_splits
                    mask_output = mask_output_fwd
                elif desp.op_type == ConvOpType.kBackwardInput.value:
                    indice_pairs = pair_bwd  # out -> inp
                    mask_ops = pair_mask_bwd_splits
                    mask_argsorts = mask_argsort_bwd_splits
                    mask_output = mask_output_bwd

                    print([bin(x.item()) for x in masks])
                else:
                    indice_pairs = pair_fwd  # inp -> out
                    mask_ops = pair_mask_fwd_splits
                    mask_argsorts = mask_argsort_fwd_splits
                    mask_output = mask_output_fwd

                for j in range(2):
                    beta = 1 if j == 1 else 0
                    mask_filter = masks[j].item()
                    reverse_mask = False
                    if desp.op_type == ConvOpType.kBackwardWeight.value:
                        mask_op = mask_output[j]
                    else:
                        mask_op = mask_ops[j]

                    CONV.run_with_tuned_result(
                        BestConvAlgoByProfile(desp, spk),
                        desp.op_type,
                        inp_tv,
                        weight_tv,
                        output_tv,
                        torch_tensor_to_tv(mask_op, dtype=tv.uint32),
                        torch_tensor_to_tv(mask_argsorts[j]),
                        torch_tensor_to_tv(mask_output[j], dtype=tv.uint32),
                        torch_tensor_to_tv(indice_pairs),
                        reverse_mask,
                        mask_filter=mask_filter,
                        mask_width=32,
                        beta=beta,
                        verbose=True,
                    )

            torch.cuda.synchronize()
            duration = time.time() - t
            if desp.op_type == ConvOpType.kForward.value:
                output_ref = np.zeros_like(output, dtype=np.float32)
                # ref algorithm
                for filter_offset in range(kv):
                    if subm and filter_offset > kv // 2:
                        nhot = indice_num_per_loc_np[kv - 1 - filter_offset]
                    elif subm and filter_offset == kv // 2:
                        nhot = voxels.shape[0]
                    else:
                        nhot = indice_num_per_loc_np[filter_offset]
                    a_inds = indice_pairs_np[0][filter_offset][:nhot]
                    c_inds = indice_pairs_np[1][filter_offset][:nhot]
                    # print(a_inds_cpu[:10])
                    a = inp[a_inds]
                    cc = a.astype(
                        np.float32) @ weight_ref[filter_offset].T.astype(
                            np.float32)
                    output_ref[c_inds] += cc

                output_cpu = output_tv.cpu().numpy().astype(np.float32)
                duration = time.time() - t
                my = output_cpu.reshape(-1)
                print("ERROR", np.linalg.norm(output_ref.reshape(-1) - my))

            elif desp.op_type == ConvOpType.kBackwardInput.value:
                dinput_ref = np.zeros_like(inp, dtype=np.float32)
                # ref algorithm
                for filter_offset in range(kv):
                    if subm and filter_offset > kv // 2:
                        nhot = indice_num_per_loc_np[kv - 1 - filter_offset]
                    elif subm and filter_offset == kv // 2:
                        nhot = voxels.shape[0]
                    else:
                        nhot = indice_num_per_loc_np[filter_offset]
                    a_inds = indice_pairs_np[1][filter_offset][:nhot]
                    c_inds = indice_pairs_np[0][filter_offset][:nhot]

                    # print(a_inds_cpu[:10])
                    a = output[a_inds]
                    # NK @ KC
                    cc = a.astype(
                        np.float32) @ weight_ref[filter_offset].astype(
                            np.float32)
                    dinput_ref[c_inds] += cc
                din_cpu = inp_tv.cpu().numpy()
                print(
                    "ERROR",
                    np.linalg.norm(
                        din_cpu.reshape(-1) - dinput_ref.reshape(-1)))
            else:
                dw_ref = np.zeros_like(weight_ref,
                                       dtype=np.float32)  # KV, K, C
                for filter_offset in range(kv):
                    if subm and filter_offset > kv // 2:
                        nhot = indice_num_per_loc_np[kv - 1 - filter_offset]
                    elif subm and filter_offset == kv // 2:
                        nhot = voxels.shape[0]
                    else:
                        nhot = indice_num_per_loc_np[filter_offset]
                    o_inds = indice_pairs_np[1][filter_offset][:nhot]
                    i_inds = indice_pairs_np[0][filter_offset][:nhot]
                    # print(a_inds_cpu[:10])
                    out_gather = output[o_inds]  # [N, K]
                    inp_gather = inp[i_inds]  # [N, C]
                    # KN @ NC
                    dw_res = out_gather.astype(
                        np.float32).T @ inp_gather.astype(np.float32)
                    dw_ref[filter_offset] = dw_res
                # print(indice_pairs_np_test[0])
                dw_ref_kcrs = dw_ref.transpose(1, 0, 2)
                dw_cpu = weight_tv.cpu().numpy().reshape(K, np.prod(ksize), C)

                print(
                    "ERROR",
                    np.linalg.norm(
                        dw_cpu.reshape(-1) - dw_ref_kcrs.reshape(-1)))