Beispiel #1
0
    def __init__(self, device: torch.device, key_dtype: torch.dtype, 
                value_dtype: torch.dtype, 
                max_size: int = -1) -> None:
        is_cpu = device.type == "cpu"
        self.is_cpu = is_cpu
        self.key_dtype = key_dtype
        self.value_dtype = value_dtype
        key_data_tv = tv.Tensor()
        value_data_tv = tv.Tensor()
        if is_cpu:
            self.keys_data = None 
            self.values_data = None 
        else:
            assert max_size > 0, "you must provide max_size for fixed-size cuda hash table, usually *2 of num of keys"
            assert device is not None, "you must specify device for cuda hash table."
            self.keys_data = torch.empty([max_size], dtype=key_dtype, device=device)
            self.values_data = torch.empty([max_size], dtype=value_dtype, device=device)
            key_data_tv = torch_tensor_to_tv(self.keys_data)
            value_data_tv = torch_tensor_to_tv(self.values_data)
        stream = 0
        if not self.is_cpu:
            stream = get_current_stream()
        self.key_itemsize = _TORCH_DTYPE_TO_ITEMSIZE[self.key_dtype]
        self.value_itemsize = _TORCH_DTYPE_TO_ITEMSIZE[self.value_dtype]
        self._valid_value_dtype_for_arange = set([torch.int32, torch.int64])

        self._table = _HashTable(is_cpu, self.key_itemsize, self.value_itemsize, key_data_tv, value_data_tv, stream)
Beispiel #2
0
    def items(self, max_size: int = -1):
        count_tv = tv.Tensor()
        count = torch.Tensor()
        stream = 0
        if not self.is_cpu:
            stream = get_current_stream()
        if not self.is_cpu:
            assert self.values_data is not None
            if self.key_itemsize == 4:
                count = torch.zeros([1], dtype=torch.int32, device=self.values_data.device)
                count_tv = torch_tensor_to_tv(count, dtype=tv.uint32)
            elif self.key_itemsize == 8:
                count = torch.zeros([1], dtype=torch.int64, device=self.values_data.device)
                count_tv = torch_tensor_to_tv(count, dtype=tv.uint64)
            else:
                raise NotImplementedError
        if not self.is_cpu:
            assert self.values_data is not None
            if max_size == -1:
                max_size = self.values_data.shape[0]
            keys = torch.empty([max_size], dtype=self.key_dtype, device=self.values_data.device)
            values = torch.empty([max_size], dtype=self.value_dtype, device=self.values_data.device)

        else:
            max_size = self._table.size_cpu()
            count = torch.tensor([max_size], dtype=torch.int64)
            keys = torch.empty([max_size], dtype=self.key_dtype)
            values = torch.empty([max_size], dtype=self.value_dtype)
        keys_tv = torch_tensor_to_tv(keys)
        values_tv = torch_tensor_to_tv(values)
        self._table.items(keys_tv, values_tv, count_tv, stream)
        return keys, values, count
Beispiel #3
0
    def assign_arange_(self):
        """iterate table, assign values with "arange" value.
        equivalent to 1. get key by items(), 2. use key and arange(key.shape[0]) to insert
        """
        count_tv = tv.Tensor()
        count = torch.Tensor()
        stream = 0
        if not self.is_cpu:
            stream = get_current_stream()
        else:
            assert self.value_dtype in self._valid_value_dtype_for_arange
        if not self.is_cpu:
            assert self.values_data is not None
            if self.key_itemsize == 4:
                count = torch.zeros([1], dtype=torch.int32, device=self.values_data.device)
                count_tv = torch_tensor_to_tv(count, dtype=tv.uint32)
            elif self.key_itemsize == 8:
                count = torch.zeros([1], dtype=torch.int64, device=self.values_data.device)
                count_tv = torch_tensor_to_tv(count, dtype=tv.uint64)
            else:
                raise NotImplementedError
        else:
            max_size = self._table.size_cpu()
            count = torch.tensor([max_size], dtype=torch.int64)

        self._table.assign_arange_(count_tv, stream)
        return count
Beispiel #4
0
    def insert(self, keys: torch.Tensor, values: Optional[torch.Tensor] = None):
        """insert hash table by keys and values
        if values is None, only key is inserted, the value is undefined.
        """
        keys_tv = torch_tensor_to_tv(keys)
        values_tv = tv.Tensor()
        if values is not None:
            values_tv = torch_tensor_to_tv(values)
        stream = 0
        if not self.is_cpu:
            stream = get_current_stream()

        return self._table.insert(keys_tv, values_tv, stream)
