Exemple #1
0
 def forward(ctx,
             features,
             filters,
             indice_pairs,
             indice_pair_num,
             num_activate_out,
             algo,
             timer: CUDAKernelTimer = CUDAKernelTimer(False)):
     ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
     ctx.algo = algo
     ctx.timer = timer
     try:
         return ops.indice_conv(features,
                                filters,
                                indice_pairs,
                                indice_pair_num,
                                num_activate_out,
                                False,
                                algo=algo,
                                timer=timer)
     except Exception as e:
         msg = "[Exception|indice_conv]"
         msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
         msg += f"pairnum={indice_pair_num},act={num_activate_out},algo={algo}"
         print(msg, file=sys.stderr)
         spconv_save_debug_data((indice_pairs, indice_pair_num))
         raise e
Exemple #2
0
    def forward(ctx,
                features: torch.Tensor,
                filters: torch.Tensor,
                pair_fwd: torch.Tensor,
                pair_bwd: torch.Tensor,
                pair_mask_fwd_splits: List[torch.Tensor],
                pair_mask_bwd_splits: List[torch.Tensor],
                mask_argsort_fwd_splits: List[torch.Tensor],
                mask_argsort_bwd_splits: List[torch.Tensor],
                num_activate_out: int,
                masks: List[np.ndarray],
                is_train: bool,
                is_subm: bool,
                timer: CUDAKernelTimer = CUDAKernelTimer(False)):

        out, mask_out, mask_width = ops.implicit_gemm(features, filters,
                                                      pair_fwd,
                                                      pair_mask_fwd_splits,
                                                      mask_argsort_fwd_splits,
                                                      num_activate_out, masks,
                                                      is_train, is_subm, timer)
        ctx.save_for_backward(features, filters, pair_fwd, pair_bwd)
        ctx.mask_width = mask_width
        ctx.mask_out = mask_out
        ctx.timer = timer
        ctx.pair_mask_fwd_splits = pair_mask_fwd_splits
        ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits
        ctx.pair_mask_bwd_splits = pair_mask_bwd_splits
        ctx.mask_argsort_bwd_splits = mask_argsort_bwd_splits
        # ctx.num_activate_out = num_activate_out
        ctx.masks = masks
        ctx.is_subm = is_subm
        return out
Exemple #3
0
 def __init__(self,
              features: torch.Tensor,
              indices: torch.Tensor,
              spatial_shape: List[int],
              batch_size: int,
              grid: Optional[torch.Tensor] = None,
              voxel_num: Optional[torch.Tensor] = None,
              indice_dict: Optional[dict] = None,
              benchmark: bool = False,
              permanent_thrust_allocator: bool = False,
              enable_timer: bool = False):
     """
     Args:
         features: [num_points, num_features] feature tensor
         indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0]
         spatial_shape: spatial shape of your sparse data
         batch_size: batch size of your sparse data
         grid: pre-allocated grid tensor. should be used when the volume of spatial shape
             is very large.
         benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to
             SparseConvTensor.
     """
     ndim = indices.shape[1] - 1
     assert features.ndim == 2
     assert indices.ndim == 2
     assert len(spatial_shape) == ndim, "spatial shape must equal to ndim"
     assert indices.dtype == torch.int32, "only support int32"
     assert batch_size > 0
     self._features = features
     self.indices = indices
     self.spatial_shape = spatial_shape
     self.batch_size = batch_size
     if indice_dict is None:
         indice_dict = {}
     self.indice_dict = indice_dict
     if grid is None:
         grid = torch.Tensor()  # empty tensor
     self.grid = grid
     self.voxel_num = voxel_num  # for tensorrt
     self.benchmark = benchmark
     self.benchmark_record = {}
     self.thrust_allocator: Optional[ThrustSortAllocator] = None
     if permanent_thrust_allocator:
         self.thrust_allocator = ThrustSortAllocator(features.device)
     self._timer = CUDAKernelTimer(enable_timer)
Exemple #4
0
    def forward(ctx,
                features: torch.Tensor,
                filters: torch.Tensor,
                pair_fwd: torch.Tensor,
                pair_bwd: torch.Tensor,
                pair_mask_fwd_splits: List[torch.Tensor],
                pair_mask_bwd_splits: List[torch.Tensor],
                mask_argsort_fwd_splits: List[torch.Tensor],
                mask_argsort_bwd_splits: List[torch.Tensor],
                num_activate_out: int,
                masks: List[np.ndarray],
                is_train: bool,
                is_subm: bool,
                timer: CUDAKernelTimer = CUDAKernelTimer(False),
                fp32_accum: Optional[bool] = None):
        try:
            out, mask_out, mask_width = ops.implicit_gemm(
                features, filters, pair_fwd, pair_mask_fwd_splits,
                mask_argsort_fwd_splits, num_activate_out, masks, is_train,
                is_subm, timer, fp32_accum)
        except Exception as e:
            msg = "[Exception|implicit_gemm]"
            msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape},"
            msg += f"act={num_activate_out},issubm={is_subm},istrain={is_train}"
            print(msg, file=sys.stderr)
            spconv_save_debug_data(
                (pair_fwd, pair_bwd, pair_mask_fwd_splits,
                 pair_mask_bwd_splits, mask_argsort_fwd_splits,
                 mask_argsort_bwd_splits, masks))
            raise e

        ctx.save_for_backward(features, filters, pair_fwd, pair_bwd)
        ctx.mask_width = mask_width
        ctx.mask_out = mask_out
        ctx.timer = timer
        ctx.pair_mask_fwd_splits = pair_mask_fwd_splits
        ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits
        ctx.pair_mask_bwd_splits = pair_mask_bwd_splits
        ctx.mask_argsort_bwd_splits = mask_argsort_bwd_splits
        # ctx.num_activate_out = num_activate_out
        ctx.masks = masks
        ctx.is_subm = is_subm
        ctx.fp32_accum = fp32_accum
        return out
