Example #1
0
    def bind_flex_scales(self):
        scaleAB = self.flex_entry_A.scale * self.flex_entry_B.scale
        scaleC = self.flex_entry_C.scale
        alpha = self.alpha * scaleAB
        beta = self.beta * scaleC
        # TODO: hardcoding these sucks
        self.params[6] = alpha
        self.params[7] = beta
        self.params[20] = 1. / scaleC

        FlexPtrDescription.bind_ptr(self.params)
Example #2
0
    def bind_flex_scales(self):
        scaleAB = self.flex_entry_I.scale * self.flex_entry_F.scale
        scaleC = self.flex_entry_O.scale
        alpha = self.alpha * scaleAB
        beta = self.beta * scaleC

        for kernel in self.kernels:
            kernel[8] = alpha
            kernel[9] = beta
            kernel[-1] = 1. / scaleC

        for kernel in self.kernels:
            FlexPtrDescription.bind_ptr(kernel)
Example #3
0
    def bind_flex_scales(self):
        scaleAB = self.flex_entry_E.scale * self.flex_entry_F.scale
        scaleC = self.flex_entry_O.scale
        alpha = self.alpha * scaleAB
        beta = self.beta * scaleC

        for kernel in self.kernels[1:1 + len(self.bprop_kernels)]:
            kernel[8] = alpha
            kernel[9] = beta
            kernel[-1] = 1. / scaleC

        if self.convert_out:
            self.kernels[-1][-2] = 1. / scaleC

        for kernel in self.kernels:
            FlexPtrDescription.bind_ptr(kernel)
Example #4
0
    def bind_buffers(self):
        super(FlexLUTBpropKernel, self).bind_buffers()
        self.flex_entry_O.allocate()
        self.flex_entry_I.allocate()
        self.flex_entry_E.allocate()

        for k in self.kernels:
            kernel, params = k
            if kernel.name == lut_bprop_kernel_name:
                maxabs_ptr = FlexPtrDescription(self.flex_entry_O)
                params.extend([
                    maxabs_ptr, self.flex_entry_O.scale,
                    self.flex_entry_I.scale, self.flex_entry_E.scale
                ])
            elif kernel.name == lut_sort_kernel_name:
                params.extend([self.flex_entry_I.scale])

        self.clear_tensor()  # Additional tensor clear due to flex specifics
Example #5
0
def _ew_bind_flex_scales(kernel):
    for index, flex_scale_desc in kernel.flex_scale_info:
        scale = flex_scale_desc.flex_entry.scale
        scale = 1.0 / scale if flex_scale_desc.is_output else scale
        kernel.params[index] = scale
    FlexPtrDescription.bind_ptr(kernel.params)
