def get_all_available( self, a: tv.Tensor, b: tv.Tensor, c: tv.Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle): if trans_c: trans_a = not trans_a trans_b = not trans_b trans_a, trans_b = trans_b, trans_a a, b = b, a trans_c = False avail_algos = get_available_algo_str_from_arch(arch) finally_algos: List[GemmAlgoDesp] = [] for algo in avail_algos: static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype, shuffle_type.value, algo) desps = self.static_key_to_desps.get(static_key, None) if desps is None or len(desps) == 0: continue for desp in desps: # skip volta tensor op since it is very slow in architectures except volta. if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value: continue lda = a.dim(1) ldb = b.dim(1) ldc = c.dim(1) if desp.supported_ldx(lda, ldb, ldc): finally_algos.append(desp) return finally_algos
def get_all_available(self, inp: tv.Tensor, weight: tv.Tensor, out: tv.Tensor, layout_i: ConvLayout, layout_w: ConvLayout, layout_o: ConvLayout, arch: Tuple[int, int], op_type: ConvOpType, mask_width: int): avail_algos = get_available_algo_str_from_arch(arch) finally_algos: List[ConvAlgoDesp] = [] for algo in avail_algos: static_key = (layout_i.layout_type.value, layout_w.layout_type.value, layout_o.layout_type.value, layout_i.interleave, layout_w.interleave, layout_o.interleave, inp.dtype, weight.dtype, out.dtype, algo, op_type.value) desps = self.static_key_to_desps.get(static_key, None) if desps is None or len(desps) == 0: continue for desp in desps: # skip volta tensor op since it is very slow in architectures except volta. if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value: continue ldi = inp.dim(-1) ldw = weight.dim(-1) ldo = out.dim(-1) mask_width_valid = True if desp.op_type == ConvOpType.kBackwardWeight.value: assert mask_width > 0 mask_width_valid = mask_width % desp.tile_shape[2] == 0 if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid: finally_algos.append(desp) return finally_algos
def get_all_available(self, inp: tv.Tensor, weight: tv.Tensor, out: tv.Tensor, layout_i: ConvLayout, layout_w: ConvLayout, layout_o: ConvLayout, arch: Tuple[int, int], op_type: ConvOpType, mask_width: int, fp32_accum: Optional[bool] = None): avail_algos = get_available_algo_str_from_arch(arch) finally_algos: List[ConvAlgoDesp] = [] is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16 use_f32_as_accum = False kv = int(np.prod(weight.shape[1:-1])) # for 3d conv, if reduce axis is too large, may cause nan during # forward. if is_fp16: if fp32_accum is None: if op_type == ConvOpType.kForward: use_f32_as_accum = weight.dim(-1) * kv > 128 * 27 elif op_type == ConvOpType.kBackwardInput: use_f32_as_accum = weight.dim(0) * kv > 128 * 27 else: use_f32_as_accum = fp32_accum for algo in avail_algos: static_key = (layout_i.layout_type.value, layout_w.layout_type.value, layout_o.layout_type.value, layout_i.interleave, layout_w.interleave, layout_o.interleave, inp.dtype, weight.dtype, out.dtype, algo, op_type.value) desps = self.static_key_to_desps.get(static_key, None) if desps is None or len(desps) == 0: continue for desp in desps: # skip volta tensor op since it is very slow in architectures except volta. if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value: continue if arch >= (7, 0) and is_fp16: # skip simt fp16 kernels if we have tensor core if desp.algo == GemmAlgo.Simt: continue if use_f32_as_accum: if desp.dacc == tv.float16: continue ldi = inp.dim(-1) ldw = weight.dim(-1) ldo = out.dim(-1) mask_width_valid = True if desp.op_type == ConvOpType.kBackwardWeight.value: assert mask_width > 0 mask_width_valid = mask_width % desp.tile_shape[2] == 0 if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid: finally_algos.append(desp) return finally_algos
def run_with_tuned_result(self, profile_res: BestConvAlgoByProfile, op_type: Union[ConvOpType, int], inp: tv.Tensor, weight: tv.Tensor, output: tv.Tensor, mask: tv.Tensor, mask_argsort: tv.Tensor, mask_output: tv.Tensor, indices: tv.Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream: int = 0, workspace: tv.Tensor = tv.Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(False)): channel_k = output.dim(1) channel_c = inp.dim(1) # GemmMainUnitTest.stream_synchronize(stream) algo_desp = profile_res.algo_desp assert algo_desp is not None split_k_slices = 1 if profile_res.splitk > 1: split_k_slices = profile_res.splitk if isinstance(op_type, int): op_type_value = op_type else: op_type_value = op_type.value params = ConvParams(NDIM_DONT_CARE, op_type_value) params.conv_algo_desp = profile_res.algo_desp params.input = inp params.verbose = verbose params.weight = weight.view([channel_k, -1, channel_c]) params.output = output params.split_k_slices = split_k_slices params.alpha = alpha params.beta = beta params.stream = stream params.mask_argsort = mask_argsort params.indices = indices params.mask = mask params.mask_filter = mask_filter params.mask_width = mask_width params.mask_filter = mask_filter params.mask_output = mask_output params.reverse_mask = reverse_mask if timer.enable: assert timer._timer is not None params.timer = timer._timer # torch.cuda.synchronize() # t = time.time() params.workspace = workspace ConvMainUnitTest.implicit_gemm2(params) # torch.cuda.synchronize() # dura = time.time() - t # print("F", algo_desp, dura) # GemmMainUnitTest.stream_synchronize(stream) return algo_desp
def tune_and_cache(self, op_type: ConvOpType, inp: tv.Tensor, weight: tv.Tensor, output: tv.Tensor, layout_i: ConvLayout, layout_w: ConvLayout, layout_o: ConvLayout, arch: Tuple[int, int], mask: tv.Tensor, mask_argsort: tv.Tensor, indices: tv.Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: tv.Tensor = tv.Tensor(), alpha: float = 1.0, beta: float = 0.0, stream: int = 0): avail = self.get_all_available(inp, weight, output, layout_i, layout_w, layout_o, arch, op_type, mask_width) inp = inp.clone() weight = weight.clone() output = output.clone() channel_k = output.dim(1) channel_c = inp.dim(1) times: List[float] = [] all_profile_res: List[BestConvAlgoByProfile] = [] for desp in avail: # for sparse conv, ndim isn't used, so we just provide a constant value. params = ConvParams(NDIM_DONT_CARE, op_type.value) params.conv_algo_desp = desp params.input = inp params.weight = weight.view([channel_k, -1, channel_c]) params.output = output params.mask_width = mask_width params.alpha = alpha params.beta = beta params.stream = stream params.mask_argsort = mask_argsort params.indices = indices params.mask = mask params.mask_output = mask_output if op_type == ConvOpType.kBackwardWeight: assert not mask_output.empty() if op_type == ConvOpType.kBackwardInput: params.reverse_mask = reverse_mask params.mask_filter = mask_filter if desp.split_k_serial and op_type == ConvOpType.kBackwardWeight: splitk_tests = [1, 2, 4, 8, 16, 32, 64] # splitk_tests = [1] else: splitk_tests = [1] spk_speeds = [] for spk in splitk_tests: this_times = [] for j in range(3): GemmMainUnitTest.stream_synchronize(stream) t = time.time() params.split_k_slices = spk ConvMainUnitTest.implicit_gemm2(params) GemmMainUnitTest.stream_synchronize(stream) this_times.append(time.time() - t) times.append(np.mean(this_times[1:])) spk_speeds.append(times[-1]) all_profile_res.append(BestConvAlgoByProfile(desp, splitk=spk)) if not all_profile_res: raise ValueError("can't find suitable algorithm for", op_type) min_time = 1000 min_idx = -1 for i, t in enumerate(times): if t < min_time: min_time = t min_idx = i res = all_profile_res[min_idx] if not op_type == ConvOpType.kBackwardWeight: # fwd and dgrad don't need mask_width = -1 key = (inp.dtype, weight.dtype, output.dtype, channel_k, channel_c, arch[0], arch[1], mask_width) with self.lock: if op_type == ConvOpType.kForward: self.kc_forward_cache[key] = res elif op_type == ConvOpType.kBackwardInput: self.kc_dgrad_cache[key] = res elif op_type == ConvOpType.kBackwardWeight: self.kc_wgrad_cache[key] = res else: raise NotImplementedError return res, min_time