Beispiel #5
0
    def run_with_tuned_result(self,
                              profile_res: BestConvAlgoByProfile,
                              op_type: Union[ConvOpType, int],
                              inp: tv.Tensor,
                              weight: tv.Tensor,
                              output: tv.Tensor,
                              mask: tv.Tensor,
                              mask_argsort: tv.Tensor,
                              mask_output: tv.Tensor,
                              indices: tv.Tensor,
                              reverse_mask: bool,
                              mask_filter: int = 0xffffffff,
                              mask_width: int = -1,
                              alpha: float = 1.0,
                              beta: float = 0.0,
                              stream: int = 0,
                              workspace: tv.Tensor = tv.Tensor(),
                              verbose: bool = False,
                              timer: CUDAKernelTimer = CUDAKernelTimer(False)):
        channel_k = output.dim(1)
        channel_c = inp.dim(1)
        # GemmMainUnitTest.stream_synchronize(stream)
        algo_desp = profile_res.algo_desp
        assert algo_desp is not None
        split_k_slices = 1
        if profile_res.splitk > 1:
            split_k_slices = profile_res.splitk
        if isinstance(op_type, int):
            op_type_value = op_type
        else:
            op_type_value = op_type.value
        params = ConvParams(NDIM_DONT_CARE, op_type_value)
        params.conv_algo_desp = profile_res.algo_desp
        params.input = inp
        params.verbose = verbose
        params.weight = weight.view([channel_k, -1, channel_c])
        params.output = output
        params.split_k_slices = split_k_slices
        params.alpha = alpha
        params.beta = beta
        params.stream = stream
        params.mask_argsort = mask_argsort
        params.indices = indices
        params.mask = mask
        params.mask_filter = mask_filter
        params.mask_width = mask_width
        params.mask_filter = mask_filter
        params.mask_output = mask_output
        params.reverse_mask = reverse_mask
        if timer.enable:
            assert timer._timer is not None
            params.timer = timer._timer
        # torch.cuda.synchronize()
        # t = time.time()
        params.workspace = workspace
        ConvMainUnitTest.implicit_gemm2(params)
        # torch.cuda.synchronize()
        # dura = time.time() - t
        # print("F", algo_desp, dura)

        # GemmMainUnitTest.stream_synchronize(stream)
        return algo_desp
Beispiel #6
0
    def tune_and_cache(self,
                       op_type: ConvOpType,
                       inp: tv.Tensor,
                       weight: tv.Tensor,
                       output: tv.Tensor,
                       layout_i: ConvLayout,
                       layout_w: ConvLayout,
                       layout_o: ConvLayout,
                       arch: Tuple[int, int],
                       mask: tv.Tensor,
                       mask_argsort: tv.Tensor,
                       indices: tv.Tensor,
                       reverse_mask: bool,
                       mask_filter: int = 0xffffffff,
                       mask_width: int = -1,
                       mask_output: tv.Tensor = tv.Tensor(),
                       alpha: float = 1.0,
                       beta: float = 0.0,
                       stream: int = 0):
        avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
                                       layout_o, arch, op_type, mask_width)
        inp = inp.clone()
        weight = weight.clone()
        output = output.clone()

        channel_k = output.dim(1)
        channel_c = inp.dim(1)

        times: List[float] = []
        all_profile_res: List[BestConvAlgoByProfile] = []
        for desp in avail:
            # for sparse conv, ndim isn't used, so we just provide a constant value.
            params = ConvParams(NDIM_DONT_CARE, op_type.value)
            params.conv_algo_desp = desp
            params.input = inp
            params.weight = weight.view([channel_k, -1, channel_c])
            params.output = output

            params.mask_width = mask_width
            params.alpha = alpha
            params.beta = beta
            params.stream = stream
            params.mask_argsort = mask_argsort
            params.indices = indices
            params.mask = mask
            params.mask_output = mask_output
            if op_type == ConvOpType.kBackwardWeight:
                assert not mask_output.empty()
            if op_type == ConvOpType.kBackwardInput:
                params.reverse_mask = reverse_mask
            params.mask_filter = mask_filter
            if desp.split_k_serial and op_type == ConvOpType.kBackwardWeight:
                splitk_tests = [1, 2, 4, 8, 16, 32, 64]
                # splitk_tests = [1]
            else:
                splitk_tests = [1]
            spk_speeds = []
            for spk in splitk_tests:
                this_times = []
                for j in range(3):
                    GemmMainUnitTest.stream_synchronize(stream)
                    t = time.time()
                    params.split_k_slices = spk
                    ConvMainUnitTest.implicit_gemm2(params)
                    GemmMainUnitTest.stream_synchronize(stream)
                    this_times.append(time.time() - t)
                times.append(np.mean(this_times[1:]))
                spk_speeds.append(times[-1])

                all_profile_res.append(BestConvAlgoByProfile(desp, splitk=spk))
        if not all_profile_res:
            raise ValueError("can't find suitable algorithm for", op_type)
        min_time = 1000
        min_idx = -1
        for i, t in enumerate(times):
            if t < min_time:
                min_time = t
                min_idx = i
        res = all_profile_res[min_idx]
        if not op_type == ConvOpType.kBackwardWeight:
            # fwd and dgrad don't need
            mask_width = -1
        key = (inp.dtype, weight.dtype, output.dtype, channel_k, channel_c,
               arch[0], arch[1], mask_width)
        with self.lock:
            if op_type == ConvOpType.kForward:
                self.kc_forward_cache[key] = res
            elif op_type == ConvOpType.kBackwardInput:
                self.kc_dgrad_cache[key] = res
            elif op_type == ConvOpType.kBackwardWeight:
                self.kc_wgrad_cache[key] = res
            else:
                raise NotImplementedError
        return res, min_time
