Ejemplo n.º 1
0
 def get_all_available(
         self,
         a: tv.Tensor,
         b: tv.Tensor,
         c: tv.Tensor,
         trans_a: bool,
         trans_b: bool,
         trans_c: bool,
         arch: Tuple[int, int],
         shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle):
     if trans_c:
         trans_a = not trans_a
         trans_b = not trans_b
         trans_a, trans_b = trans_b, trans_a
         a, b = b, a
         trans_c = False
     avail_algos = get_available_algo_str_from_arch(arch)
     finally_algos: List[GemmAlgoDesp] = []
     for algo in avail_algos:
         static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype,
                       shuffle_type.value, algo)
         desps = self.static_key_to_desps.get(static_key, None)
         if desps is None or len(desps) == 0:
             continue
         for desp in desps:
             # skip volta tensor op since it is very slow in architectures except volta.
             if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
                 continue
             lda = a.dim(1)
             ldb = b.dim(1)
             ldc = c.dim(1)
             if desp.supported_ldx(lda, ldb, ldc):
                 finally_algos.append(desp)
     return finally_algos
Ejemplo n.º 2
0
    def get_all_available(self, inp: tv.Tensor, weight: tv.Tensor,
                          out: tv.Tensor, layout_i: ConvLayout,
                          layout_w: ConvLayout, layout_o: ConvLayout,
                          arch: Tuple[int, int], op_type: ConvOpType,
                          mask_width: int):

        avail_algos = get_available_algo_str_from_arch(arch)
        finally_algos: List[ConvAlgoDesp] = []
        for algo in avail_algos:
            static_key = (layout_i.layout_type.value,
                          layout_w.layout_type.value,
                          layout_o.layout_type.value, layout_i.interleave,
                          layout_w.interleave, layout_o.interleave, inp.dtype,
                          weight.dtype, out.dtype, algo, op_type.value)
            desps = self.static_key_to_desps.get(static_key, None)
            if desps is None or len(desps) == 0:
                continue
            for desp in desps:
                # skip volta tensor op since it is very slow in architectures except volta.
                if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
                    continue
                ldi = inp.dim(-1)
                ldw = weight.dim(-1)
                ldo = out.dim(-1)
                mask_width_valid = True
                if desp.op_type == ConvOpType.kBackwardWeight.value:
                    assert mask_width > 0
                    mask_width_valid = mask_width % desp.tile_shape[2] == 0
                if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
                    finally_algos.append(desp)
        return finally_algos
Ejemplo n.º 3
0
    def get_all_available(self, inp: tv.Tensor, weight: tv.Tensor,
                          out: tv.Tensor, layout_i: ConvLayout,
                          layout_w: ConvLayout, layout_o: ConvLayout,
                          arch: Tuple[int, int], op_type: ConvOpType,
                          mask_width: int, fp32_accum: Optional[bool] = None):

        avail_algos = get_available_algo_str_from_arch(arch)
        finally_algos: List[ConvAlgoDesp] = []
        is_fp16 = inp.dtype == tv.float16 and weight.dtype == tv.float16 and out.dtype == tv.float16
        use_f32_as_accum = False
        kv = int(np.prod(weight.shape[1:-1]))
        # for 3d conv, if reduce axis is too large, may cause nan during 
        # forward.
        if is_fp16:
            if fp32_accum is None:
                if op_type == ConvOpType.kForward:
                    use_f32_as_accum = weight.dim(-1) * kv > 128 * 27
                elif op_type == ConvOpType.kBackwardInput:
                    use_f32_as_accum = weight.dim(0) * kv > 128 * 27
            else:
                use_f32_as_accum = fp32_accum
        for algo in avail_algos:
            static_key = (layout_i.layout_type.value,
                          layout_w.layout_type.value,
                          layout_o.layout_type.value, layout_i.interleave,
                          layout_w.interleave, layout_o.interleave, inp.dtype,
                          weight.dtype, out.dtype, algo, op_type.value)
            desps = self.static_key_to_desps.get(static_key, None)
            if desps is None or len(desps) == 0:
                continue
            for desp in desps:
                # skip volta tensor op since it is very slow in architectures except volta.
                if arch >= (7, 5) and desp.algo == GemmAlgo.Volta.value:
                    continue
                if arch >= (7, 0) and is_fp16:
                    # skip simt fp16 kernels if we have tensor core
                    if desp.algo == GemmAlgo.Simt:
                        continue
                    if use_f32_as_accum:
                        if desp.dacc == tv.float16:
                            continue
                        
                ldi = inp.dim(-1)
                ldw = weight.dim(-1)
                ldo = out.dim(-1)
                mask_width_valid = True
                if desp.op_type == ConvOpType.kBackwardWeight.value:
                    assert mask_width > 0
                    mask_width_valid = mask_width % desp.tile_shape[2] == 0
                if desp.supported_ldx_conv(ldi, ldw, ldo) and mask_width_valid:
                    finally_algos.append(desp)
        return finally_algos