Example #6
0
    def gen_kernels(self, runtime, N, C, K, D, H, W, T, R, S, M, P, Q,
                    pad_d, pad_h, pad_w, str_d, str_h, str_w, dil_d, dil_h, dil_w):
        self.I = TensorDescriptionWrapper(self.I, len(self.I.shape))
        self.F = TensorDescriptionWrapper(self.F, len(self.F.shape))
        self.O = TensorDescriptionWrapper(self.O, len(self.O.shape))

        self.flex_entry_I = self.I.flex_entry()
        self.flex_entry_F = self.F.flex_entry()
        self.flex_entry_O = self.O.flex_entry()

        vec_size = 4 if self.dtype.itemsize == 4 else 8

        assert N % 32 == 0, "N dim must be multiple of 32"
        assert K % vec_size == 0, "K dim must be multiple of %d" % vec_size

        if self.dtype.type == "flex":
            clss = "fconv"
        else:
            raise TypeError("Type not supported.")

        self.C = C
        self.K = K
        self.M = M
        self.P = P
        self.Q = Q
        self.NCK = (N, C, K)
        self.TRS = (T, R, S)
        self.DHW = (D, H, W)
        self.MPQ = (M, P, Q)
        self.padding = (pad_d, pad_h, pad_w)
        self.strides = (str_d, str_h, str_w)

        self.all_params = (N, C, K, D, H, W, T, R, S, pad_d, pad_h, pad_w, str_d, str_h, str_w)

        self.dimI = (C, D, H, W, N)
        self.dimF = (C, T, R, S, K)
        self.dimF = (K, T, R, S, C)
        self.dimO = (K, M, P, Q, N)
        self.dimI2 = (C * D * H * W, N)
        self.dimF2 = (C * T * R * S, K)
        self.dimF2t = (K, C * T * R * S)
        self.dimO2 = (K * M * P * Q, N)
        self.dimS = (K, 1)
        self.sizeI = reduce(mul, self.dimI, 1)
        self.sizeF = reduce(mul, self.dimF, 1)
        self.sizeO = reduce(mul, self.dimO, 1)
        self.nOut = reduce(mul, self.MPQ, 1) * K

        # precompute some multiplications for fast constant memory access
        WN = W * N
        HWN = H * WN
        DHWN = D * HWN
        RS = R * S
        RST = T * RS
        CRST = C * RST
        KRST = K * RST
        PQ = P * Q
        PQM = M * PQ
        QN = Q * N
        PQN = P * QN
        MPQN = M * PQN

        if CRST > 2**16:
            assert CRST < 2**16, "Integer division is faster with 16bit numerators"

        # precompute the magic numbers and shift amounts for integer division
        magic_PQ = _magic64(PQ)
        magic_Q = _magic64(Q)
        magic_RS = _magic32(RST + 32, RS)
        magic_S = _magic32(RS + 32, S)

        # flop count for benchmarking
        self.flops = PQM * K * N * CRST * 2.0

        tile_N = 128 if N > 64 else 64
        grid_N = _grid_dim(tile_N, N)
        tiles_CK = (128, 64, 32) if tile_N == 128 else (128, 64)

        # FPROP #
        self.fprop_kernels = kernel_specs.xprop_conv_kernels(
            clss, "fprop", "K", tile_N, grid_N, K, tiles_CK, PQM, RST,
            _flatten([N, K, D, H, W, WN, HWN, DHWN,
                      C, KRST, RST, RS, magic_RS, S, magic_S,
                      pad_d, pad_h, pad_w, str_d, str_h, str_w,
                      Q, PQ, QN, PQN, MPQN, magic_Q, magic_PQ]))

        # shared lookup table size
        self.fprop_lut_size = RST * 4 * 2

        # Set to 5 for the current T1000 HW config
        self.trunc_rows = 32
        flags = self.trunc_rows << 8

        self.kernels = []
        for kernel in self.fprop_kernels:
            # TODO: Populate alpha and beta parameters (in a separate loop!).
            # alpha (used to be params[6]) will be multiplied with
            self.kernels.append([
                kernel_specs.get_kernel(kernel[0]), kernel[1], kernel[2], None,
                0, self.O, self.I, self.F, 1.0, 0.0, flags,
                kernel[3]] + kernel[4])

        for kernel in self.kernels:
            kernel.extend((FlexPtrDescription(self.flex_entry_O), 1.0))
            kernel[10] &= 0xfffffffe  # Enable output flag

        # record output flex id for autoflex
        self.output_flex_ids = [self.flex_entry_O.flex_id]