Beispiel #7
0
    def run_with_tuned_result(
        self,
        profile_res: BestAlgoByProfile,
        a: tv.Tensor,
        b: tv.Tensor,
        c: tv.Tensor,
        trans_a: bool,
        trans_b: bool,
        trans_c: bool,
        arch: Tuple[int, int],
        stream: int,
        shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
        a_inds: tv.Tensor = tv.Tensor(),
        b_inds: tv.Tensor = tv.Tensor(),
        c_inds: tv.Tensor = tv.Tensor(),
        hint: int = AlgoHint.NoHint.value,
        alpha: float = 1.0,
        beta: float = 0.0,
        gather_data: tv.Tensor = tv.Tensor(),
        workspace: tv.Tensor = tv.Tensor(),
        timer: CUDAKernelTimer = CUDAKernelTimer(False)):
        m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a,
                                               trans_b, trans_c,
                                               shuffle_type.value,
                                               a_inds.shape, b_inds.shape,
                                               c_inds.shape)
        # GemmMainUnitTest.stream_synchronize(stream)
        algo_desp = profile_res.algo_desp
        assert algo_desp is not None
        split_k_slices = 1
        # TODO better splitk selection
        # if algo_desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
        #     split_k_slices = max(min(32, k // 128), 1)
        if profile_res.splitk > 1:
            split_k_slices = profile_res.splitk
        params = GemmParams()
        params.a = a
        params.b = b
        params.c = c
        params.a_inds = a_inds
        params.b_inds = b_inds
        params.c_inds = c_inds
        params.algo_desp = algo_desp
        params.split_k_slices = split_k_slices
        params.stream = stream
        params.alpha = alpha
        params.beta = beta
        params.workspace = workspace
        # gather = 0
        # if profile_res.external_gather and not gather_data.empty():
        #     GemmMainUnitTest.stream_synchronize(stream)
        #     tt = time.time()
        #     assert not gather_data.empty()
        #     params.a_inds = tv.Tensor()
        #     params.a = gather_data
        #     # print(profile_res.gather_params, gather_data.shape, a.shape, a_inds.shape)
        #     GATHER.gather(gather_data,
        #                    a,
        #                    a_inds,
        #                    *profile_res.gather_params,
        #                    stream=stream)
        #     GemmMainUnitTest.stream_synchronize(stream)
        #     gather = time.time() - tt
        if timer.enable:
            assert timer._timer is not None
            params.timer = timer._timer

        GemmMainUnitTest.matmul2(params)
        # GemmMainUnitTest.stream_synchronize(stream)
        return algo_desp