Ejemplo n.º 4
0
    def run_with_tuned_result(self,
                              profile_res: BestConvAlgoByProfile,
                              op_type: Union[ConvOpType, int],
                              inp: tv.Tensor,
                              weight: tv.Tensor,
                              output: tv.Tensor,
                              mask: tv.Tensor,
                              mask_argsort: tv.Tensor,
                              mask_output: tv.Tensor,
                              indices: tv.Tensor,
                              reverse_mask: bool,
                              mask_filter: int = 0xffffffff,
                              mask_width: int = -1,
                              alpha: float = 1.0,
                              beta: float = 0.0,
                              stream: int = 0,
                              workspace: tv.Tensor = tv.Tensor(),
                              verbose: bool = False,
                              timer: CUDAKernelTimer = CUDAKernelTimer(False)):
        channel_k = output.dim(1)
        channel_c = inp.dim(1)
        # GemmMainUnitTest.stream_synchronize(stream)
        algo_desp = profile_res.algo_desp
        assert algo_desp is not None
        split_k_slices = 1
        if profile_res.splitk > 1:
            split_k_slices = profile_res.splitk
        if isinstance(op_type, int):
            op_type_value = op_type
        else:
            op_type_value = op_type.value
        params = ConvParams(NDIM_DONT_CARE, op_type_value)
        params.conv_algo_desp = profile_res.algo_desp
        params.input = inp
        params.verbose = verbose
        params.weight = weight.view([channel_k, -1, channel_c])
        params.output = output
        params.split_k_slices = split_k_slices
        params.alpha = alpha
        params.beta = beta
        params.stream = stream
        params.mask_argsort = mask_argsort
        params.indices = indices
        params.mask = mask
        params.mask_filter = mask_filter
        params.mask_width = mask_width
        params.mask_filter = mask_filter
        params.mask_output = mask_output
        params.reverse_mask = reverse_mask
        if timer.enable:
            assert timer._timer is not None
            params.timer = timer._timer
        # torch.cuda.synchronize()
        # t = time.time()
        params.workspace = workspace
        ConvMainUnitTest.implicit_gemm2(params)
        # torch.cuda.synchronize()
        # dura = time.time() - t
        # print("F", algo_desp, dura)

        # GemmMainUnitTest.stream_synchronize(stream)
        return algo_desp
Ejemplo n.º 5
0
    def tune_and_cache(self,
                       op_type: ConvOpType,
                       inp: tv.Tensor,
                       weight: tv.Tensor,
                       output: tv.Tensor,
                       layout_i: ConvLayout,
                       layout_w: ConvLayout,
                       layout_o: ConvLayout,
                       arch: Tuple[int, int],
                       mask: tv.Tensor,
                       mask_argsort: tv.Tensor,
                       indices: tv.Tensor,
                       reverse_mask: bool,
                       mask_filter: int = 0xffffffff,
                       mask_width: int = -1,
                       mask_output: tv.Tensor = tv.Tensor(),
                       alpha: float = 1.0,
                       beta: float = 0.0,
                       stream: int = 0):
        avail = self.get_all_available(inp, weight, output, layout_i, layout_w,
                                       layout_o, arch, op_type, mask_width)
        inp = inp.clone()
        weight = weight.clone()
        output = output.clone()

        channel_k = output.dim(1)
        channel_c = inp.dim(1)

        times: List[float] = []
        all_profile_res: List[BestConvAlgoByProfile] = []
        for desp in avail:
            # for sparse conv, ndim isn't used, so we just provide a constant value.
            params = ConvParams(NDIM_DONT_CARE, op_type.value)
            params.conv_algo_desp = desp
            params.input = inp
            params.weight = weight.view([channel_k, -1, channel_c])
            params.output = output

            params.mask_width = mask_width
            params.alpha = alpha
            params.beta = beta
            params.stream = stream
            params.mask_argsort = mask_argsort
            params.indices = indices
            params.mask = mask
            params.mask_output = mask_output
            if op_type == ConvOpType.kBackwardWeight:
                assert not mask_output.empty()
            if op_type == ConvOpType.kBackwardInput:
                params.reverse_mask = reverse_mask
            params.mask_filter = mask_filter
            if desp.split_k_serial and op_type == ConvOpType.kBackwardWeight:
                splitk_tests = [1, 2, 4, 8, 16, 32, 64]
                # splitk_tests = [1]
            else:
                splitk_tests = [1]
            spk_speeds = []
            for spk in splitk_tests:
                this_times = []
                for j in range(3):
                    GemmMainUnitTest.stream_synchronize(stream)
                    t = time.time()
                    params.split_k_slices = spk
                    ConvMainUnitTest.implicit_gemm2(params)
                    GemmMainUnitTest.stream_synchronize(stream)
                    this_times.append(time.time() - t)
                times.append(np.mean(this_times[1:]))
                spk_speeds.append(times[-1])

                all_profile_res.append(BestConvAlgoByProfile(desp, splitk=spk))
        if not all_profile_res:
            raise ValueError("can't find suitable algorithm for", op_type)
        min_time = 1000
        min_idx = -1
        for i, t in enumerate(times):
            if t < min_time:
                min_time = t
                min_idx = i
        res = all_profile_res[min_idx]
        if not op_type == ConvOpType.kBackwardWeight:
            # fwd and dgrad don't need
            mask_width = -1
        key = (inp.dtype, weight.dtype, output.dtype, channel_k, channel_c,
               arch[0], arch[1], mask_width)
        with self.lock:
            if op_type == ConvOpType.kForward:
                self.kc_forward_cache[key] = res
            elif op_type == ConvOpType.kBackwardInput:
                self.kc_dgrad_cache[key] = res
            elif op_type == ConvOpType.kBackwardWeight:
                self.kc_wgrad_cache[key] = res
            else:
                raise NotImplementedError
        return res, min_time