Example #7
0
    def gen_kernels(self, runtime, N, C, K, D, H, W, T, R, S, M, P, Q,
                    pad_d, pad_h, pad_w, str_d, str_h, str_w, dil_d, dil_h, dil_w):
        self.I = TensorDescriptionWrapper(self.I, len(self.I.shape))
        self.E = TensorDescriptionWrapper(self.E, len(self.E.shape))
        self.U = TensorDescriptionWrapper(self.U, len(self.U.shape))

        self.flex_entry_I = self.I.flex_entry()
        self.flex_entry_E = self.E.flex_entry()
        self.flex_entry_U = self.U.flex_entry()

        U_size = int(np.prod(self.U.shape) * 4)

        vec_size = 4 if self.dtype.itemsize == 4 else 8

        assert N % 32 == 0, "N dim must be multiple of 32"
        assert K % vec_size == 0, "K dim must be multiple of %d" % vec_size

        if self.dtype.type == "flex":
            clss = "fconv"
        else:
            raise TypeError("Type not supported.")

        self.C = C
        self.K = K
        self.M = M
        self.P = P
        self.Q = Q
        self.NCK = (N, C, K)
        self.TRS = (T, R, S)
        self.DHW = (D, H, W)
        self.MPQ = (M, P, Q)
        self.padding = (pad_d, pad_h, pad_w)
        self.strides = (str_d, str_h, str_w)

        self.all_params = (N, C, K, D, H, W, T, R, S, pad_d, pad_h, pad_w, str_d, str_h, str_w)

        self.dimI = (C, D, H, W, N)
        self.dimF = (C, T, R, S, K)
        self.dimFb = (K, T, R, S, C)
        self.dimO = (K, M, P, Q, N)
        self.dimI2 = (C * D * H * W, N)
        self.dimF2 = (C * T * R * S, K)
        self.dimF2t = (K, C * T * R * S)
        self.dimO2 = (K * M * P * Q, N)
        self.dimS = (K, 1)
        self.sizeI = reduce(mul, self.dimI, 1)
        self.sizeF = reduce(mul, self.dimF, 1)
        self.sizeO = reduce(mul, self.dimO, 1)
        self.nOut = reduce(mul, self.MPQ, 1) * K

        # precompute some multiplications for fast constant memory access
        WN = W * N
        HWN = H * WN
        DHWN = D * HWN
        RS = R * S
        RST = T * RS
        CRST = C * RST
        CRSTK = K * CRST
        PQ = P * Q
        PQM = M * PQ
        QN = Q * N
        PQN = P * QN
        MPQN = M * PQN

        if CRST > 2**16:
            assert CRST < 2**16, "Integer division is faster with 16bit numerators"

        # precompute the magic numbers and shift amounts for integer division
        magic_RST = _magic32(CRST, RST)
        magic_RS = _magic32(RST + 32, RS)
        magic_S = _magic32(RS + 32, S)

        # flop count for benchmarking
        self.flops = PQM * K * N * CRST * 2.0

        # UPDATE #

        grid_C = _grid_dim(128, CRST)
        sm_count = _get_sm_count()

        # in float32 for big feature_map layers the smaller tile is actually faster
        # so restrict tile selection to just that.
        if self.dtype.type is np.float32 and PQ > 56 * 56:
            K_tiles = (64,)
        else:
            K_tiles = (128, 64)

        determ = ""
        self.determ = 0

        self.updat_kernels = []
        for tile_K, grid_K, offset_K in kernel_specs.K_partitions(K, K_tiles):

            kernel_name = "%s_updat%s_C128_K%d" % (clss, determ, tile_K)
            base_blocks = M * grid_C * grid_K

            grid_P, grid_Q, threads = kernel_specs.update_grid(kernel_name,
                                                               base_blocks, P, Q, sm_count)

            grid_PQ = grid_P * grid_Q
            magic_PQu = _magic64(grid_PQ)
            magic_Qu = _magic64(grid_Q)

            block = (threads, 1, 1)
            if RST > 1:
                grid = (M * grid_PQ, grid_C, grid_K)
            else:
                grid = (grid_C, grid_K, M * grid_PQ)

            self.determ *= M * grid_PQ
            self.determ_shape = (M * grid_PQ, CRSTK)

            self.updat_kernels.append([kernel_name, grid, block, offset_K, _flatten([
                N, K, D, H, W, WN, HWN, DHWN,
                C, CRST, RST, magic_RST, RS, magic_RS, S, magic_S,
                pad_d, pad_h, pad_w, str_d, str_h, str_w,
                P, Q, PQ, QN, PQN, MPQN, magic_Qu, magic_PQu,
                grid_P, grid_Q, grid_PQ])])

        # Set to 5 for the current T1000 HW config
        self.trunc_rows = 32
        flags = self.trunc_rows << 8

        # Have to convert output from float to flex
        U_data = ScratchBufferWrapper(U_size, 0, runtime)
        shape = [int(np.prod(self.U.shape[:-1])), self.U.shape[-1]]
        convert_kernel = _prepare_convert_kernel(U_data, "f4", self.U, shape,
                                                 FlexPtrDescription(self.flex_entry_U))

        self.kernels = []
        for kernel in self.updat_kernels:
            # TODO: Populate alpha and beta parameters (in a separate loop!).
            # alpha (used to be params[6]) will be multiplied with
            self.kernels.append([
                kernel_specs.get_kernel(kernel[0]), kernel[1], kernel[2], None,
                0, U_data, self.I, self.E, 1.0, 0.0, flags,
                kernel[3]] + kernel[4])

        for kernel in self.kernels:
            kernel.extend((FlexPtrDescription(self.flex_entry_U), 1.0))
            kernel[10] &= 0xfffffffe  # Enable output flag

        self.kernels.append(convert_kernel)

        # record output flex id for autoflex
        self.output_flex_ids = [self.flex_entry_U.flex_id]
