def _transpile_function(func, attributes, mode, consts, in_types, ret_type): """Transpile the function Args: func (ast.FunctionDef): Target function. attributes (str): The attributes of target function. mode ('numpy' or 'cuda'): The rule for typecast. consts (dict): The dictionary with keys as variable names and values as concrete data object. in_types (list of _types.TypeBase): The types of arguments. ret_type (_types.TypeBase): The type of return value. Returns: code (str): The generated CUDA code. env (Environment): More details of analysis result of the function, which includes preambles, estimated return type and more. """ consts = dict([(k, Constant(v)) for k, v, in consts.items()]) if not isinstance(func, ast.FunctionDef): # TODO(asi1024): Support for `ast.ClassDef`. raise NotImplementedError('Not supported: {}'.format(type(func))) if len(func.decorator_list) > 0: if sys.version_info >= (3, 9): # Code path for Python versions that support `ast.unparse`. for deco in func.decorator_list: deco_code = ast.unparse(deco) if deco_code not in ['rawkernel', 'vectorize']: warnings.warn( f'Decorator {deco_code} may not supported in JIT.', RuntimeWarning) arguments = func.args if arguments.vararg is not None: raise NotImplementedError('`*args` is not supported currently.') if len(arguments.kwonlyargs) > 0: # same length with `kw_defaults`. raise NotImplementedError( 'keyword only arguments are not supported currently .') if arguments.kwarg is not None: raise NotImplementedError('`**kwargs` is not supported currently.') if len(arguments.defaults) > 0: raise NotImplementedError( 'Default values are not supported currently.') args = [arg.arg for arg in arguments.args] if len(args) != len(in_types): raise TypeError( f'{func.name}() takes {len(args)} positional arguments ' f'but {len(in_types)} were given.') params = dict([(x, CudaObject(x, t)) for x, t in zip(args, in_types)]) env = Environment(mode, consts, params, ret_type) body = _transpile_stmts(func.body, True, env) params = ', '.join([env[a].ctype.declvar(a) for a in args]) local_vars = [v.ctype.declvar(n) + ';' for n, v in env.locals.items()] if env.ret_type is None: env.ret_type = _types.Void() head = f'{attributes} {env.ret_type} {func.name}({params})' code = CodeBlock(head, local_vars + body) return str(code), env
def _transpile_function( func, attributes, mode, consts, in_types, ret_type): """Transpile the function Args: func (ast.FunctionDef): Target function. attributes (str): The attributes of target function. mode ('numpy' or 'cuda'): The rule for typecast. consts (dict): The dictionary with keys as variable names and values as concrete data object. in_types (list of _types.TypeBase): The types of arguments. ret_type (_types.TypeBase): The type of return value. Returns: code (str): The generated CUDA code. env (Environment): More details of analysis result of the function, which includes preambles, estimated return type and more. """ if not isinstance(func, ast.FunctionDef): # TODO(asi1024): Support for `ast.ClassDef`. raise NotImplementedError('Not supported: {}'.format(type(func))) if len(func.decorator_list) > 0: raise NotImplementedError('Decorator is not supported') arguments = func.args if arguments.vararg is not None: raise NotImplementedError('`*args` is not supported currently.') if len(arguments.kwonlyargs) > 0: # same length with `kw_defaults`. raise NotImplementedError( 'keyword only arguments are not supported currently .') if arguments.kwarg is not None: raise NotImplementedError('`**kwargs` is not supported currently.') if len(arguments.defaults) > 0: raise NotImplementedError( 'Default values are not supported currently.') args = [arg.arg for arg in arguments.args] if len(args) != len(in_types): raise TypeError( f'{func.name}() takes {len(args)} positional arguments ' 'but {len(in_types)} were given.') env = Environment( mode, dict([(k, Constant(v)) for k, v, in consts.items()]), dict([(x, CudaObject(x, t)) for x, t in zip(args, in_types)]), ret_type) body = _transpile_stmts(func.body, True, env) params = ', '.join([f'{env[a].ctype} {a}' for a in args]) local_vars = [f'{v.ctype} {n};' for n, v in env.locals.items()] head = f'{attributes} {env.ret_type} {func.name}({params})' code = CodeBlock(head, local_vars + body) return str(code), env
def _transpile_function_internal(func, attributes, mode, consts, in_types, ret_type): consts = dict([(k, Constant(v)) for k, v, in consts.items()]) if not isinstance(func, ast.FunctionDef): # TODO(asi1024): Support for `ast.ClassDef`. raise NotImplementedError('Not supported: {}'.format(type(func))) if len(func.decorator_list) > 0: if sys.version_info >= (3, 9): # Code path for Python versions that support `ast.unparse`. for deco in func.decorator_list: deco_code = ast.unparse(deco) if not any(word in deco_code for word in ['rawkernel', 'vectorize']): warnings.warn( f'Decorator {deco_code} may not supported in JIT.', RuntimeWarning) arguments = func.args if arguments.vararg is not None: raise NotImplementedError('`*args` is not supported currently.') if len(arguments.kwonlyargs) > 0: # same length with `kw_defaults`. raise NotImplementedError( 'keyword only arguments are not supported currently .') if arguments.kwarg is not None: raise NotImplementedError('`**kwargs` is not supported currently.') if len(arguments.defaults) > 0: raise NotImplementedError( 'Default values are not supported currently.') args = [arg.arg for arg in arguments.args] if len(args) != len(in_types): raise TypeError( f'{func.name}() takes {len(args)} positional arguments ' f'but {len(in_types)} were given.') params = dict([(x, CudaObject(x, t)) for x, t in zip(args, in_types)]) env = Environment(mode, consts, params, ret_type) body = _transpile_stmts(func.body, True, env) params = ', '.join([env[a].ctype.declvar(a) for a in args]) local_vars = [v.ctype.declvar(n) + ';' for n, v in env.locals.items()] if env.ret_type is None: env.ret_type = _types.Void() head = f'{attributes} {env.ret_type} {func.name}({params})' code = CodeBlock(head, local_vars + body) return str(code), env
def _transpile_stmt(stmt, is_toplevel, env): """Transpile the statement. Returns (list of [CodeBlock or str]): The generated CUDA code. """ if isinstance(stmt, ast.ClassDef): raise NotImplementedError('class is not supported currently.') if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)): raise NotImplementedError( 'Nested functions are not supported currently.') if isinstance(stmt, ast.Return): value = _transpile_expr(stmt.value, env) value = _to_cuda_object(value, env) t = value.ctype if env.ret_type is None: env.ret_type = t elif env.ret_type != t: raise ValueError( f'Failed to infer the return type: {env.ret_type} or {t}') return [f'return {value.code};'] if isinstance(stmt, ast.Delete): raise NotImplementedError('`del` is not supported currently.') if isinstance(stmt, ast.Assign): if len(stmt.targets) != 1: raise NotImplementedError('Not implemented.') target = stmt.targets[0] if not isinstance(target, ast.Name): raise NotImplementedError('Tuple is not supported.') name = target.id value = _transpile_expr(stmt.value, env) if is_constants([value]): if not isinstance(value.obj, _typeclasses): if is_toplevel: if env[name] is not None and not is_constants([env[name]]): raise TypeError(f'Type mismatch of variable: `{name}`') env.consts[name] = value return [] else: raise TypeError( 'Cannot assign constant value not at top-level.') value = _to_cuda_object(value, env) if env[name] is None: env[name] = CudaObject(target.id, value.ctype) elif is_constants([env[name]]): raise TypeError('Type mismatch of variable: `{name}`') elif env[name].ctype.dtype != value.ctype.dtype: raise TypeError( f'Data type mismatch of variable: `{name}`: ' f'{env[name].ctype.dtype} != {value.ctype.dtype}') return [f'{target.id} = {value.code};'] if isinstance(stmt, ast.AugAssign): value = _transpile_expr(stmt.value, env) target = _transpile_expr(stmt.target, env) assert isinstance(target, CudaObject) value = _to_cuda_object(value, env) result = _eval_operand(stmt.op, (target, value), env) if not numpy.can_cast( result.ctype.dtype, target.ctype.dtype, 'same_kind'): raise TypeError('dtype mismatch') return [f'{target.code} = {result.code};'] if isinstance(stmt, ast.For): if len(stmt.orelse) > 0: raise NotImplementedError('while-else is not supported.') name = stmt.target.id iters = _transpile_expr(stmt.iter, env) if env[name] is None: env[name] = CudaObject(stmt.target.id, iters.ctype) elif env[name].ctype.dtype != iters.ctype.dtype: raise TypeError( f'Data type mismatch of variable: `{name}`: ' f'{env[name].ctype.dtype} != {iters.ctype.dtype}') body = _transpile_stmts(stmt.body, False, env) if not isinstance(iters, Range): raise NotImplementedError( 'for-loop is supported only for range iterator.') init_code = (f'{iters.ctype} ' f'__it = {iters.start.code}, ' f'__stop = {iters.stop.code}, ' f'__step = {iters.step.code}') cond = f'__step >= 0 ? __it < __stop : __it > __stop' if iters.step_is_positive is True: cond = f'__it < __stop' elif iters.step_is_positive is False: cond = f'__it > __stop' head = f'for ({init_code}; {cond}; __it += __step)' return [CodeBlock(head, [f'{name} = __it;'] + body)] if isinstance(stmt, ast.AsyncFor): raise ValueError('`async for` is not allowed.') if isinstance(stmt, ast.While): if len(stmt.orelse) > 0: raise NotImplementedError('while-else is not supported.') condition = _transpile_expr(stmt.test, env) condition = _astype_scalar(condition, _types.bool_, 'unsafe', env) condition = _to_cuda_object(condition, env) body = _transpile_stmts(stmt.body, False, env) head = f'while ({condition.code})' return [CodeBlock(head, body)] if isinstance(stmt, ast.If): condition = _transpile_expr(stmt.test, env) if is_constants([condition]): stmts = stmt.body if condition.obj else stmt.orelse return _transpile_stmts(stmts, is_toplevel, env) head = f'if ({condition.code})' then_body = _transpile_stmts(stmt.body, False, env) else_body = _transpile_stmts(stmt.orelse, False, env) return [CodeBlock(head, then_body), CodeBlock('else', else_body)] if isinstance(stmt, (ast.With, ast.AsyncWith)): raise ValueError('Switching contexts are not allowed.') if isinstance(stmt, (ast.Raise, ast.Try)): raise ValueError('throw/catch are not allowed.') if isinstance(stmt, ast.Assert): value = _transpile_expr(stmt.test, env) if is_constants([value]): assert value.obj return [';'] else: return ['assert(' + value + ');'] if isinstance(stmt, (ast.Import, ast.ImportFrom)): raise ValueError('Cannot import modules from the target functions.') if isinstance(stmt, (ast.Global, ast.Nonlocal)): raise ValueError('Cannot use global/nonlocal in the target functions.') if isinstance(stmt, ast.Expr): value = _transpile_expr(stmt.value, env) return [';'] if is_constants([value]) else [value + ';'] if isinstance(stmt, ast.Pass): return [';'] if isinstance(stmt, ast.Break): raise NotImplementedError('Not implemented.') if isinstance(stmt, ast.Continue): raise NotImplementedError('Not implemented.') assert False