Exemple #5
0
 def forward(ctx,
             features,
             filters,
             indice_pairs,
             indice_pair_num,
             num_activate_out,
             algo,
             timer: CUDAKernelTimer = CUDAKernelTimer(False)):
     ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
     ctx.algo = algo
     ctx.timer = timer
     return ops.indice_conv(features,
                            filters,
                            indice_pairs,
                            indice_pair_num,
                            num_activate_out,
                            False,
                            algo=algo,
                            timer=timer)
Exemple #6
0
    def run_with_tuned_result(self,
                              profile_res: BestConvAlgoByProfile,
                              op_type: Union[ConvOpType, int],
                              inp: tv.Tensor,
                              weight: tv.Tensor,
                              output: tv.Tensor,
                              mask: tv.Tensor,
                              mask_argsort: tv.Tensor,
                              mask_output: tv.Tensor,
                              indices: tv.Tensor,
                              reverse_mask: bool,
                              mask_filter: int = 0xffffffff,
                              mask_width: int = -1,
                              alpha: float = 1.0,
                              beta: float = 0.0,
                              stream: int = 0,
                              workspace: tv.Tensor = tv.Tensor(),
                              verbose: bool = False,
                              timer: CUDAKernelTimer = CUDAKernelTimer(False)):
        channel_k = output.dim(1)
        channel_c = inp.dim(1)
        # GemmMainUnitTest.stream_synchronize(stream)
        algo_desp = profile_res.algo_desp
        assert algo_desp is not None
        split_k_slices = 1
        if profile_res.splitk > 1:
            split_k_slices = profile_res.splitk
        if isinstance(op_type, int):
            op_type_value = op_type
        else:
            op_type_value = op_type.value
        params = ConvParams(NDIM_DONT_CARE, op_type_value)
        params.conv_algo_desp = profile_res.algo_desp
        params.input = inp
        params.verbose = verbose
        params.weight = weight.view([channel_k, -1, channel_c])
        params.output = output
        params.split_k_slices = split_k_slices
        params.alpha = alpha
        params.beta = beta
        params.stream = stream
        params.mask_argsort = mask_argsort
        params.indices = indices
        params.mask = mask
        params.mask_filter = mask_filter
        params.mask_width = mask_width
        params.mask_filter = mask_filter
        params.mask_output = mask_output
        params.reverse_mask = reverse_mask
        if timer.enable:
            assert timer._timer is not None
            params.timer = timer._timer
        # torch.cuda.synchronize()
        # t = time.time()
        params.workspace = workspace
        ConvMainUnitTest.implicit_gemm2(params)
        # torch.cuda.synchronize()
        # dura = time.time() - t
        # print("F", algo_desp, dura)

        # GemmMainUnitTest.stream_synchronize(stream)
        return algo_desp
Exemple #7
0
    def run_with_tuned_result(
        self,
        profile_res: BestAlgoByProfile,
        a: tv.Tensor,
        b: tv.Tensor,
        c: tv.Tensor,
        trans_a: bool,
        trans_b: bool,
        trans_c: bool,
        arch: Tuple[int, int],
        stream: int,
        shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
        a_inds: tv.Tensor = tv.Tensor(),
        b_inds: tv.Tensor = tv.Tensor(),
        c_inds: tv.Tensor = tv.Tensor(),
        hint: int = AlgoHint.NoHint.value,
        alpha: float = 1.0,
        beta: float = 0.0,
        gather_data: tv.Tensor = tv.Tensor(),
        workspace: tv.Tensor = tv.Tensor(),
        timer: CUDAKernelTimer = CUDAKernelTimer(False)):
        m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a,
                                               trans_b, trans_c,
                                               shuffle_type.value,
                                               a_inds.shape, b_inds.shape,
                                               c_inds.shape)
        # GemmMainUnitTest.stream_synchronize(stream)
        algo_desp = profile_res.algo_desp
        assert algo_desp is not None
        split_k_slices = 1
        # TODO better splitk selection
        # if algo_desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
        #     split_k_slices = max(min(32, k // 128), 1)
        if profile_res.splitk > 1:
            split_k_slices = profile_res.splitk
        params = GemmParams()
        params.a = a
        params.b = b
        params.c = c
        params.a_inds = a_inds
        params.b_inds = b_inds
        params.c_inds = c_inds
        params.algo_desp = algo_desp
        params.split_k_slices = split_k_slices
        params.stream = stream
        params.alpha = alpha
        params.beta = beta
        params.workspace = workspace
        # gather = 0
        # if profile_res.external_gather and not gather_data.empty():
        #     GemmMainUnitTest.stream_synchronize(stream)
        #     tt = time.time()
        #     assert not gather_data.empty()
        #     params.a_inds = tv.Tensor()
        #     params.a = gather_data
        #     # print(profile_res.gather_params, gather_data.shape, a.shape, a_inds.shape)
        #     GATHER.gather(gather_data,
        #                    a,
        #                    a_inds,
        #                    *profile_res.gather_params,
        #                    stream=stream)
        #     GemmMainUnitTest.stream_synchronize(stream)
        #     gather = time.time() - tt
        if timer.enable:
            assert timer._timer is not None
            params.timer = timer._timer

        GemmMainUnitTest.matmul2(params)
        # GemmMainUnitTest.stream_synchronize(stream)
        return algo_desp