Example #8
0
    def gen_kernels(self, runtime, N, C, K, D, H, W, T, R, S, M, P, Q,
                    pad_d, pad_h, pad_w, str_d, str_h, str_w, dil_d, dil_h, dil_w):
        self.E = TensorDescriptionWrapper(self.E, len(self.E.shape))
        self.F = TensorDescriptionWrapper(self.F, len(self.F.shape))
        self.O = TensorDescriptionWrapper(self.O, len(self.O.shape))

        self.flex_entry_E = self.E.flex_entry()
        self.flex_entry_F = self.F.flex_entry()
        self.flex_entry_O = self.O.flex_entry()

        F_size = int(np.prod(self.F.shape) * 2)
        O_size = int(np.prod(self.O.shape) * 2)

        vec_size = 4 if self.dtype.itemsize == 4 else 8

        assert N % 32 == 0, "N dim must be multiple of 32"
        assert K % vec_size == 0, "K dim must be multiple of %d" % vec_size

        if self.dtype.type == "flex":
            clss = "fconv"
        else:
            raise TypeError("Type not supported.")

        self.C = C
        self.K = K
        self.M = M
        self.P = P
        self.Q = Q
        self.NCK = (N, C, K)
        self.TRS = (T, R, S)
        self.DHW = (D, H, W)
        self.MPQ = (M, P, Q)
        self.padding = (pad_d, pad_h, pad_w)
        self.strides = (str_d, str_h, str_w)

        self.all_params = (N, C, K, D, H, W, T, R, S, pad_d, pad_h, pad_w, str_d, str_h, str_w)

        self.dimI = (C, D, H, W, N)
        self.dimF = (C, T, R, S, K)
        self.dimFb = (K, T, R, S, C)
        self.dimO = (K, M, P, Q, N)
        self.dimI2 = (C * D * H * W, N)
        self.dimF2 = (C * T * R * S, K)
        self.dimF2t = (K, C * T * R * S)
        self.dimO2 = (K * M * P * Q, N)
        self.dimS = (K, 1)
        self.sizeI = reduce(mul, self.dimI, 1)
        self.sizeF = reduce(mul, self.dimF, 1)
        self.sizeO = reduce(mul, self.dimO, 1)
        self.nOut = reduce(mul, self.MPQ, 1) * K

        # precompute some multiplications for fast constant memory access
        HW = H * W
        DHW = D * HW
        WN = W * N
        HWN = H * WN
        DHWN = D * HWN
        RS = R * S
        RST = T * RS
        CRST = C * RST
        PQ = P * Q
        PQM = M * PQ
        QN = Q * N
        PQN = P * QN
        MPQN = M * PQN

        if CRST > 2**16:
            assert CRST < 2**16, "Integer division is faster with 16bit numerators"

        # precompute the magic numbers and shift amounts for integer division
        magic_HW = _magic64(HW)
        magic_W = _magic64(W)
        magic_PQ = _magic64(PQ)
        magic_Q = _magic64(Q)
        magic_RST = _magic32(CRST, RST)
        magic_RS = _magic32(RST + 32, RS)
        magic_S = _magic32(RS + 32, S)
        magic_str_w = _magic32(W + S, str_w)
        magic_str_h = _magic32(H + R, str_h)
        magic_str_d = _magic32(D + T, str_d)

        # flop count for benchmarking
        self.flops = PQM * K * N * CRST * 2.0

        tile_N = 128 if N > 64 else 64
        grid_N = _grid_dim(tile_N, N)
        tiles_CK = (128, 64, 32) if tile_N == 128 else (128, 64)

        # BPROP #
        if C < 16 or C % vec_size != 0:
            # special kernel for deconv into first layer
            kernel_name = "%s_bprop_C1_N64" % clss

            grid = (PQM, _grid_dim(32, CRST), _grid_dim(64, N))
            block = (32, 1, 1)

            self.bprop_kernels = [[kernel_name, grid, block, 0, _flatten([
                N, K, D, H, W, WN, HWN, DHWN,
                C, CRST, RST, magic_RST, RS, magic_RS, S, magic_S,
                pad_d, pad_h, pad_w, str_d, str_h, str_w,
                Q, PQ, QN, PQN, MPQN, magic_Q, magic_PQ,
                CRST * 8 * self.dtype.itemsize, MPQN * 8 * self.dtype.itemsize])]]

            # generate the kernel args for transpose CRST,K => K,CRST
            self.shuffle_args = [CRST, K]
            gridX = (K >> 5) + (K & 31 != 0)
            gridY = (CRST >> 5) + (CRST & 31 != 0)
            self.shuffle_grid = (gridX, gridY, 1)
            self.shuffle_block = (32, 8, 1)
            self.bprop_zero = self.sizeI * self.dtype.itemsize
            self.bprop_lut_size = 0

        else:
            self.bprop_kernels = kernel_specs.xprop_conv_kernels(
                clss, "bprop", "C", tile_N, grid_N, C, tiles_CK, DHW, RST, _flatten([
                    N, C, M, P, Q, QN, PQN, MPQN,
                    K, CRST, RST, RS, magic_RS, S, magic_S,
                    pad_d, pad_h, pad_w, str_d, str_h, str_w,
                    W, HW, WN, HWN, DHWN, magic_W, magic_HW,
                    R, T, magic_str_w, magic_str_h, magic_str_d]))

            # generate the kernel args for dim shuffling CRSTK => KRSTC
            self.shuffle_args = _flatten([
                RST * K, RS * K, S * K, K,
                RST * C, RS * C, S * C, C,
                RS, magic_RS, S, magic_S])
            gridX = (K >> 5) + (K & 31 != 0)
            gridY = (C >> 5) + (C & 31 != 0)
            self.shuffle_grid = (gridX, gridY, RST)
            self.shuffle_block = (32, 8, 1)
            self.bprop_zero = 0
            self.bprop_lut_size = RST * 4 * 2

        # Set to 5 for the current T1000 HW config
        self.trunc_rows = 32
        flags = self.trunc_rows << 8

        # Must dim shuffle filter data for bprop kernel
        F_data = ScratchBufferWrapper(F_size, 0, runtime)
        if self.bprop_zero:
            Out = ScratchBufferWrapper(O_size, F_size, runtime)
            shuffle_kernel = _get_transpose_kernel(self.dtype)
        else:
            Out = self.O
            # can point to transpose or dimshuffle kernel
            shuffle_kernel = _get_shuffle_kernel(self.dtype)
        shuffle_args = [self.shuffle_grid, self.shuffle_block, None,
                        F_data, self.F] + self.shuffle_args
        shuffle_kernel = [shuffle_kernel] + shuffle_args

        # Have to zero output buffer and use type conversion for kernel using atomics
        if self.bprop_zero:
            shape = [int(np.prod(self.O.shape[:-1])), self.O.shape[-1]]
            convert_kernel = _prepare_convert_kernel(Out, "f2", self.O, shape,
                                                     FlexPtrDescription(self.flex_entry_O))
            self.convert_out = True
        else:
            self.convert_out = False

        self.kernels = []
        for kernel in self.bprop_kernels:
            # TODO: Populate alpha and beta parameters (in a separate loop!).
            # alpha (used to be params[6]) will be multiplied with
            self.kernels.append([
                kernel_specs.get_kernel(kernel[0]), kernel[1], kernel[2], None,
                0, Out, self.E, F_data, 1.0, 0.0, flags, kernel[3]] + kernel[4])

        for kernel in self.kernels:
            kernel.extend((FlexPtrDescription(self.flex_entry_O), 1.0))
            kernel[10] &= 0xfffffffe  # Enable output flag

        self.kernels = [shuffle_kernel] + self.kernels
        if self.convert_out:
            self.kernels.append(convert_kernel)

        # record output flex id for autoflex
        self.output_flex_ids = [self.flex_entry_O.flex_id]
