def _conv2d_winograd_nhwc_impl( data, weight, strides, padding, dilation, out_dtype, tile_size, pre_computed=False, auto_scheduler_rewritten_layout="", ): """Conv2D Winograd implementation in NHWC layout. This is a clean version to be used by the auto-scheduler for both CPU and GPU. Parameters ---------- data : tvm.Tensor 4-D with shape [batch, in_height, in_width, in_channel] weight : tvm.Tensor 4-D with shape [filter_height, filter_width, in_channel, num_filter] strides : int or a list/tuple of two ints stride size, or [stride_height, stride_width] padding : int or a list/tuple of two ints padding size, or [pad_height, pad_width] dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] out_dtype : str, optional Specifies the output data type. tile_size : int The size of the tile to use for the Winograd filter pre_computed: bool = False Whether the kernel is precomputed auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. Returns ------- output : tvm.Tensor 4-D with shape [batch, out_height, out_width, out_channel] """ N, H, W, CI = get_const_tuple(data.shape) if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" if not pre_computed: KH, KW, CI, CO = get_const_tuple(weight.shape) else: if auto_scheduler_rewritten_layout: H_CAT, W_CAT, CO, CI = get_const_tuple( auto_scheduler.get_shape_from_rewritten_layout( auto_scheduler_rewritten_layout, ["eps", "nu", "co", "ci"])) auto_scheduler.remove_index_check(weight) else: H_CAT, W_CAT, CO, CI = get_const_tuple(weight.shape) KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1 pad_t, pad_l, pad_b, pad_r = get_pad_tuple(padding, (KH, KW)) HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides assert HSTR == 1 and WSTR == 1 and KH == 3 and KW == 3 r = KW m = tile_size alpha = m + r - 1 A, B, G = winograd_transform_matrices(m, r, out_dtype) H = (H + pad_t + pad_b - KH) // HSTR + 1 W = (W + pad_l + pad_r - KW) // WSTR + 1 nH, nW = (H + m - 1) // m, (W + m - 1) // m P = N * nH * nW pad_extra = (nW - 1) * m + alpha - (H + pad_t + pad_b) data_pad = pad(data, (0, pad_t, pad_l, 0), (0, pad_b + pad_extra, pad_r + pad_extra, 0), name="data_pad") if not pre_computed: r_kh = te.reduce_axis((0, KH), name="r_kh") r_kw = te.reduce_axis((0, KW), name="r_kw") kernel_pack = te.compute( (alpha, alpha, CO, CI), lambda eps, nu, co, ci: te.sum(weight[r_kh][r_kw][ci][co] * G[eps][ r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name="kernel_pack", ) attrs = {} else: kernel_pack = weight attrs = {"layout_free_placeholders": [kernel_pack]} # pack data tile input_tile = te.compute( (alpha, alpha, P, CI), lambda eps, nu, p, ci: data_pad[p // (nH * nW)][( (p // nW) % nH) * m + eps][(p % nW) * m + nu][ci], name="input_tile", ) # transform data r_a = te.reduce_axis((0, alpha), "r_a") r_b = te.reduce_axis((0, alpha), "r_b") data_pack = te.compute( (alpha, alpha, P, CI), lambda eps, nu, p, ci: te.sum(input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]), name="data_pack", attrs={ "auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"] }, # the attrs are necessary hints for the auto-scheduler ) # do batch gemm ci = te.reduce_axis((0, CI), name="ci") bgemm = te.compute( (alpha, alpha, P, CO), lambda eps, nu, p, co: te.sum(data_pack[eps][nu][p][ci] * kernel_pack[ eps][nu][co][ci], axis=[ci]), name="bgemm", attrs=attrs, ) if auto_scheduler_rewritten_layout: bgemm = auto_scheduler.rewrite_compute_body( bgemm, auto_scheduler_rewritten_layout) # inverse transform r_a = te.reduce_axis((0, alpha), "r_a") r_b = te.reduce_axis((0, alpha), "r_b") inverse = te.compute( (m, m, P, CO), lambda vh, vw, p, co: te.sum( bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]), name="inverse", attrs={ "auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"] }, # the attrs are necessary hints for the auto-scheduler ) # output output = te.compute( (N, H, W, CO), lambda n, h, w, co: inverse[h % m, w % m, n * nH * nW + (h // m) * nW + (w // m), co], name="conv2d_winograd", ) return output
def conv2d_nhwc( Input, Filter, stride, padding, dilation, out_dtype="float32", auto_scheduler_rewritten_layout="", ): """Convolution operator in NHWC layout. Parameters ---------- Input : tvm.te.Tensor 4-D with shape [batch, in_height, in_width, in_channel] Filter : tvm.te.Tensor 4-D with shape [filter_height, filter_width, in_channel, num_filter] stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] padding : int or a list/tuple of 2 or 4 ints padding size, or [pad_height, pad_width] for 2 ints, or [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] out_dtype: str = "float32", The type of output tensor auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. Returns ------- output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, out_channel] """ assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout kernel_h, kernel_w, channel, num_filter = auto_scheduler.get_shape_from_rewritten_layout( auto_scheduler_rewritten_layout, ["ry", "rx", "rc", "ff"]) auto_scheduler.remove_index_check(Filter) else: kernel_h, kernel_w, channel, num_filter = Filter.shape batch, in_height, in_width, in_channel = Input.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) out_channel = num_filter out_height = simplify( (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify( (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) pad_before = [0, pad_top, pad_left, 0] pad_after = [0, pad_down, pad_right, 0] PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") rc = te.reduce_axis((0, in_channel), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") Output = te.compute( (batch, out_height, out_width, out_channel), lambda nn, yy, xx, ff: te.sum( PaddedInput[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * Filter[ ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc], ), name="Conv2dOutput", tag="conv2d_nhwc", attrs={"layout_free_placeholders": [Filter]}, ) if auto_scheduler_rewritten_layout: Output = auto_scheduler.rewrite_compute_body( Output, auto_scheduler_rewritten_layout) return Output
def batch_matmul( tensor_a, tensor_b, oshape=None, out_dtype=None, transpose_a=False, transpose_b=True, auto_scheduler_rewritten_layout="", ): """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format (transpose_a=False, transpose_b=True) by default. Parameters ---------- tensor_a : tvm.te.Tensor 3-D with shape [batch, M, K] or [batch, K, M]. tensor_b : tvm.te.Tensor 3-D with shape [batch, K, N] or [batch, N, K]. oshape : List[Optional] Explicit intended output shape of the computation. Can be useful in cases with dynamic input shapes. out_dtype : Optional[str] Specifies the output data type for mixed precision batch matmul. transpose_a : Optional[bool] = False Whether the first tensor is in transposed format. transpose_b : Optional[bool] = True Whether the second tensor is in transposed format. auto_scheduler_rewritten_layout: Optional[str] = "" The layout after auto-scheduler's layout rewrite pass. Returns ------- output : tvm.te.Tensor 3-D with shape [batch, M, N] """ assert len(tensor_a.shape) == 3, "tensor_a only support 3-dim" if transpose_a: XB, XK, XI = get_const_tuple(tensor_a.shape) else: XB, XI, XK = get_const_tuple(tensor_a.shape) if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout YB, YK, YJ = auto_scheduler.get_shape_from_rewritten_layout( auto_scheduler_rewritten_layout, ["b", "k", "j"]) auto_scheduler.remove_index_check(tensor_b) else: assert len(tensor_b.shape) == 3, "tensor_b only support 3-dim" if transpose_b: YB, YJ, YK = get_const_tuple(tensor_b.shape) else: YB, YK, YJ = get_const_tuple(tensor_b.shape) assert XK == YK or isinstance( YK, tvm.tir.expr.Var), "shapes of x and y are inconsistent" k = te.reduce_axis((0, XK), name="k") if oshape is None: assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" batch = (tvm.tir.expr.SizeVar("batch", "int32") if isinstance(XB, tvm.tir.expr.Var) or isinstance(YB, tvm.tir.expr.Var) else te.max(XB, YB)) oshape = (batch, XI, YJ) if out_dtype is None: out_dtype = tensor_a.dtype if tensor_a.dtype != tensor_b.dtype: logger.warning( "tensor_a has different data type with tensor_b: %s, %s", tensor_a.dtype, tensor_b.dtype, ) if (transpose_a, transpose_b) == (True, True): compute_lambda = lambda b, i, j: te.sum( tensor_a[b if XB != 1 else 0, k, i].astype(out_dtype) * tensor_b[ b if YB != 1 else 0, j, k].astype(out_dtype), axis=k, ) compute_name = "T_batch_matmul_TT" elif (transpose_a, transpose_b) == (True, False): compute_lambda = lambda b, i, j: te.sum( tensor_a[b if XB != 1 else 0, k, i].astype(out_dtype) * tensor_b[ b if YB != 1 else 0, k, j].astype(out_dtype), axis=k, ) compute_name = "T_batch_matmul_TN" elif (transpose_a, transpose_b) == (False, True): compute_lambda = lambda b, i, j: te.sum( tensor_a[b if XB != 1 else 0, i, k].astype(out_dtype) * tensor_b[ b if YB != 1 else 0, j, k].astype(out_dtype), axis=k, ) compute_name = "T_batch_matmul_NT" else: # (transpose_a, transpose_b) == (False, False): compute_lambda = lambda b, i, j: te.sum( tensor_a[b if XB != 1 else 0, i, k].astype(out_dtype) * tensor_b[ b if YB != 1 else 0, k, j].astype(out_dtype), axis=k, ) compute_name = "T_batch_matmul_NN" output = te.compute( oshape, compute_lambda, name=compute_name, tag="batch_matmul", attrs={"layout_free_placeholders": [tensor_b]}, ) if auto_scheduler_rewritten_layout: output = auto_scheduler.rewrite_compute_body( output, auto_scheduler_rewritten_layout) return output
def conv2d_nhwc( Input, Filter, stride, padding, dilation, out_dtype="float32", auto_scheduler_rewritten_layout="", ): """Convolution operator in NHWC layout. Parameters ---------- Input : tvm.te.Tensor 4-D with shape [batch, in_height, in_width, in_channel] Filter : tvm.te.Tensor 4-D with shape [filter_height, filter_width, in_channel, num_filter] stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] padding : int or a list/tuple of 2 or 4 ints padding size, or [pad_height, pad_width] for 2 ints, or [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] Returns ------- output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, out_channel] """ assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout # todo(merrymercy): wrap this with a more general interface. if len(Filter.shape) == 17: # For mali. # GPU tile structure is SSSRRSRS # You could refer function comment of DoMultiLevelTiling # in the utils.h to see more detail explanation. kernel_h = Filter.shape[6] * Filter.shape[9] * Filter.shape[13] kernel_w = Filter.shape[7] * Filter.shape[10] * Filter.shape[14] channel = Filter.shape[8] * Filter.shape[11] * Filter.shape[15] num_filter = Filter.shape[12] * Filter.shape[16] for i in range(6): num_filter *= Filter.shape[i] elif len(Filter.shape) >= 10: # For cpu tile structure SSRSRS base = len(Filter.shape) - 10 kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base] kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base] channel = Filter.shape[4 + base] * Filter.shape[8 + base] num_filter = Filter.shape[5 + base] * Filter.shape[9 + base] for i in range(base + 2): num_filter *= Filter.shape[i] elif len(Filter.shape) == 4: num_filter, kernel_h, kernel_w, channel = Filter.shape else: raise ValueError( "Don't know how to infer the layout for filter shape: %s. " "Please add a new branch to handle this case." % str(Filter)) auto_scheduler.remove_index_check(Filter) else: kernel_h, kernel_w, channel, num_filter = Filter.shape batch, in_height, in_width, in_channel = Input.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) out_channel = num_filter out_height = simplify( (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify( (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) pad_before = [0, pad_top, pad_left, 0] pad_after = [0, pad_down, pad_right, 0] PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") rc = te.reduce_axis((0, in_channel), name="rc") ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") Output = te.compute( (batch, out_height, out_width, out_channel), lambda nn, yy, xx, ff: te.sum( PaddedInput[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * Filter[ ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc], ), name="Conv2dOutput", tag="conv2d_nhwc", attrs={"layout_free_placeholders": [Filter]}, ) if auto_scheduler_rewritten_layout: Output = auto_scheduler.rewrite_compute_body( Output, auto_scheduler_rewritten_layout) return Output
def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layout=""): """The default implementation of dense in topi. Parameters ---------- data : tvm.te.Tensor 2-D with shape [batch, in_dim] weight : tvm.te.Tensor 2-D with shape [out_dim, in_dim] bias : Optional[tvm.te.Tensor] 1-D with shape [out_dim] out_dtype : Optional[str] The output type. This is used for mixed precision. auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. Returns ------- output : tvm.te.Tensor 2-D with shape [batch, out_dim] """ assert len(data.shape) == 2, "only support 2-dim dense" if bias is not None: assert len(bias.shape) == 1 if out_dtype is None: out_dtype = data.dtype batch, in_dim = data.shape if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout( auto_scheduler_rewritten_layout, ["j", "k"]) auto_scheduler.remove_index_check(weight) else: out_dim, red_dim = weight.shape assert in_dim == red_dim k = te.reduce_axis((0, in_dim), name="k") matmul = te.compute( (batch, out_dim), lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype( out_dtype), axis=k), name="T_dense", tag="dense", attrs={"layout_free_placeholders": [weight]}, ) if bias is not None: matmul = te.compute( (batch, out_dim), lambda i, j: matmul[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST, ) if auto_scheduler_rewritten_layout: matmul = auto_scheduler.rewrite_compute_body( matmul, auto_scheduler_rewritten_layout) return matmul
def conv( inp: te.Tensor, filt: te.Tensor, stride: Union[int, Sequence[int]], padding: Union[int, Sequence[int]], dilation: Union[int, Sequence[int]], groups: int, order: str, out_dtype: Union[str, None] = None, auto_scheduler_rewritten_layout: Optional[str] = None, ): """Convolution operator in NCHW or NHWC layout. Supports 1D, 2D, 3D, ... and grouping. Parameters ---------- inp : tvm.te.Tensor N-D with shape [batch, in_channel, in_height, in_width, ...] ordered by `order` filt : tvm.te.Tensor N-D with shape [num_filter, in_channel // groups, filter_height, filter_width, ...] for NCHW or [filter_height, filter_width, ..., in_channel // groups, num_filter] for NHWC stride : int or a list/tuple of dim ints (where dim=2 for NCHW, dim=1 for NCH, etc.) Stride size, or [stride_height, stride_width, ...] padding : int or a list/tuple of dim or 2*dim ints (where dim=2 for NCHW, dim=1 for NCH, etc.) padding size, or [pad_height, pad_width, ...] for dim ints, or [pad_top, pad_left, pad_bottom, pad_right] for 2*dim ints dilation : int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] groups : int number of groups order : str Ordering of dimensions. N indicates batch dimension, C indicates channels, any other character indicates HW (or H or HWD for 1D and 3D). out_dtype : str Elements are converted to this type before elementwise multiplication and summation. auto_scheduler_rewritten_layout: str Layout from autoscheduler's layout rewritting. Returns ------- Output : tvm.te.Tensor N-D with shape [batch, out_channel, out_height, out_width, ...] ordered by `order`. """ dim = len(inp.shape) - 2 if out_dtype is None: out_dtype = inp.dtype assert isinstance(stride, int) or len(stride) == dim assert isinstance(dilation, int) or len(dilation) == dim if isinstance(stride, int): strides = [stride for _ in range(dim)] else: strides = stride if isinstance(dilation, int): dilations = [dilation for _ in range(dim)] else: dilations = list(dilation) # transform from order to NCHW permutation_to = [order.find("N"), order.find("C") ] + [x.span()[0] for x in re.finditer("[^NC]", order)] # transform from NCHW to order permutation_from = np.argsort(permutation_to) # transform from CHW to order permutation_from_reductions = permutation_from[1:].copy() permutation_from_reductions[ permutation_from_reductions > permutation_from[0]] -= 1 # kernel permutation, if C appears before HW then num_filter is first, otherwise it is last # tkonolige: I don't really understand kernel ordering for NHWC, it seems # like num_filters should match the N dimension if order.find("C") < re.search("[^NC]", order).span()[0]: permutation_to_kernel = [0, 1] + list(range(2, dim + 2)) else: permutation_to_kernel = [dim + 1, dim] + list(range(dim)) permutation_from_kernel = np.argsort(permutation_to_kernel) batch, in_channel, *dimensions = np.array(get_const_tuple( inp.shape))[permutation_to].tolist() num_filter, _, *kernel_dimensions = np.array(get_const_tuple( filt.shape))[permutation_to_kernel].tolist() # Autoscheduler may have messed with the input layout, so we extract the # dimensions that it gives us if auto_scheduler_rewritten_layout: num_filter, _, *kernel_dimensions = auto_scheduler.get_shape_from_rewritten_layout( auto_scheduler_rewritten_layout, ["ff", "rc"] + [f"r{i}" for i in ["y", "x", "z"][:len(kernel_dimensions)]], ) auto_scheduler.remove_index_check(filt) assert in_channel % groups == 0, "input channels must divide group size" assert num_filter % groups == 0, "output channels must divide group size" dilated_kernel_dimensions = [ (k - 1) * dil + 1 for k, dil in zip(kernel_dimensions, dilations) ] pad_begin, pad_end = get_pad_tuple_generic(padding, dilated_kernel_dimensions) # compute the output shape out_channel = num_filter out_dimensions = [ simplify(d - (k - 1) * dil - 1 + pb + pe) // stride + 1 for d, k, dil, pb, pe, stride in zip(dimensions, kernel_dimensions, dilations, pad_begin, pad_end, strides) ] # compute graph pad_before = list(np.array([0, 0] + pad_begin)[permutation_from]) pad_after = list(np.array([0, 0] + pad_end)[permutation_from]) temp = pad(inp, pad_before, pad_after, name="pad_temp") rc = te.reduce_axis((0, in_channel // groups), name="rc") rs = [ te.reduce_axis((0, k), name=f"r{i}") for i, k in zip(["y", "x", "z"], kernel_dimensions) ] def compute(*args): nn, ff, *dim_indices = list(np.array(args)[permutation_to]) if groups == 1: simplified_channel_index = rc else: simplified_channel_index = ff // (num_filter // groups) * ( in_channel // groups) + rc return te.sum( temp.__getitem__( tuple( np.array([nn, simplified_channel_index] + [ di * stride + r * dil for di, stride, r, dil in zip( dim_indices, strides, rs, dilations) ])[permutation_from])).astype(out_dtype) * filt.__getitem__( tuple( np.array([ff, rc] + rs)[permutation_from_kernel])).astype(out_dtype), # Schedules depend on reduction axes being in the same order as the # layout, so we reorder here. axis=np.array([rc, *rs])[permutation_from_reductions].tolist(), ) out = te.compute( list( np.array([batch, out_channel] + out_dimensions)[permutation_from]), compute, # tag is expected to be lowercase tag=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}", name=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}", attrs={"layout_free_placeholders": [filt]}, varargs_names=list( np.array(["nn", "ff", "yy", "xx", "zz"])[permutation_from]), ) # if we used autoscheduler's changed layout we need to rewrite the ordering # of the output dimensions if auto_scheduler_rewritten_layout: out = auto_scheduler.rewrite_compute_body( out, auto_scheduler_rewritten_layout) return out
def matmul( tensor_a, tensor_b, bias=None, out_dtype=None, transpose_a=False, transpose_b=False, auto_scheduler_rewritten_layout="", ): """The default implementation of matmul in topi. Parameters ---------- tensor_a : tvm.te.Tensor 2-D with shape [batch, in_dim] tensor_b : tvm.te.Tensor 2-D with shape [out_dim, in_dim] bias : Optional[tvm.te.Tensor] 1-D with shape [out_dim] out_dtype : Optional[str] The output type. This is used for mixed precision. transpose_a : Optional[bool] = False Whether the tensor_a is in transposed format. transpose_b : Optional[bool] = False Whether the tensor_b is in transposed format. auto_scheduler_rewritten_layout: Optional[str] = "" The layout after auto-scheduler's layout rewrite pass. Returns ------- output : tvm.te.Tensor 2-D with shape [batch, out_dim] """ # TODO(jcf94): Add multi-dim support for tensor_a assert len(tensor_a.shape) == 2, "only support 2-dim matmul" if bias is not None: assert len(bias.shape) == 1 if out_dtype is None: out_dtype = tensor_a.dtype if transpose_a: in_dim, batch = tensor_a.shape else: batch, in_dim = tensor_a.shape if auto_scheduler_rewritten_layout: # Infer shape for the rewritten layout out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout( auto_scheduler_rewritten_layout, ["j", "k"]) auto_scheduler.remove_index_check(tensor_b) elif transpose_b: out_dim, red_dim = tensor_b.shape else: red_dim, out_dim = tensor_b.shape # cmp should be done by values assert int(in_dim) == int(red_dim) k = te.reduce_axis((0, in_dim), name="k") if (transpose_a, transpose_b) == (True, True): compute_lambda = lambda i, j: te.sum(tensor_a[k, i].astype(out_dtype) * tensor_b[j, k].astype(out_dtype), axis=k) compute_name = "T_matmul_TT" compute_tag = "matmul" elif (transpose_a, transpose_b) == (True, False): compute_lambda = lambda i, j: te.sum(tensor_a[k, i].astype(out_dtype) * tensor_b[k, j].astype(out_dtype), axis=k) compute_name = "T_matmul_TN" compute_tag = "matmul" elif (transpose_a, transpose_b) == (False, True): compute_lambda = lambda i, j: te.sum(tensor_a[i, k].astype(out_dtype) * tensor_b[j, k].astype(out_dtype), axis=k) compute_name = "T_matmul_NT" # TODO(jcf94): Remove `dense` when `matmul` is finally ready compute_tag = "dense" else: # (transpose_a, transpose_b) == (False, False): compute_lambda = lambda i, j: te.sum(tensor_a[i, k].astype(out_dtype) * tensor_b[k, j].astype(out_dtype), axis=k) compute_name = "T_matmul_NN" compute_tag = "matmul" mat = te.compute( (batch, out_dim), compute_lambda, name=compute_name, tag=compute_tag, attrs={"layout_free_placeholders": [tensor_b]}, ) if bias is not None: mat = te.compute( (batch, out_dim), lambda i, j: mat[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST, ) if auto_scheduler_rewritten_layout: mat = auto_scheduler.rewrite_compute_body( mat, auto_scheduler_rewritten_layout) return mat