Exemple #1
0
    def run_with_tuned_result(self,
                              profile_res: BestConvAlgoByProfile,
                              op_type: Union[ConvOpType, int],
                              inp: tv.Tensor,
                              weight: tv.Tensor,
                              output: tv.Tensor,
                              mask: tv.Tensor,
                              mask_argsort: tv.Tensor,
                              mask_output: tv.Tensor,
                              indices: tv.Tensor,
                              reverse_mask: bool,
                              mask_filter: int = 0xffffffff,
                              mask_width: int = -1,
                              alpha: float = 1.0,
                              beta: float = 0.0,
                              stream: int = 0,
                              workspace: tv.Tensor = tv.Tensor(),
                              verbose: bool = False,
                              timer: CUDAKernelTimer = CUDAKernelTimer(False)):
        channel_k = output.dim(1)
        channel_c = inp.dim(1)
        # GemmMainUnitTest.stream_synchronize(stream)
        algo_desp = profile_res.algo_desp
        assert algo_desp is not None
        split_k_slices = 1
        if profile_res.splitk > 1:
            split_k_slices = profile_res.splitk
        if isinstance(op_type, int):
            op_type_value = op_type
        else:
            op_type_value = op_type.value
        params = ConvParams(NDIM_DONT_CARE, op_type_value)
        params.conv_algo_desp = profile_res.algo_desp
        params.input = inp
        params.verbose = verbose
        params.weight = weight.view([channel_k, -1, channel_c])
        params.output = output
        params.split_k_slices = split_k_slices
        params.alpha = alpha
        params.beta = beta
        params.stream = stream
        params.mask_argsort = mask_argsort
        params.indices = indices
        params.mask = mask
        params.mask_filter = mask_filter
        params.mask_width = mask_width
        params.mask_filter = mask_filter
        params.mask_output = mask_output
        params.reverse_mask = reverse_mask
        if timer.enable:
            assert timer._timer is not None
            params.timer = timer._timer
        # torch.cuda.synchronize()
        # t = time.time()
        params.workspace = workspace
        ConvMainUnitTest.implicit_gemm2(params)
        # torch.cuda.synchronize()
        # dura = time.time() - t
        # print("F", algo_desp, dura)

        # GemmMainUnitTest.stream_synchronize(stream)
        return algo_desp
Exemple #2
0
    def tune_and_cache(self,
                       op_type: ConvOpType,
                       inp: tv.Tensor,
                       weight: tv.Tensor,
                       output: tv.Tensor,
                       layout_i: ConvLayout,
                       layout_w: ConvLayout,
                       layout_o: ConvLayout,
                       arch: Tuple[int, int],
                       mask: tv.Tensor,
                       mask_argsort: tv.Tensor,
                       indices: tv.Tensor,
                       reverse_mask: bool,
                       mask_filter: int = 0xffffffff,
                       mask_width: int = -1,
                       mask_output: tv.Tensor = tv.Tensor(),
                       alpha: float = 1.0,
                       beta: float = 0.0,
                       stream: int = 0):
        avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
                                       layout_o, arch, op_type, mask_width)
        inp = inp.clone()
        weight = weight.clone()
        output = output.clone()

        channel_k = output.dim(1)
        channel_c = inp.dim(1)

        times: List[float] = []
        all_profile_res: List[BestConvAlgoByProfile] = []
        for desp in avail:
            # for sparse conv, ndim isn't used, so we just provide a constant value.
            params = ConvParams(NDIM_DONT_CARE, op_type.value)
            params.conv_algo_desp = desp
            params.input = inp
            params.weight = weight.view([channel_k, -1, channel_c])
            params.output = output

            params.mask_width = mask_width
            params.alpha = alpha
            params.beta = beta
            params.stream = stream
            params.mask_argsort = mask_argsort
            params.indices = indices
            params.mask = mask
            params.mask_output = mask_output
            if op_type == ConvOpType.kBackwardWeight:
                assert not mask_output.empty()
            if op_type == ConvOpType.kBackwardInput:
                params.reverse_mask = reverse_mask
            params.mask_filter = mask_filter
            if desp.split_k_serial and op_type == ConvOpType.kBackwardWeight:
                splitk_tests = [1, 2, 4, 8, 16, 32, 64]
                # splitk_tests = [1]
            else:
                splitk_tests = [1]
            spk_speeds = []
            for spk in splitk_tests:
                this_times = []
                for j in range(3):
                    GemmMainUnitTest.stream_synchronize(stream)
                    t = time.time()
                    params.split_k_slices = spk
                    ConvMainUnitTest.implicit_gemm2(params)
                    GemmMainUnitTest.stream_synchronize(stream)
                    this_times.append(time.time() - t)
                times.append(np.mean(this_times[1:]))
                spk_speeds.append(times[-1])

                all_profile_res.append(BestConvAlgoByProfile(desp, splitk=spk))
        if not all_profile_res:
            raise ValueError("can't find suitable algorithm for", op_type)
        min_time = 1000
        min_idx = -1
        for i, t in enumerate(times):
            if t < min_time:
                min_time = t
                min_idx = i
        res = all_profile_res[min_idx]
        if not op_type == ConvOpType.kBackwardWeight:
            # fwd and dgrad don't need
            mask_width = -1
        key = (inp.dtype, weight.dtype, output.dtype, channel_k, channel_c,
               arch[0], arch[1], mask_width)
        with self.lock:
            if op_type == ConvOpType.kForward:
                self.kc_forward_cache[key] = res
            elif op_type == ConvOpType.kBackwardInput:
                self.kc_dgrad_cache[key] = res
            elif op_type == ConvOpType.kBackwardWeight:
                self.kc_wgrad_cache[key] = res
            else:
                raise NotImplementedError
        return res, min_time