def _transpile_assign_stmt(target, env, value, is_toplevel, depth=0): if isinstance(target, ast.Name): name = target.id if env[name] is None: env.locals[name] = Data(name, value.ctype) if is_toplevel and depth == 0: return [value.ctype.declvar(name, value) + ';'] env.decls[name] = Data(name, value.ctype) return _emit_assign_stmt(env[name], value, env) if isinstance(target, ast.Subscript): target = _transpile_expr(target, env) return _emit_assign_stmt(target, value, env) if isinstance(target, ast.Tuple): if not isinstance(value.ctype, _cuda_types.Tuple): raise ValueError(f'{value.ctype} cannot be unpack') size = len(target.elts) if len(value.ctype.types) > size: raise ValueError(f'too many values to unpack (expected {size})') if len(value.ctype.types) < size: raise ValueError(f'not enough values to unpack (expected {size})') codes = [value.ctype.declvar(f'_temp{depth}', value) + ';'] for i in range(size): code = f'thrust::get<{i}>(_temp{depth})' ctype = value.ctype.types[i] stmt = _transpile_assign_stmt(target.elts[i], env, Data(code, ctype), is_toplevel, depth + 1) codes.extend(stmt) return [CodeBlock('', codes)]
def _transpile_function_internal(func, attributes, mode, consts, in_types, ret_type): consts = dict([(k, Constant(v)) for k, v, in consts.items()]) if not isinstance(func, ast.FunctionDef): # TODO(asi1024): Support for `ast.ClassDef`. raise NotImplementedError('Not supported: {}'.format(type(func))) if len(func.decorator_list) > 0: if sys.version_info >= (3, 9): # Code path for Python versions that support `ast.unparse`. for deco in func.decorator_list: deco_code = ast.unparse(deco) if not any(word in deco_code for word in ['rawkernel', 'vectorize']): warnings.warn( f'Decorator {deco_code} may not supported in JIT.', RuntimeWarning) arguments = func.args if arguments.vararg is not None: raise NotImplementedError('`*args` is not supported currently.') if len(arguments.kwonlyargs) > 0: # same length with `kw_defaults`. raise NotImplementedError( 'keyword only arguments are not supported currently .') if arguments.kwarg is not None: raise NotImplementedError('`**kwargs` is not supported currently.') if len(arguments.defaults) > 0: raise NotImplementedError( 'Default values are not supported currently.') args = [arg.arg for arg in arguments.args] if len(args) != len(in_types): raise TypeError( f'{func.name}() takes {len(args)} positional arguments ' f'but {len(in_types)} were given.') params = dict([(x, CudaObject(x, t)) for x, t in zip(args, in_types)]) env = Environment(mode, consts, params, ret_type) body = _transpile_stmts(func.body, True, env) params = ', '.join([env[a].ctype.declvar(a) for a in args]) local_vars = [v.ctype.declvar(n) + ';' for n, v in env.locals.items()] if env.ret_type is None: env.ret_type = _types.Void() head = f'{attributes} {env.ret_type} {func.name}({params})' code = CodeBlock(head, local_vars + body) return str(code), env
def _transpile_stmt(stmt, is_toplevel, env): """Transpile the statement. Returns (list of [CodeBlock or str]): The generated CUDA code. """ if isinstance(stmt, ast.ClassDef): raise NotImplementedError('class is not supported currently.') if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)): raise NotImplementedError( 'Nested functions are not supported currently.') if isinstance(stmt, ast.Return): value = _transpile_expr(stmt.value, env) value = Data.init(value, env) t = value.ctype if env.ret_type is None: env.ret_type = t elif env.ret_type != t: raise ValueError( f'Failed to infer the return type: {env.ret_type} or {t}') return [f'return {value.code};'] if isinstance(stmt, ast.Delete): raise NotImplementedError('`del` is not supported currently.') if isinstance(stmt, ast.Assign): if len(stmt.targets) != 1: raise NotImplementedError('Not implemented.') value = _transpile_expr(stmt.value, env) target = stmt.targets[0] if is_constants(value) and isinstance(target, ast.Name): name = target.id if not isinstance(value.obj, _typeclasses): if is_toplevel: if env[name] is not None and not is_constants(env[name]): raise TypeError(f'Type mismatch of variable: `{name}`') env.consts[name] = value return [] else: raise TypeError( 'Cannot assign constant value not at top-level.') value = Data.init(value, env) return _transpile_assign_stmt(target, env, value, is_toplevel) if isinstance(stmt, ast.AugAssign): value = _transpile_expr(stmt.value, env) target = _transpile_expr(stmt.target, env) assert isinstance(target, Data) value = Data.init(value, env) result = _eval_operand(stmt.op, (target, value), env) if not numpy.can_cast(result.ctype.dtype, target.ctype.dtype, 'same_kind'): raise TypeError('dtype mismatch') return [target.ctype.assign(target, result) + ';'] if isinstance(stmt, ast.For): if len(stmt.orelse) > 0: raise NotImplementedError('while-else is not supported.') name = stmt.target.id iters = _transpile_expr(stmt.iter, env) if env[name] is None: var = Data(stmt.target.id, iters.ctype) env.locals[name] = var env.decls[name] = var elif env[name].ctype.dtype != iters.ctype.dtype: raise TypeError(f'Data type mismatch of variable: `{name}`: ' f'{env[name].ctype.dtype} != {iters.ctype.dtype}') body = _transpile_stmts(stmt.body, False, env) if not isinstance(iters, _internal_types.Range): raise NotImplementedError( 'for-loop is supported only for range iterator.') init_code = (f'{iters.ctype} ' f'__it = {iters.start.code}, ' f'__stop = {iters.stop.code}, ' f'__step = {iters.step.code}') cond = '__step >= 0 ? __it < __stop : __it > __stop' if iters.step_is_positive is True: cond = '__it < __stop' elif iters.step_is_positive is False: cond = '__it > __stop' head = f'for ({init_code}; {cond}; __it += __step)' return [CodeBlock(head, [f'{name} = __it;'] + body)] if isinstance(stmt, ast.AsyncFor): raise ValueError('`async for` is not allowed.') if isinstance(stmt, ast.While): if len(stmt.orelse) > 0: raise NotImplementedError('while-else is not supported.') condition = _transpile_expr(stmt.test, env) condition = _astype_scalar(condition, _cuda_types.bool_, 'unsafe', env) condition = Data.init(condition, env) body = _transpile_stmts(stmt.body, False, env) head = f'while ({condition.code})' return [CodeBlock(head, body)] if isinstance(stmt, ast.If): condition = _transpile_expr(stmt.test, env) if is_constants(condition): stmts = stmt.body if condition.obj else stmt.orelse return _transpile_stmts(stmts, is_toplevel, env) head = f'if ({condition.code})' then_body = _transpile_stmts(stmt.body, False, env) else_body = _transpile_stmts(stmt.orelse, False, env) return [CodeBlock(head, then_body), CodeBlock('else', else_body)] if isinstance(stmt, (ast.With, ast.AsyncWith)): raise ValueError('Switching contexts are not allowed.') if isinstance(stmt, (ast.Raise, ast.Try)): raise ValueError('throw/catch are not allowed.') if isinstance(stmt, ast.Assert): value = _transpile_expr(stmt.test, env) if is_constants(value): assert value.obj return [';'] else: return ['assert(' + value + ');'] if isinstance(stmt, (ast.Import, ast.ImportFrom)): raise ValueError('Cannot import modules from the target functions.') if isinstance(stmt, (ast.Global, ast.Nonlocal)): raise ValueError('Cannot use global/nonlocal in the target functions.') if isinstance(stmt, ast.Expr): value = _transpile_expr(stmt.value, env) return [';'] if is_constants(value) else [value.code + ';'] if isinstance(stmt, ast.Pass): return [';'] if isinstance(stmt, ast.Break): raise NotImplementedError('Not implemented.') if isinstance(stmt, ast.Continue): raise NotImplementedError('Not implemented.') assert False