def __init__(self, device: torch.device, key_dtype: torch.dtype, value_dtype: torch.dtype, max_size: int = -1) -> None: is_cpu = device.type == "cpu" self.is_cpu = is_cpu self.key_dtype = key_dtype self.value_dtype = value_dtype key_data_tv = tv.Tensor() value_data_tv = tv.Tensor() if is_cpu: self.keys_data = None self.values_data = None else: assert max_size > 0, "you must provide max_size for fixed-size cuda hash table, usually *2 of num of keys" assert device is not None, "you must specify device for cuda hash table." self.keys_data = torch.empty([max_size], dtype=key_dtype, device=device) self.values_data = torch.empty([max_size], dtype=value_dtype, device=device) key_data_tv = torch_tensor_to_tv(self.keys_data) value_data_tv = torch_tensor_to_tv(self.values_data) stream = 0 if not self.is_cpu: stream = get_current_stream() self.key_itemsize = _TORCH_DTYPE_TO_ITEMSIZE[self.key_dtype] self.value_itemsize = _TORCH_DTYPE_TO_ITEMSIZE[self.value_dtype] self._valid_value_dtype_for_arange = set([torch.int32, torch.int64]) self._table = _HashTable(is_cpu, self.key_itemsize, self.value_itemsize, key_data_tv, value_data_tv, stream)
def items(self, max_size: int = -1): count_tv = tv.Tensor() count = torch.Tensor() stream = 0 if not self.is_cpu: stream = get_current_stream() if not self.is_cpu: assert self.values_data is not None if self.key_itemsize == 4: count = torch.zeros([1], dtype=torch.int32, device=self.values_data.device) count_tv = torch_tensor_to_tv(count, dtype=tv.uint32) elif self.key_itemsize == 8: count = torch.zeros([1], dtype=torch.int64, device=self.values_data.device) count_tv = torch_tensor_to_tv(count, dtype=tv.uint64) else: raise NotImplementedError if not self.is_cpu: assert self.values_data is not None if max_size == -1: max_size = self.values_data.shape[0] keys = torch.empty([max_size], dtype=self.key_dtype, device=self.values_data.device) values = torch.empty([max_size], dtype=self.value_dtype, device=self.values_data.device) else: max_size = self._table.size_cpu() count = torch.tensor([max_size], dtype=torch.int64) keys = torch.empty([max_size], dtype=self.key_dtype) values = torch.empty([max_size], dtype=self.value_dtype) keys_tv = torch_tensor_to_tv(keys) values_tv = torch_tensor_to_tv(values) self._table.items(keys_tv, values_tv, count_tv, stream) return keys, values, count
def assign_arange_(self): """iterate table, assign values with "arange" value. equivalent to 1. get key by items(), 2. use key and arange(key.shape[0]) to insert """ count_tv = tv.Tensor() count = torch.Tensor() stream = 0 if not self.is_cpu: stream = get_current_stream() else: assert self.value_dtype in self._valid_value_dtype_for_arange if not self.is_cpu: assert self.values_data is not None if self.key_itemsize == 4: count = torch.zeros([1], dtype=torch.int32, device=self.values_data.device) count_tv = torch_tensor_to_tv(count, dtype=tv.uint32) elif self.key_itemsize == 8: count = torch.zeros([1], dtype=torch.int64, device=self.values_data.device) count_tv = torch_tensor_to_tv(count, dtype=tv.uint64) else: raise NotImplementedError else: max_size = self._table.size_cpu() count = torch.tensor([max_size], dtype=torch.int64) self._table.assign_arange_(count_tv, stream) return count
def insert(self, keys: torch.Tensor, values: Optional[torch.Tensor] = None): """insert hash table by keys and values if values is None, only key is inserted, the value is undefined. """ keys_tv = torch_tensor_to_tv(keys) values_tv = tv.Tensor() if values is not None: values_tv = torch_tensor_to_tv(values) stream = 0 if not self.is_cpu: stream = get_current_stream() return self._table.insert(keys_tv, values_tv, stream)
def run_with_tuned_result(self, profile_res: BestConvAlgoByProfile, op_type: Union[ConvOpType, int], inp: tv.Tensor, weight: tv.Tensor, output: tv.Tensor, mask: tv.Tensor, mask_argsort: tv.Tensor, mask_output: tv.Tensor, indices: tv.Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, alpha: float = 1.0, beta: float = 0.0, stream: int = 0, workspace: tv.Tensor = tv.Tensor(), verbose: bool = False, timer: CUDAKernelTimer = CUDAKernelTimer(False)): channel_k = output.dim(1) channel_c = inp.dim(1) # GemmMainUnitTest.stream_synchronize(stream) algo_desp = profile_res.algo_desp assert algo_desp is not None split_k_slices = 1 if profile_res.splitk > 1: split_k_slices = profile_res.splitk if isinstance(op_type, int): op_type_value = op_type else: op_type_value = op_type.value params = ConvParams(NDIM_DONT_CARE, op_type_value) params.conv_algo_desp = profile_res.algo_desp params.input = inp params.verbose = verbose params.weight = weight.view([channel_k, -1, channel_c]) params.output = output params.split_k_slices = split_k_slices params.alpha = alpha params.beta = beta params.stream = stream params.mask_argsort = mask_argsort params.indices = indices params.mask = mask params.mask_filter = mask_filter params.mask_width = mask_width params.mask_filter = mask_filter params.mask_output = mask_output params.reverse_mask = reverse_mask if timer.enable: assert timer._timer is not None params.timer = timer._timer # torch.cuda.synchronize() # t = time.time() params.workspace = workspace ConvMainUnitTest.implicit_gemm2(params) # torch.cuda.synchronize() # dura = time.time() - t # print("F", algo_desp, dura) # GemmMainUnitTest.stream_synchronize(stream) return algo_desp
def tune_and_cache(self, op_type: ConvOpType, inp: tv.Tensor, weight: tv.Tensor, output: tv.Tensor, layout_i: ConvLayout, layout_w: ConvLayout, layout_o: ConvLayout, arch: Tuple[int, int], mask: tv.Tensor, mask_argsort: tv.Tensor, indices: tv.Tensor, reverse_mask: bool, mask_filter: int = 0xffffffff, mask_width: int = -1, mask_output: tv.Tensor = tv.Tensor(), alpha: float = 1.0, beta: float = 0.0, stream: int = 0): avail = self.get_all_available(inp, weight, output, layout_i, layout_w, layout_o, arch, op_type, mask_width) inp = inp.clone() weight = weight.clone() output = output.clone() channel_k = output.dim(1) channel_c = inp.dim(1) times: List[float] = [] all_profile_res: List[BestConvAlgoByProfile] = [] for desp in avail: # for sparse conv, ndim isn't used, so we just provide a constant value. params = ConvParams(NDIM_DONT_CARE, op_type.value) params.conv_algo_desp = desp params.input = inp params.weight = weight.view([channel_k, -1, channel_c]) params.output = output params.mask_width = mask_width params.alpha = alpha params.beta = beta params.stream = stream params.mask_argsort = mask_argsort params.indices = indices params.mask = mask params.mask_output = mask_output if op_type == ConvOpType.kBackwardWeight: assert not mask_output.empty() if op_type == ConvOpType.kBackwardInput: params.reverse_mask = reverse_mask params.mask_filter = mask_filter if desp.split_k_serial and op_type == ConvOpType.kBackwardWeight: splitk_tests = [1, 2, 4, 8, 16, 32, 64] # splitk_tests = [1] else: splitk_tests = [1] spk_speeds = [] for spk in splitk_tests: this_times = [] for j in range(3): GemmMainUnitTest.stream_synchronize(stream) t = time.time() params.split_k_slices = spk ConvMainUnitTest.implicit_gemm2(params) GemmMainUnitTest.stream_synchronize(stream) this_times.append(time.time() - t) times.append(np.mean(this_times[1:])) spk_speeds.append(times[-1]) all_profile_res.append(BestConvAlgoByProfile(desp, splitk=spk)) if not all_profile_res: raise ValueError("can't find suitable algorithm for", op_type) min_time = 1000 min_idx = -1 for i, t in enumerate(times): if t < min_time: min_time = t min_idx = i res = all_profile_res[min_idx] if not op_type == ConvOpType.kBackwardWeight: # fwd and dgrad don't need mask_width = -1 key = (inp.dtype, weight.dtype, output.dtype, channel_k, channel_c, arch[0], arch[1], mask_width) with self.lock: if op_type == ConvOpType.kForward: self.kc_forward_cache[key] = res elif op_type == ConvOpType.kBackwardInput: self.kc_dgrad_cache[key] = res elif op_type == ConvOpType.kBackwardWeight: self.kc_wgrad_cache[key] = res else: raise NotImplementedError return res, min_time
def run_with_tuned_result( self, profile_res: BestAlgoByProfile, a: tv.Tensor, b: tv.Tensor, c: tv.Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], stream: int, shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle, a_inds: tv.Tensor = tv.Tensor(), b_inds: tv.Tensor = tv.Tensor(), c_inds: tv.Tensor = tv.Tensor(), hint: int = AlgoHint.NoHint.value, alpha: float = 1.0, beta: float = 0.0, gather_data: tv.Tensor = tv.Tensor(), workspace: tv.Tensor = tv.Tensor(), timer: CUDAKernelTimer = CUDAKernelTimer(False)): m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a, trans_b, trans_c, shuffle_type.value, a_inds.shape, b_inds.shape, c_inds.shape) # GemmMainUnitTest.stream_synchronize(stream) algo_desp = profile_res.algo_desp assert algo_desp is not None split_k_slices = 1 # TODO better splitk selection # if algo_desp.split_k_serial and hint & AlgoHint.BackwardWeight.value: # split_k_slices = max(min(32, k // 128), 1) if profile_res.splitk > 1: split_k_slices = profile_res.splitk params = GemmParams() params.a = a params.b = b params.c = c params.a_inds = a_inds params.b_inds = b_inds params.c_inds = c_inds params.algo_desp = algo_desp params.split_k_slices = split_k_slices params.stream = stream params.alpha = alpha params.beta = beta params.workspace = workspace # gather = 0 # if profile_res.external_gather and not gather_data.empty(): # GemmMainUnitTest.stream_synchronize(stream) # tt = time.time() # assert not gather_data.empty() # params.a_inds = tv.Tensor() # params.a = gather_data # # print(profile_res.gather_params, gather_data.shape, a.shape, a_inds.shape) # GATHER.gather(gather_data, # a, # a_inds, # *profile_res.gather_params, # stream=stream) # GemmMainUnitTest.stream_synchronize(stream) # gather = time.time() - tt if timer.enable: assert timer._timer is not None params.timer = timer._timer GemmMainUnitTest.matmul2(params) # GemmMainUnitTest.stream_synchronize(stream) return algo_desp
def tune_and_cache( self, a: tv.Tensor, b: tv.Tensor, c: tv.Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle, a_inds: tv.Tensor = tv.Tensor(), b_inds: tv.Tensor = tv.Tensor(), c_inds: tv.Tensor = tv.Tensor(), hint: int = AlgoHint.NoHint.value, alpha: float = 1.0, beta: float = 0.0, gather_data: tv.Tensor = tv.Tensor(), scatter_data: tv.Tensor = tv.Tensor(), # mm_func stream: int = 0): m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a, trans_b, trans_c, shuffle_type.value, a_inds.shape, b_inds.shape, c_inds.shape) avail = self.get_all_available(a, b, c, trans_a, trans_b, trans_c, arch, shuffle_type) c_ = c.clone() times: List[float] = [] best_gather_params = (-1, -1, -1, -1) best_scatter_params = (-1, -1, -1, -1) all_profile_res: List[BestAlgoByProfile] = [] for desp in avail: c_.zero_() split_k_slices = 1 # TODO better splitk selection if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value: split_k_slices = max(min(32, k // 128), 1) params = GemmParams() params.a = a params.b = b params.c = c_ params.a_inds = a_inds params.b_inds = b_inds params.c_inds = c_inds params.algo_desp = desp params.alpha = alpha params.beta = beta params.stream = stream if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value: splitk_tests = [1, 2, 4, 8, 16, 32, 64] else: splitk_tests = [1] spk_speeds = [] for spk in splitk_tests: this_times = [] for j in range(3): GemmMainUnitTest.stream_synchronize(stream) t = time.time() params.split_k_slices = spk GemmMainUnitTest.matmul2(params) GemmMainUnitTest.stream_synchronize(stream) this_times.append(time.time() - t) times.append(np.mean(this_times[1:])) spk_speeds.append(times[-1]) all_profile_res.append(BestAlgoByProfile(desp, splitk=spk)) min_time = 1000 min_idx = -1 for i, t in enumerate(times): if t < min_time: min_time = t min_idx = i res = all_profile_res[min_idx] with self.lock: if hint & AlgoHint.BackwardWeight.value: key = (a.dtype, b.dtype, c.dtype, m, n) self.mn_cache[key] = res elif hint & AlgoHint.BackwardInput.value: key = (a.dtype, b.dtype, c.dtype, n, k) self.nk_dgrad_cache[key] = res elif hint & AlgoHint.Fowrard.value: key = (a.dtype, b.dtype, c.dtype, n, k) self.nk_forward_cache[key] = res else: raise NotImplementedError return res, min_time
def select(self, a: tv.Tensor, b: tv.Tensor, c: tv.Tensor, trans_a: bool, trans_b: bool, trans_c: bool, arch: Tuple[int, int], shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle, a_inds: tv.Tensor = tv.Tensor(), b_inds: tv.Tensor = tv.Tensor(), c_inds: tv.Tensor = tv.Tensor(), hint: int = AlgoHint.NoHint.value): m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a, trans_b, trans_c, shuffle_type.value, a_inds.shape, b_inds.shape, c_inds.shape) if trans_c: trans_a = not trans_a trans_b = not trans_b trans_a, trans_b = trans_b, trans_a a, b = b, a trans_c = False avail_algos = get_available_algo_str_from_arch(arch) finally_algos: List[GemmAlgoDesp] = [] for algo in avail_algos: static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype, shuffle_type.value, algo) desps = self.static_key_to_desps.get(static_key, None) if desps is None or len(desps) == 0: continue meta = self.static_key_to_meta[static_key] # for shuffle stride algos, we need to make channel tile size as large as possible. # so if ShuffleAC, we need to make k largest. selected_algo_desps = GemmMainUnitTest.simple_select_tile_shape( m, n, k, meta.tile_ms, meta.tile_ns, meta.tile_ks, meta.tile_shape_to_algos, large_k_first=shuffle_type == shuffle_type.ShuffleAC) if not selected_algo_desps: candidate = desps else: candidate = [desps[i] for i in selected_algo_desps] # select by hint if hint == 0: return candidate[0] if hint & (AlgoHint.Fowrard.value | AlgoHint.BackwardInput.value): # m may be huge, n and k are small # don't need mixed precision # don't need splitk finally_algos = [] if a.dtype == tv.float16: dacc = tv.float16 dcomp = tv.float16 for can in candidate: if can.dacc == dacc and can.dcomp == dcomp: finally_algos.append(can) else: finally_algos = candidate elif hint & AlgoHint.BackwardWeight.value: # k is huge # don't support i8 # if f16, acc and comp must be f32 finally_algos = [] candidate_filtered: List[GemmAlgoDesp] = list( filter(lambda x: x.split_k_serial, candidate)) if not candidate_filtered: candidate_filtered = candidate if a.dtype == tv.int8: continue elif a.dtype == tv.float16: dacc = tv.float32 dcomp = tv.float32 for can in candidate_filtered: if can.dacc == dacc and can.dcomp == dcomp: finally_algos.append(can) else: finally_algos = candidate_filtered else: return candidate[0] # print(finally_algos) if finally_algos: return finally_algos[0] return None