def reduction_function_mangler(kernel, func_id, arg_dtypes): if isinstance(func_id, ArgExtFunction) and func_id.name == "init": from loopy.target.opencl import OpenCLTarget if not isinstance(kernel.target, OpenCLTarget): raise LoopyError("only OpenCL supported for now") op = func_id.reduction_op from loopy.kernel.data import CallMangleInfo return CallMangleInfo( target_name="%s_init" % op.prefix(func_id.scalar_dtype), result_dtypes=op.result_dtypes(kernel, func_id.scalar_dtype, func_id.inames), arg_dtypes=(), ) elif isinstance(func_id, ArgExtFunction) and func_id.name == "update": from loopy.target.opencl import OpenCLTarget if not isinstance(kernel.target, OpenCLTarget): raise LoopyError("only OpenCL supported for now") op = func_id.reduction_op from loopy.kernel.data import CallMangleInfo return CallMangleInfo( target_name="%s_update" % op.prefix(func_id.scalar_dtype), result_dtypes=op.result_dtypes(kernel, func_id.scalar_dtype, func_id.inames), arg_dtypes=(func_id.scalar_dtype, kernel.index_dtype, func_id.scalar_dtype, kernel.index_dtype), ) return None
def pyopencl_function_mangler(target, name, arg_dtypes): if len(arg_dtypes) == 1 and isinstance(name, str): arg_dtype, = arg_dtypes if arg_dtype.is_complex(): if arg_dtype.numpy_dtype == np.complex64: tpname = "cfloat" elif arg_dtype.numpy_dtype == np.complex128: tpname = "cdouble" else: raise RuntimeError("unexpected complex type '%s'" % arg_dtype) if name in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh", "conj"]: return CallMangleInfo( target_name="%s_%s" % (tpname, name), result_dtypes=(arg_dtype,), arg_dtypes=(arg_dtype,)) if name in ["real", "imag", "abs"]: return CallMangleInfo( target_name="%s_%s" % (tpname, name), result_dtypes=(NumpyType( np.dtype(arg_dtype.numpy_dtype.type(0).real)), ), arg_dtypes=(arg_dtype,)) return None
def bessel_function_mangler(kernel, name, arg_dtypes): from loopy.types import NumpyType if name == "bessel_j" and len(arg_dtypes) == 2: n_dtype, x_dtype, = arg_dtypes # *technically* takes a float, but let's not worry about that. if n_dtype.numpy_dtype.kind != "i": raise TypeError("%s expects an integer first argument") from loopy.kernel.data import CallMangleInfo return CallMangleInfo( "bessel_jv", (NumpyType(np.float64),), (NumpyType(np.int32), NumpyType(np.float64)), ) elif name == "bessel_y" and len(arg_dtypes) == 2: n_dtype, x_dtype, = arg_dtypes # *technically* takes a float, but let's not worry about that. if n_dtype.numpy_dtype.kind != "i": raise TypeError("%s expects an integer first argument") from loopy.kernel.data import CallMangleInfo return CallMangleInfo( "bessel_yn", (NumpyType(np.float64),), (NumpyType(np.int32), NumpyType(np.float64)), ) return None
def opencl_function_mangler(kernel, name, arg_dtypes): if not isinstance(name, str): return None # OpenCL has min(), max() for integer types if name in ["max", "min"] and len(arg_dtypes) == 2: dtype = np.find_common_type( [], [dtype.numpy_dtype for dtype in arg_dtypes]) if dtype.kind == "i": result_dtype = NumpyType(dtype) return CallMangleInfo(target_name=name, result_dtypes=(result_dtype, ), arg_dtypes=2 * (result_dtype, )) if name == "dot": scalar_dtype, offset, field_name = arg_dtypes[0].numpy_dtype.fields[ "s0"] return CallMangleInfo(target_name=name, result_dtypes=(NumpyType(scalar_dtype), ), arg_dtypes=(arg_dtypes[0], ) * 2) if name in _CL_SIMPLE_MULTI_ARG_FUNCTIONS: num_args = _CL_SIMPLE_MULTI_ARG_FUNCTIONS[name] if len(arg_dtypes) != num_args: raise LoopyError("%s takes %d arguments (%d received)" % (name, num_args, len(arg_dtypes))) dtype = np.find_common_type( [], [dtype.numpy_dtype for dtype in arg_dtypes]) if dtype.kind == "c": raise LoopyError("%s does not support complex numbers" % name) result_dtype = NumpyType(dtype) return CallMangleInfo(target_name=name, result_dtypes=(result_dtype, ), arg_dtypes=(result_dtype, ) * num_args) if name in VECTOR_LITERAL_FUNCS: base_tp_name, dtype, count = VECTOR_LITERAL_FUNCS[name] if count != len(arg_dtypes): return None return CallMangleInfo(target_name="(%s%d) " % (base_tp_name, count), result_dtypes=(kernel.target.vector_dtype( NumpyType(dtype), count), ), arg_dtypes=(NumpyType(dtype), ) * count) return None
def tuple_function_mangler(kernel, name, arg_dtypes): if name == "make_tuple": from loopy.kernel.data import CallMangleInfo return CallMangleInfo(target_name="loopy_make_tuple", result_dtypes=arg_dtypes, arg_dtypes=arg_dtypes) return None
def single_arg_function_mangler(kernel, name, arg_dtypes): if len(arg_dtypes) == 1: dtype, = arg_dtypes from loopy.kernel.data import CallMangleInfo return CallMangleInfo(name, (dtype, ), (dtype, )) return None
def reduction_function_mangler(kernel, func_id, arg_dtypes): if isinstance(func_id, ArgExtOp): from loopy.target.opencl import CTarget if not isinstance(kernel.target, CTarget): raise LoopyError("%s: only C-like targets supported for now" % func_id) op = func_id.reduction_op scalar_dtype = arg_dtypes[0] index_dtype = arg_dtypes[1] from loopy.kernel.data import CallMangleInfo return CallMangleInfo( target_name="%s_op" % op.prefix( scalar_dtype, index_dtype), result_dtypes=op.result_dtypes( kernel, scalar_dtype, index_dtype), arg_dtypes=( scalar_dtype, index_dtype, scalar_dtype, index_dtype), ) elif isinstance(func_id, SegmentedOp): from loopy.target.opencl import CTarget if not isinstance(kernel.target, CTarget): raise LoopyError("%s: only C-like targets supported for now" % func_id) op = func_id.reduction_op scalar_dtype = arg_dtypes[0] segment_flag_dtype = arg_dtypes[1] from loopy.kernel.data import CallMangleInfo return CallMangleInfo( target_name="%s_op" % op.prefix( scalar_dtype, segment_flag_dtype), result_dtypes=op.result_dtypes( kernel, scalar_dtype, segment_flag_dtype), arg_dtypes=( scalar_dtype, segment_flag_dtype, scalar_dtype, segment_flag_dtype), ) return None
def no_ret_f_mangler(kernel, name, arg_dtypes): if not isinstance(name, str): return None if (name == "f" and len(arg_dtypes) == 0): from loopy.kernel.data import CallMangleInfo return CallMangleInfo(target_name="f", result_dtypes=arg_dtypes, arg_dtypes=arg_dtypes)
def _numpy_single_arg_function_mangler(kernel, name, arg_dtypes): if (not isinstance(name, str) or not hasattr(np, name) or len(arg_dtypes) != 1): return None arg_dtype, = arg_dtypes from loopy.kernel.data import CallMangleInfo return CallMangleInfo(target_name="_lpy_np." + name, result_dtypes=(arg_dtype, ), arg_dtypes=arg_dtypes)
def mangle_function(self, identifier, arg_dtypes, ast_builder=None): if ast_builder is None: ast_builder = self.target.get_device_ast_builder() manglers = ast_builder.function_manglers() + self.function_manglers for mangler in manglers: mangle_result = mangler(self, identifier, arg_dtypes) if mangle_result is not None: from loopy.kernel.data import CallMangleInfo if isinstance(mangle_result, CallMangleInfo): assert len(mangle_result.arg_dtypes) == len(arg_dtypes) return mangle_result assert isinstance(mangle_result, tuple) from warnings import warn warn( "'%s' returned a tuple instead of a CallMangleInfo instance. " "This is deprecated." % mangler.__name__, DeprecationWarning) if len(mangle_result) == 2: result_dtype, target_name = mangle_result return CallMangleInfo(target_name=target_name, result_dtypes=(result_dtype, ), arg_dtypes=None) elif len(mangle_result) == 3: result_dtype, target_name, actual_arg_dtypes = mangle_result return CallMangleInfo(target_name=target_name, result_dtypes=(result_dtype, ), arg_dtypes=actual_arg_dtypes) else: raise ValueError( "unexpected size of tuple returned by '%s'" % mangler.__name__) return None
def random123_function_mangler(kernel, name, arg_dtypes): try: rng_variant = FUNC_NAMES_TO_RNG[name] except KeyError: return None from loopy.types import NumpyType target = kernel.target base_dtype = {32: np.uint32, 64: np.uint64}[rng_variant.bits] ctr_dtype = target.vector_dtype(NumpyType(base_dtype), rng_variant.width) key_dtype = target.vector_dtype(NumpyType(base_dtype), rng_variant.key_width) from loopy.kernel.data import CallMangleInfo fn = rng_variant.full_name if name == fn: return CallMangleInfo( target_name=fn+"_gen", result_dtypes=(ctr_dtype, ctr_dtype), arg_dtypes=(ctr_dtype, key_dtype)) elif name == fn + "_f32": return CallMangleInfo( target_name=name, result_dtypes=( target.vector_dtype(NumpyType(np.float32), rng_variant.width), ctr_dtype), arg_dtypes=(ctr_dtype, key_dtype)) elif name == fn + "_f64": return CallMangleInfo( target_name=name, result_dtypes=( target.vector_dtype(NumpyType(np.float64), rng_variant.width), ctr_dtype), arg_dtypes=(ctr_dtype, key_dtype)) else: return None
def __call__(self, kernel, name, arg_dtypes): """ A function that will return a :class:`loopy.kernel.data.CallMangleInfo` to interface with the calling :class:`loopy.LoopKernel` """ if name != self.func_name: return None from loopy.types import to_loopy_type from loopy.kernel.data import CallMangleInfo def __compare(d1, d2): # compare dtypes ignoring atomic return to_loopy_type(d1, for_atomic=True) == \ to_loopy_type(d2, for_atomic=True) # check types if len(arg_dtypes) != len(arg_dtypes): raise Exception( 'Unexpected number of arguments provided to mangler ' '{}, expected {}, got {}'.format(self.func_name, len(self.func_arg_dtypes), len(arg_dtypes))) for i, (d1, d2) in enumerate(zip(self.func_arg_dtypes, arg_dtypes)): if not __compare(d1, d2): raise Exception( 'Argument at index {} for mangler {} does not ' 'match expected dtype. Expected {}, got {}'.format( i, self.func_name, str(d1), str(d2))) # get target for creation target = arg_dtypes[0].target return CallMangleInfo(target_name=self.func_name, result_dtypes=tuple( to_loopy_type(x, target=target) for x in self.func_result_dtypes), arg_dtypes=arg_dtypes)
def c_math_mangler(target, name, arg_dtypes, modify_name=True): # Function mangler for math functions defined in C standard # Convert abs, min, max to fabs, fmin, fmax. # If modify_name is set to True, function names are modified according to # floating point types of the arguments (e.g. cos(double), cosf(float)) # This should be set to True for C and Cuda, False for OpenCL if not isinstance(name, str): return None if name in ["abs", "min", "max"]: name = "f" + name # unitary functions if (name in ["fabs", "acos", "asin", "atan", "cos", "cosh", "sin", "sinh", "tanh", "exp", "log", "log10", "sqrt", "ceil", "floor"] and len(arg_dtypes) == 1 and arg_dtypes[0].numpy_dtype.kind == "f"): dtype = arg_dtypes[0].numpy_dtype if modify_name: if dtype == np.float64: pass # fabs elif dtype == np.float32: name = name + "f" # fabsf elif dtype == np.float128: # pylint:disable=no-member name = name + "l" # fabsl else: raise LoopyTypeError("%s does not support type %s" % (name, dtype)) return CallMangleInfo( target_name=name, result_dtypes=arg_dtypes, arg_dtypes=arg_dtypes) # binary functions if (name in ["fmax", "fmin", "copysign"] and len(arg_dtypes) == 2): dtype = np.find_common_type( [], [dtype.numpy_dtype for dtype in arg_dtypes]) if dtype.kind == "c": raise LoopyTypeError("%s does not support complex numbers") elif dtype.kind == "f": if modify_name: if dtype == np.float64: pass # fmin elif dtype == np.float32: name = name + "f" # fminf elif dtype == np.float128: # pylint:disable=no-member name = name + "l" # fminl else: raise LoopyTypeError("%s does not support type %s" % (name, dtype)) result_dtype = NumpyType(dtype) return CallMangleInfo( target_name=name, result_dtypes=(result_dtype,), arg_dtypes=2*(result_dtype,)) return None