def forward(ctx, features, filters, indice_pairs, indice_pair_num, num_activate_out, algo, timer: CUDAKernelTimer = CUDAKernelTimer(False)): ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.algo = algo ctx.timer = timer try: return ops.indice_conv(features, filters, indice_pairs, indice_pair_num, num_activate_out, False, algo=algo, timer=timer) except Exception as e: msg = "[Exception|indice_conv]" msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape}," msg += f"pairnum={indice_pair_num},act={num_activate_out},algo={algo}" print(msg, file=sys.stderr) spconv_save_debug_data((indice_pairs, indice_pair_num)) raise e
def forward(ctx, features: torch.Tensor, filters: torch.Tensor, pair_fwd: torch.Tensor, pair_bwd: torch.Tensor, pair_mask_fwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor], num_activate_out: int, masks: List[np.ndarray], is_train: bool, is_subm: bool, timer: CUDAKernelTimer = CUDAKernelTimer(False)): out, mask_out, mask_width = ops.implicit_gemm(features, filters, pair_fwd, pair_mask_fwd_splits, mask_argsort_fwd_splits, num_activate_out, masks, is_train, is_subm, timer) ctx.save_for_backward(features, filters, pair_fwd, pair_bwd) ctx.mask_width = mask_width ctx.mask_out = mask_out ctx.timer = timer ctx.pair_mask_fwd_splits = pair_mask_fwd_splits ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits ctx.pair_mask_bwd_splits = pair_mask_bwd_splits ctx.mask_argsort_bwd_splits = mask_argsort_bwd_splits # ctx.num_activate_out = num_activate_out ctx.masks = masks ctx.is_subm = is_subm return out
def __init__(self, features: torch.Tensor, indices: torch.Tensor, spatial_shape: List[int], batch_size: int, grid: Optional[torch.Tensor] = None, voxel_num: Optional[torch.Tensor] = None, indice_dict: Optional[dict] = None, benchmark: bool = False, permanent_thrust_allocator: bool = False, enable_timer: bool = False): """ Args: features: [num_points, num_features] feature tensor indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0] spatial_shape: spatial shape of your sparse data batch_size: batch size of your sparse data grid: pre-allocated grid tensor. should be used when the volume of spatial shape is very large. benchmark: whether to enable benchmark. if enabled, all sparse operators will be record to SparseConvTensor. """ ndim = indices.shape[1] - 1 assert features.ndim == 2 assert indices.ndim == 2 assert len(spatial_shape) == ndim, "spatial shape must equal to ndim" assert indices.dtype == torch.int32, "only support int32" assert batch_size > 0 self._features = features self.indices = indices self.spatial_shape = spatial_shape self.batch_size = batch_size if indice_dict is None: indice_dict = {} self.indice_dict = indice_dict if grid is None: grid = torch.Tensor() # empty tensor self.grid = grid self.voxel_num = voxel_num # for tensorrt self.benchmark = benchmark self.benchmark_record = {} self.thrust_allocator: Optional[ThrustSortAllocator] = None if permanent_thrust_allocator: self.thrust_allocator = ThrustSortAllocator(features.device) self._timer = CUDAKernelTimer(enable_timer)
def forward(ctx, features: torch.Tensor, filters: torch.Tensor, pair_fwd: torch.Tensor, pair_bwd: torch.Tensor, pair_mask_fwd_splits: List[torch.Tensor], pair_mask_bwd_splits: List[torch.Tensor], mask_argsort_fwd_splits: List[torch.Tensor], mask_argsort_bwd_splits: List[torch.Tensor], num_activate_out: int, masks: List[np.ndarray], is_train: bool, is_subm: bool, timer: CUDAKernelTimer = CUDAKernelTimer(False), fp32_accum: Optional[bool] = None): try: out, mask_out, mask_width = ops.implicit_gemm( features, filters, pair_fwd, pair_mask_fwd_splits, mask_argsort_fwd_splits, num_activate_out, masks, is_train, is_subm, timer, fp32_accum) except Exception as e: msg = "[Exception|implicit_gemm]" msg += f"feat={features.shape},w={filters.shape},pair={pair_fwd.shape}," msg += f"act={num_activate_out},issubm={is_subm},istrain={is_train}" print(msg, file=sys.stderr) spconv_save_debug_data( (pair_fwd, pair_bwd, pair_mask_fwd_splits, pair_mask_bwd_splits, mask_argsort_fwd_splits, mask_argsort_bwd_splits, masks)) raise e ctx.save_for_backward(features, filters, pair_fwd, pair_bwd) ctx.mask_width = mask_width ctx.mask_out = mask_out ctx.timer = timer ctx.pair_mask_fwd_splits = pair_mask_fwd_splits ctx.mask_argsort_fwd_splits = mask_argsort_fwd_splits ctx.pair_mask_bwd_splits = pair_mask_bwd_splits ctx.mask_argsort_bwd_splits = mask_argsort_bwd_splits # ctx.num_activate_out = num_activate_out ctx.masks = masks ctx.is_subm = is_subm ctx.fp32_accum = fp32_accum return out
def forward(ctx, features, filters, indice_pairs, indice_pair_num, num_activate_out, algo, timer: CUDAKernelTimer = CUDAKernelTimer(False)): ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters) ctx.algo = algo ctx.timer = timer return ops.indice_conv(features, filters, indice_pairs, indice_pair_num, num_activate_out, False, algo=algo, timer=timer)
def run_with_tuned_result(self, profile_res: BestConvAlgoByProfile, op_type: Union[ConvOpType, int], inp: tv.Tensor, weight: tv.Tensor, output: tv.Tensor, mask: tv.Tensor, mask_argsort: tv.Tensor, mask_output: tv.Tensor, indices: tv.Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream: int = 0, workspace: tv.Tensor = tv.Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(False)): channel_k = output.dim(1) channel_c = inp.dim(1) # GemmMainUnitTest.stream_synchronize(stream) algo_desp = profile_res.algo_desp assert algo_desp is not None split_k_slices = 1 if profile_res.splitk > 1: split_k_slices = profile_res.splitk if isinstance(op_type, int): op_type_value = op_type else: op_type_value = op_type.value params = ConvParams(NDIM_DONT_CARE, op_type_value) params.conv_algo_desp = profile_res.algo_desp params.input = inp params.verbose = verbose params.weight = weight.view([channel_k, -1, channel_c]) params.output = output params.split_k_slices = split_k_slices params.alpha = alpha params.beta = beta params.stream = stream params.mask_argsort = mask_argsort params.indices = indices params.mask = mask params.mask_filter = mask_filter params.mask_width = mask_width params.mask_filter = mask_filter params.mask_output = mask_output params.reverse_mask = reverse_mask if timer.enable: assert timer._timer is not None params.timer = timer._timer # torch.cuda.synchronize() # t = time.time() params.workspace = workspace ConvMainUnitTest.implicit_gemm2(params) # torch.cuda.synchronize() # dura = time.time() - t # print("F", algo_desp, dura) # GemmMainUnitTest.stream_synchronize(stream) return algo_desp
def run_with_tuned_result( self, profile_res: BestAlgoByProfile, a: tv.Tensor, b: tv.Tensor, c: tv.Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], stream: int, shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle, a_inds: tv.Tensor = tv.Tensor(), b_inds: tv.Tensor = tv.Tensor(), c_inds: tv.Tensor = tv.Tensor(), hint: int = AlgoHint.NoHint.value, alpha: float = 1.0, beta: float = 0.0, gather_data: tv.Tensor = tv.Tensor(), workspace: tv.Tensor = tv.Tensor(), timer: CUDAKernelTimer = CUDAKernelTimer(False)): m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a, trans_b, trans_c, shuffle_type.value, a_inds.shape, b_inds.shape, c_inds.shape) # GemmMainUnitTest.stream_synchronize(stream) algo_desp = profile_res.algo_desp assert algo_desp is not None split_k_slices = 1 # TODO better splitk selection # if algo_desp.split_k_serial and hint & AlgoHint.BackwardWeight.value: # split_k_slices = max(min(32, k // 128), 1) if profile_res.splitk > 1: split_k_slices = profile_res.splitk params = GemmParams() params.a = a params.b = b params.c = c params.a_inds = a_inds params.b_inds = b_inds params.c_inds = c_inds params.algo_desp = algo_desp params.split_k_slices = split_k_slices params.stream = stream params.alpha = alpha params.beta = beta params.workspace = workspace # gather = 0 # if profile_res.external_gather and not gather_data.empty(): # GemmMainUnitTest.stream_synchronize(stream) # tt = time.time() # assert not gather_data.empty() # params.a_inds = tv.Tensor() # params.a = gather_data # # print(profile_res.gather_params, gather_data.shape, a.shape, a_inds.shape) # GATHER.gather(gather_data, # a, # a_inds, # *profile_res.gather_params, # stream=stream) # GemmMainUnitTest.stream_synchronize(stream) # gather = time.time() - tt if timer.enable: assert timer._timer is not None params.timer = timer._timer GemmMainUnitTest.matmul2(params) # GemmMainUnitTest.stream_synchronize(stream) return algo_desp