Example #9
0
 def bind_flex_scales(self):
     for k in self.kernels:
         kernel, params = k
         FlexPtrDescription.bind_ptr(params)
Example #10
0
 def bind_flex_scales(self):
     self.transformer.get_op_tensor(self.op).flex_entry.allocate()
     FlexPtrDescription.bind_ptr(self.params)
Example #11
0
 def bind_buffers(self):
     super(FlexPoolKernel, self).bind_buffers()
     maxabs_ptr = FlexPtrDescription(
         self.transformer.get_op_tensor(self.op).flex_entry)
     self.params.extend([maxabs_ptr])
Example #12
0
    def _build_maxas_kernel(self, op, size=None):
        """
        Uses tensor dimensions and axis ordering to select a sass kernel and use
        maxas to compile it for later use.

        Arguments:
            op (DotOp): Graph op being transformed into this kernel
            size (str): Optional preselected tile size
        """
        # Get inputs to gemm
        C = TensorDescriptionWrapper(op.tensor_description(), 2)
        A, B = (TensorDescriptionWrapper(_, 2) for _ in op.call_info())

        # If both inputs are 1d, need to transpose one of them
        if min(A.strides) == 0 and min(B.strides) == 0:
            A.strides = tuple(reversed(A.strides))
            A.shape = tuple(reversed(A.shape))
            vector_dot = True
        else:
            vector_dot = False

        self.C = C
        self.A = A
        self.B = B

        # Kernels only support 2d tensors
        assert len(A.shape) == 2
        assert len(B.shape) == 2
        assert len(C.shape) == 2

        # one dimension must be contiguous
        assert min(A.strides) == 1 or max(A.strides) == 1
        assert min(B.strides) == 1 or max(B.strides) == 1
        assert min(C.strides) == 1 or max(C.strides) == 1 or vector_dot

        lda = max(A.strides)
        ldb = max(B.strides)
        ldc = max(C.strides)

        if A.is_trans:
            opA = 't'
            if size not in ("32x64", "16x64"):
                lda *= 8 * A.dtype.itemsize  # saves a kernel register
        else:
            opA = 'n'

        if B.is_trans:
            opB = 't'
        else:
            opB = 'n'
            if size not in ("32x64", "16x64"):
                ldb *= 8 * B.dtype.itemsize  # saves a kernel register

        op = opA + opB
        assert op != "tt"

        m = A.shape[0]
        n = B.shape[1]
        k = A.shape[1]

        assert m == C.shape[0]
        assert n == C.shape[1]
        assert k == B.shape[0]

        # Flex only has the 128x128 tile size
        if C.is_flex():
            size = "128x128"

        # Some basic tile size selection.
        # Your best bet is to benchmark your code with all 3 sizes
        # and manually fine tune the selection for each layer.
        # TODO: Perhaps I'll add an autotuning mode.
        if size is None:
            # find the shorter side
            short = min(m, n)
            # anything bigger than this just use 128
            if short < 384 - 16:
                # compute remainder of 128
                short128 = short % 128
                # if remainder is more than 112 just use 128
                if 0 < short128 < 112:
                    # to figure out when to use 64 over 32 we need to calc
                    # occupancy at 64
                    if 48 < short128 <= 64:
                        occupancy64 = short // 64
                        wide = max(m, n)
                        occupancy64 *= (wide // 128 +
                                        (wide % 128 != 0)) // _get_sm_count()
                        # 64 is only faster than 32 when occupancy is more than
                        # 1 warp per scheduler.
                        if occupancy64 > 1:
                            size = 64
                        else:
                            size = 32
                    else:
                        size = 32
                else:
                    size = 128
            # There's a large regime where 64 is faster, but it's hard to
            # characterize
            else:
                size = 128

            # match the kernel to the optimal short size but avoid not
            # implemented kernels
            if m >= n:
                if op == "nt":
                    size = 128
                sizeA, sizeB = (128, size)
            else:
                if op == "tn":
                    size = 128
                # temp till I can write these kernels (coming soon)
                elif size == 64:
                    size = 32
                sizeA, sizeB = (size, 128)

            size = "%dx%d" % (sizeA, sizeB)

        else:
            sizeA, sizeB = (int(s) for s in size.split('x'))

        gridA = m // sizeA + (m % sizeA != 0)
        gridB = n // sizeB + (n % sizeB != 0)

        k_vec = 8 if sizeA in (16, 32) or sizeB == 32 else 16

        vec_opt = None
        if op == "tn":
            if (m % 4 == 0 and n % 4 == 0 and A.strides[1] % 4 == 0
                    and B.strides[0] % 4 == 0):
                vec_opt = ("vec", )
        elif op == "nn":
            if (k % k_vec == 0 and n % 4 == 0 and A.strides[0] % k_vec == 0
                    and B.strides[0] % 4 == 0):
                vec_opt = ("vec", )
        elif op == "nt":
            if (k % k_vec == 0 and n % 4 == 0 and A.strides[0] % k_vec == 0
                    and B.strides[1] % k_vec == 0):
                vec_opt = ("vec", )

        # nt and nn are more efficient with k%16==0
        if C.is_flex():
            clss = "fgemm"
        elif C.dtype.type is np.float16:
            clss = "hgemm"
        elif C.dtype.type is np.float32:
            clss = "sgemm"
        else:
            raise TypeError("Only floating point dot currently supported.")

        # TODO: Flex may not have all "size" options (Urs)
        self.kernel = kernel_specs.get_kernel("_".join((clss, op, size)),
                                              vec_opt)
        # alpha, beta
        self.alpha = 1.0
        self.beta = 0.0
        # create params
        # if params list changes, indices in bind_flex_scales may need updating
        self.params = [(1, int(gridA), int(gridB)),
                       (self.kernel.threads, 1, 1), None, C.td, A.td, B.td,
                       self.alpha, self.beta, 0,
                       int(lda),
                       int(ldb),
                       int(ldc),
                       int(m),
                       int(n),
                       int(k), 0, 0, 0, 0]
        if clss == "fgemm":
            # save flex entries for bind_flex_scales
            self.flex_entry_A = A.flex_entry()
            self.flex_entry_B = B.flex_entry()
            self.flex_entry_C = C.flex_entry()

            # flex params
            self.params += [FlexPtrDescription(self.flex_entry_C),
                            1.0]  # maxabs ptr, output scale
            # record output flex id for autoflex
            self.output_flex_ids = [self.flex_entry_C.flex_id]