Beispiel #8
0
    def tune_and_cache(
            self,
            a: tv.Tensor,
            b: tv.Tensor,
            c: tv.Tensor,
            trans_a: bool,
            trans_b: bool,
            trans_c: bool,
            arch: Tuple[int, int],
            shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
            a_inds: tv.Tensor = tv.Tensor(),
            b_inds: tv.Tensor = tv.Tensor(),
            c_inds: tv.Tensor = tv.Tensor(),
            hint: int = AlgoHint.NoHint.value,
            alpha: float = 1.0,
            beta: float = 0.0,
            gather_data: tv.Tensor = tv.Tensor(),
            scatter_data: tv.Tensor = tv.Tensor(),
            # mm_func
            stream: int = 0):
        m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a,
                                               trans_b, trans_c,
                                               shuffle_type.value,
                                               a_inds.shape, b_inds.shape,
                                               c_inds.shape)
        avail = self.get_all_available(a, b, c, trans_a, trans_b, trans_c,
                                       arch, shuffle_type)

        c_ = c.clone()
        times: List[float] = []
        best_gather_params = (-1, -1, -1, -1)
        best_scatter_params = (-1, -1, -1, -1)

        all_profile_res: List[BestAlgoByProfile] = []
        for desp in avail:
            c_.zero_()
            split_k_slices = 1
            # TODO better splitk selection
            if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
                split_k_slices = max(min(32, k // 128), 1)
            params = GemmParams()
            params.a = a
            params.b = b
            params.c = c_
            params.a_inds = a_inds
            params.b_inds = b_inds
            params.c_inds = c_inds
            params.algo_desp = desp
            params.alpha = alpha
            params.beta = beta
            params.stream = stream
            if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
                splitk_tests = [1, 2, 4, 8, 16, 32, 64]
            else:
                splitk_tests = [1]
            spk_speeds = []
            for spk in splitk_tests:
                this_times = []
                for j in range(3):
                    GemmMainUnitTest.stream_synchronize(stream)
                    t = time.time()
                    params.split_k_slices = spk
                    GemmMainUnitTest.matmul2(params)
                    GemmMainUnitTest.stream_synchronize(stream)
                    this_times.append(time.time() - t)
                times.append(np.mean(this_times[1:]))
                spk_speeds.append(times[-1])

                all_profile_res.append(BestAlgoByProfile(desp, splitk=spk))

        min_time = 1000
        min_idx = -1
        for i, t in enumerate(times):
            if t < min_time:
                min_time = t
                min_idx = i
        res = all_profile_res[min_idx]
        with self.lock:
            if hint & AlgoHint.BackwardWeight.value:
                key = (a.dtype, b.dtype, c.dtype, m, n)
                self.mn_cache[key] = res
            elif hint & AlgoHint.BackwardInput.value:
                key = (a.dtype, b.dtype, c.dtype, n, k)
                self.nk_dgrad_cache[key] = res
            elif hint & AlgoHint.Fowrard.value:
                key = (a.dtype, b.dtype, c.dtype, n, k)
                self.nk_forward_cache[key] = res
            else:
                raise NotImplementedError

        return res, min_time
Beispiel #9
0
 def select(self,
            a: tv.Tensor,
            b: tv.Tensor,
            c: tv.Tensor,
            trans_a: bool,
            trans_b: bool,
            trans_c: bool,
            arch: Tuple[int, int],
            shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
            a_inds: tv.Tensor = tv.Tensor(),
            b_inds: tv.Tensor = tv.Tensor(),
            c_inds: tv.Tensor = tv.Tensor(),
            hint: int = AlgoHint.NoHint.value):
     m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape, trans_a,
                                            trans_b, trans_c,
                                            shuffle_type.value,
                                            a_inds.shape, b_inds.shape,
                                            c_inds.shape)
     if trans_c:
         trans_a = not trans_a
         trans_b = not trans_b
         trans_a, trans_b = trans_b, trans_a
         a, b = b, a
         trans_c = False
     avail_algos = get_available_algo_str_from_arch(arch)
     finally_algos: List[GemmAlgoDesp] = []
     for algo in avail_algos:
         static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype,
                       shuffle_type.value, algo)
         desps = self.static_key_to_desps.get(static_key, None)
         if desps is None or len(desps) == 0:
             continue
         meta = self.static_key_to_meta[static_key]
         # for shuffle stride algos, we need to make channel tile size as large as possible.
         # so if ShuffleAC, we need to make k largest.
         selected_algo_desps = GemmMainUnitTest.simple_select_tile_shape(
             m,
             n,
             k,
             meta.tile_ms,
             meta.tile_ns,
             meta.tile_ks,
             meta.tile_shape_to_algos,
             large_k_first=shuffle_type == shuffle_type.ShuffleAC)
         if not selected_algo_desps:
             candidate = desps
         else:
             candidate = [desps[i] for i in selected_algo_desps]
         # select by hint
         if hint == 0:
             return candidate[0]
         if hint & (AlgoHint.Fowrard.value | AlgoHint.BackwardInput.value):
             # m may be huge, n and k are small
             # don't need mixed precision
             # don't need splitk
             finally_algos = []
             if a.dtype == tv.float16:
                 dacc = tv.float16
                 dcomp = tv.float16
                 for can in candidate:
                     if can.dacc == dacc and can.dcomp == dcomp:
                         finally_algos.append(can)
             else:
                 finally_algos = candidate
         elif hint & AlgoHint.BackwardWeight.value:
             # k is huge
             # don't support i8
             # if f16, acc and comp must be f32
             finally_algos = []
             candidate_filtered: List[GemmAlgoDesp] = list(
                 filter(lambda x: x.split_k_serial, candidate))
             if not candidate_filtered:
                 candidate_filtered = candidate
             if a.dtype == tv.int8:
                 continue
             elif a.dtype == tv.float16:
                 dacc = tv.float32
                 dcomp = tv.float32
                 for can in candidate_filtered:
                     if can.dacc == dacc and can.dcomp == dcomp:
                         finally_algos.append(can)
             else:
                 finally_algos = candidate_filtered
         else:
             return candidate[0]
     # print(finally_algos)
     if finally_algos:
         return finally_algos[0]
     return None