def compute(shape, fcompute, name="compute", tag="", attrs=None): """Construct a new tensor by computing over the shape domain. The compute rule is result[axis] = fcompute(axis) Parameters ---------- shape: Tuple of Expr The shape of the tensor fcompute: lambda function of indices-> value Specifies the input source expression name: str, optional The name hint of the tensor tag: str, optional Additional tag information about the compute. attrs: dict, optional The additional auxiliary attributes about the compute. Returns ------- tensor: Tensor The created tensor """ if _tag.TagScope.get_current() is not None: if tag != "": raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.get_current().tag shape = (shape, ) if isinstance(shape, tvm.tir.PrimExpr) else shape # for python3 shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) ndim = len(shape) code = fcompute.__code__ out_ndim = ndim if code.co_argcount == 0: arg_names = ["i%d" % i for i in range(ndim)] else: arg_names = code.co_varnames[:code.co_argcount] out_ndim = code.co_argcount if out_ndim != len(arg_names): raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) dim_var = [ tvm.tir.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim]) ] body = fcompute(*[v.var for v in dim_var]) if isinstance(body, _tensor.TensorIntrinCall): for i, s in enumerate(shape[out_ndim:]): var_name = "ax" + str(i) dim_var.append(tvm.tir.IterVar((0, s), var_name, 4)) op_node = _ffi_api.TensorComputeOp(name, tag, dim_var, body.reduce_axis, out_ndim, body.intrin, body.tensors, body.regions, body.scalar_inputs) else: if not isinstance(body, (list, tuple)): body = [body] body = convert(body) op_node = _ffi_api.ComputeOp(name, tag, attrs, dim_var, body) num = op_node.num_outputs outputs = tuple(op_node.output(i) for i in range(num)) return outputs[0] if num == 1 else outputs
def compute(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=None): """Construct a new tensor by computing over the shape domain. The compute rule is result[axis] = fcompute(axis) Parameters ---------- shape: Tuple of Expr The shape of the tensor fcompute: lambda function of indices-> value Specifies the input source expression name: str, optional The name hint of the tensor tag: str, optional Additional tag information about the compute. attrs: dict, optional The additional auxiliary attributes about the compute. varargs_names: list, optional The names to use for each of the varargs. If not supplied, the varargs will be called i1, i2, ... Returns ------- tensor: Tensor The created tensor """ if _tag.TagScope.get_current() is not None: if tag != "": raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.get_current().tag shape = (shape, ) if isinstance(shape, tvm.tir.PrimExpr) else shape # for python3 shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) out_ndim = len(shape) argspec = inspect.getfullargspec(fcompute) if len(argspec.args) == 0 and argspec.varargs is None: arg_names = ["i%d" % i for i in range(out_ndim)] elif argspec.varargs is not None: # if there is a varargs, it takes the remaining dimensions of out_ndim num_remaining_args = out_ndim - len(argspec.args) if varargs_names is not None: if len(varargs_names) != num_remaining_args: raise RuntimeError( f"Number of varargs ({num_remaining_args}) does not match number" f"of varargs_names ({len(varargs_names)})") arg_names = argspec.args + varargs_names else: arg_names = argspec.args + [ f"i{i}" for i in range(out_ndim - len(argspec.args)) ] else: arg_names = argspec.args # if there are fewer args than out dimensions, the remaining dimensions # are implicitly broadcast out_ndim = len(arg_names) assert argspec.varkw is None, "Variable keyword arguments not supported in fcompute" assert argspec.defaults is None, "Default arguments not supported in fcompute" assert len(argspec.kwonlyargs ) == 0, "Keyword arguments are not supported in fcompute" if out_ndim != len(arg_names): raise ValueError( "Number of args to fcompute does not match dimension, " "args=%d, dimension=%d" % (len(arg_names), out_ndim)) dim_var = [ tvm.tir.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim]) ] body = fcompute(*[v.var for v in dim_var]) if isinstance(body, _tensor.TensorIntrinCall): for i, s in enumerate(shape[out_ndim:]): var_name = "ax" + str(i) dim_var.append(tvm.tir.IterVar((0, s), var_name, 4)) op_node = _ffi_api.TensorComputeOp( name, tag, dim_var, body.reduce_axis, out_ndim, body.intrin, body.tensors, body.regions, body.scalar_inputs, ) else: if not isinstance(body, (list, tuple)): body = [body] body = convert(body) op_node = _ffi_api.ComputeOp(name, tag, attrs, dim_var, body) num = op_node.num_outputs outputs = tuple(op_node.output(i) for i in range(num)) return outputs[0] if num == 1 else outputs