def test_ast_resolver_alias(): import taichi taichi.init() node = ast.parse('taichi.kernel', mode='eval').body assert ASTResolver.resolve_to(node, taichi.kernel, locals()) import taichi as tc node = ast.parse('tc.kernel', mode='eval').body assert ASTResolver.resolve_to(node, tc.kernel, locals())
def test_ast_resolver_wrong_ti(): import taichi taichi.init() fake_ti = namedtuple('FakeTi', ['kernel']) ti = fake_ti(kernel='fake') node = ast.parse('ti.kernel', mode='eval').body assert not ASTResolver.resolve_to(node, taichi.kernel, locals())
def build_NamedExpr(ctx, node): node.value = build_stmt(ctx, node.value) node.target = build_stmt(ctx, node.target) is_static_assign = isinstance( node.value, ast.Call) and ASTResolver.resolve_to( node.value.func, ti.static, globals()) node.ptr = IRBuilder.build_assign_basic(ctx, node.target, node.value.ptr, is_static_assign) return node
def get_decorator(ctx, node): if not isinstance(node, ast.Call): return '' for wanted, name in [ (impl.static, 'static'), (impl.grouped, 'grouped'), (ndrange, 'ndrange'), ]: if ASTResolver.resolve_to(node.func, wanted, ctx.global_vars): return name return ''
def get_decorator(node): if not isinstance(node, ast.Call): return '' for wanted, name in [ (ti.static, 'static'), (ti.grouped, 'grouped'), (ti.ndrange, 'ndrange'), ]: if ASTResolver.resolve_to(node.func, wanted, globals()): return name return ''
def build_AnnAssign(ctx, node): build_stmt(ctx, node.value) build_stmt(ctx, node.target) build_stmt(ctx, node.annotation) is_static_assign = isinstance( node.value, ast.Call) and ASTResolver.resolve_to( node.value.func, ti.static, globals()) node.ptr = ASTTransformer.build_assign_annotated( ctx, node.target, node.value.ptr, is_static_assign, node.annotation.ptr) return node.ptr
def build_Call(ctx, node): if ASTResolver.resolve_to(node.func, ti.static, globals()): # Do not modify the expression if the function called is ti.static return node node.func = build_expr(ctx, node.func) node.args = build_exprs(ctx, node.args) for i in range(len(node.keywords)): node.keywords[i].value = build_expr(ctx, node.keywords[i].value) if isinstance(node.func, ast.Attribute): attr_name = node.func.attr if attr_name == 'format': node.args.insert(0, node.func.value) node.func = parse_expr('ti.ti_format') if isinstance(node.func, ast.Name): func_name = node.func.id if func_name == 'print': node.func = parse_expr('ti.ti_print') elif func_name == 'min': node.func = parse_expr('ti.ti_min') elif func_name == 'max': node.func = parse_expr('ti.ti_max') elif func_name == 'int': node.func = parse_expr('ti.ti_int') elif func_name == 'float': node.func = parse_expr('ti.ti_float') elif func_name == 'any': node.func = parse_expr('ti.ti_any') elif func_name == 'all': node.func = parse_expr('ti.ti_all') else: pass _taichi_skip_traceback = 1 ti_func = node.func if '_sitebuiltins' == getattr(ti_func, '__module__', '') and getattr( getattr(ti_func, '__class__', ''), '__name__', '') == 'Quitter': raise TaichiSyntaxError( f'exit or quit not supported in Taichi-scope') if getattr(ti_func, '__module__', '') == '__main__' and not getattr( ti_func, '__wrapped__', ''): warnings.warn( f'Calling into non-Taichi function {ti_func.__name__}.' ' This means that scope inside that function will not be processed' ' by the Taichi transformer. Proceed with caution! ' ' Maybe you want to decorate it with @ti.func?', UserWarning, stacklevel=2) return node
def build_Assign(ctx, node): node.value = build_stmt(ctx, node.value) node.targets = build_stmts(ctx, node.targets) is_static_assign = isinstance( node.value, ast.Call) and ASTResolver.resolve_to( node.value.func, ti.static, globals()) # Keep all generated assign statements and compose single one at last. # The variable is introduced to support chained assignments. # Ref https://github.com/taichi-dev/taichi/issues/2659. for node_target in node.targets: IRBuilder.build_assign_unpack(ctx, node_target, node.value.ptr, is_static_assign) return node
def build_Assign(ctx, node): node.value = build_expr(ctx, node.value) node.targets = build_exprs(ctx, node.targets) is_static_assign = isinstance( node.value, ast.Call) and ASTResolver.resolve_to( node.value.func, ti.static, globals()) if is_static_assign: return node # Keep all generated assign statements and compose single one at last. # The variable is introduced to support chained assignments. # Ref https://github.com/taichi-dev/taichi/issues/2659. assign_stmts = [] for node_target in node.targets: if isinstance(node_target, ast.Tuple): assign_stmts.append( StmtBuilder.build_assign_unpack(ctx, node, node_target)) else: assign_stmts.append( StmtBuilder.build_assign_basic(ctx, node, node_target, node.value)) return StmtBuilder.make_single_statement(assign_stmts)
def build_FunctionDef(ctx, node): args = node.args assert args.vararg is None assert args.kwonlyargs == [] assert args.kw_defaults == [] assert args.kwarg is None arg_decls = [] def transform_as_kernel(): # Treat return type if node.returns is not None: ret_init = parse_stmt( 'ti.lang.kernel_arguments.decl_scalar_ret(0)') ret_init.value.args[0] = node.returns ctx.returns = node.returns arg_decls.append(ret_init) node.returns = None for i, arg in enumerate(args.args): # Directly pass in template arguments, # such as class instances ("self"), fields, SNodes, etc. if isinstance(ctx.func.argument_annotations[i], ti.template): continue if isinstance(ctx.func.argument_annotations[i], ti.linalg.sparse_matrix_builder): arg_init = parse_stmt( 'x = ti.lang.kernel_arguments.decl_sparse_matrix()') arg_init.targets[0].id = arg.arg ctx.create_variable(arg.arg) arg_decls.append(arg_init) elif isinstance(ctx.func.argument_annotations[i], ti.any_arr): arg_init = parse_stmt( 'x = ti.lang.kernel_arguments.decl_any_arr_arg(0, 0, 0, 0)' ) arg_init.targets[0].id = arg.arg ctx.create_variable(arg.arg) array_dt = ctx.arg_features[i][0] array_dim = ctx.arg_features[i][1] array_element_shape = ctx.arg_features[i][2] array_layout = ctx.arg_features[i][3] array_dt = to_taichi_type(array_dt) dt_expr = 'ti.' + ti.core.data_type_name(array_dt) dt = parse_expr(dt_expr) arg_init.value.args[0] = dt arg_init.value.args[1] = parse_expr("{}".format(array_dim)) arg_init.value.args[2] = parse_expr( "{}".format(array_element_shape)) arg_init.value.args[3] = parse_expr( "ti.{}".format(array_layout)) arg_decls.append(arg_init) else: arg_init = parse_stmt( 'x = ti.lang.kernel_arguments.decl_scalar_arg(0)') arg_init.targets[0].id = arg.arg dt = arg.annotation arg_init.value.args[0] = dt arg_decls.append(arg_init) # remove original args node.args.args = [] if ctx.is_kernel: # ti.kernel for decorator in node.decorator_list: if ASTResolver.resolve_to(decorator, ti.func, globals()): raise TaichiSyntaxError( "Function definition not allowed in 'ti.kernel'.") transform_as_kernel() else: # ti.func for decorator in node.decorator_list: if ASTResolver.resolve_to(decorator, ti.func, globals()): raise TaichiSyntaxError( "Function definition not allowed in 'ti.func'.") if impl.get_runtime().experimental_real_function: transform_as_kernel() else: # Transform as force-inlined func arg_decls = [] for i, arg in enumerate(args.args): # Remove annotations because they are not used. args.args[i].annotation = None # Template arguments are passed by reference. if isinstance(ctx.func.argument_annotations[i], ti.template): ctx.create_variable(ctx.func.argument_names[i]) continue # Create a copy for non-template arguments, # so that they are passed by value. arg_init = parse_stmt('x = ti.expr_init_func(0)') arg_init.targets[0].id = arg.arg ctx.create_variable(arg.arg) arg_init.value.args[0] = parse_expr(arg.arg + '_by_value__') args.args[i].arg += '_by_value__' arg_decls.append(arg_init) with ctx.variable_scope_guard(): node.body = build_stmts(ctx, node.body) node.body = arg_decls + node.body node.body = [parse_stmt('import taichi as ti')] + node.body return node
def build_FunctionDef(ctx, node): args = node.args assert args.vararg is None assert args.kwonlyargs == [] assert args.kw_defaults == [] assert args.kwarg is None def transform_as_kernel(): # Treat return type if node.returns is not None: ti.lang.kernel_arguments.decl_scalar_ret(ctx.func.return_type) for i, arg in enumerate(args.args): if isinstance(ctx.func.argument_annotations[i], ti.template): continue elif isinstance(ctx.func.argument_annotations[i], ti.linalg.sparse_matrix_builder): ctx.create_variable( arg.arg, ti.lang.kernel_arguments.decl_sparse_matrix()) elif isinstance(ctx.func.argument_annotations[i], ti.any_arr): ctx.create_variable( arg.arg, ti.lang.kernel_arguments.decl_any_arr_arg( to_taichi_type(ctx.arg_features[i][0]), ctx.arg_features[i][1], ctx.arg_features[i][2], ctx.arg_features[i][3])) else: ctx.global_vars[ arg.arg] = ti.lang.kernel_arguments.decl_scalar_arg( ctx.func.argument_annotations[i]) # remove original args node.args.args = [] if ctx.is_kernel: # ti.kernel for decorator in node.decorator_list: if ASTResolver.resolve_to(decorator, ti.func, globals()): raise TaichiSyntaxError( "Function definition not allowed in 'ti.kernel'.") transform_as_kernel() else: # ti.func for decorator in node.decorator_list: if ASTResolver.resolve_to(decorator, ti.func, globals()): raise TaichiSyntaxError( "Function definition not allowed in 'ti.func'.") if impl.get_runtime().experimental_real_function: transform_as_kernel() else: len_args = len(args.args) len_default = len(args.defaults) len_provided = len(ctx.argument_data) len_minimum = len_args - len_default if len_args < len_provided or len_args - len_default > len_provided: if len(args.defaults): raise TaichiSyntaxError( f"Function receives {len_minimum} to {len_args} argument(s) and {len_provided} provided." ) else: raise TaichiSyntaxError( f"Function receives {len_args} argument(s) and {len_provided} provided." ) # Transform as force-inlined func default_start = len_provided - len_minimum ctx.argument_data = list(ctx.argument_data) for arg in args.defaults[default_start:]: ctx.argument_data.append(build_stmt(ctx, arg).ptr) assert len(args.args) == len(ctx.argument_data) for i, (arg, data) in enumerate(zip(args.args, ctx.argument_data)): # Remove annotations because they are not used. args.args[i].annotation = None # Template arguments are passed by reference. if isinstance(ctx.func.argument_annotations[i], ti.template): ctx.create_variable(ctx.func.argument_names[i], data) continue # Create a copy for non-template arguments, # so that they are passed by value. ctx.create_variable(arg.arg, ti.expr_init_func(data)) with ctx.variable_scope_guard(): build_stmts(ctx, node.body) return node
def build_FunctionDef(ctx, node): args = node.args assert args.vararg is None assert args.kwonlyargs == [] assert args.kw_defaults == [] assert args.kwarg is None arg_decls = [] def transform_as_kernel(): # Treat return type if node.returns is not None: node.returns = build_stmt(ctx, node.returns) ti.lang.kernel_arguments.decl_scalar_ret(node.returns.ptr) ctx.returns = node.returns.ptr for i, arg in enumerate(args.args): # Directly pass in template arguments, # such as class instances ("self"), fields, SNodes, etc. # if isinstance(ctx.func.argument_annotations[i], ti.template): # continue # if isinstance(ctx.func.argument_annotations[i], # ti.sparse_matrix_builder): # arg_init = parse_stmt( # 'x = ti.lang.kernel_arguments.decl_sparse_matrix()') # arg_init.targets[0].id = arg.arg # ctx.create_variable(arg.arg) # arg_decls.append(arg_init) # elif isinstance(ctx.func.argument_annotations[i], ti.any_arr): # arg_init = parse_stmt( # 'x = ti.lang.kernel_arguments.decl_any_arr_arg(0, 0, 0, 0)' # ) # arg_init.targets[0].id = arg.arg # ctx.create_variable(arg.arg) # array_dt = ctx.arg_features[i][0] # array_dim = ctx.arg_features[i][1] # array_element_shape = ctx.arg_features[i][2] # array_layout = ctx.arg_features[i][3] # array_dt = to_taichi_type(array_dt) # dt_expr = 'ti.' + ti.core.data_type_name(array_dt) # dt = parse_expr(dt_expr) # arg_init.value.args[0] = dt # arg_init.value.args[1] = parse_expr("{}".format(array_dim)) # arg_init.value.args[2] = parse_expr( # "{}".format(array_element_shape)) # arg_init.value.args[3] = parse_expr( # "ti.{}".format(array_layout)) # arg_decls.append(arg_init) if isinstance(ctx.func.argument_annotations[i], ti.template): continue else: arg.annotation = build_stmt(ctx, arg.annotation) ctx.create_variable( arg.arg, ti.lang.kernel_arguments.decl_scalar_arg( arg.annotation.ptr)) # remove original args node.args.args = [] if ctx.is_kernel: # ti.kernel for decorator in node.decorator_list: if ASTResolver.resolve_to(decorator, ti.func, globals()): raise TaichiSyntaxError( "Function definition not allowed in 'ti.kernel'.") transform_as_kernel() else: # ti.func for decorator in node.decorator_list: if ASTResolver.resolve_to(decorator, ti.func, globals()): raise TaichiSyntaxError( "Function definition not allowed in 'ti.func'.") # if impl.get_runtime().experimental_real_function: # transform_as_kernel() if False: pass else: if len(args.args) != len(ctx.argument_data): raise TaichiSyntaxError("Function argument of ") # Transform as force-inlined func for i, (arg, data) in enumerate(zip(args.args, ctx.argument_data)): # Remove annotations because they are not used. args.args[i].annotation = None # Template arguments are passed by reference. if isinstance(ctx.func.argument_annotations[i], ti.template): ctx.create_variable(ctx.func.argument_names[i], data) continue # Create a copy for non-template arguments, # so that they are passed by value. ctx.create_variable(arg.arg, ti.expr_init_func(data)) with ctx.variable_scope_guard(): build_stmts(ctx, node.body) return node
def test_ast_resolver_basic(): # import within the function to avoid polluting the global scope import taichi as ti ti.init() node = ast.parse('ti.kernel', mode='eval').body assert ASTResolver.resolve_to(node, ti.kernel, locals())
def test_ast_resolver_chain(): import taichi as ti ti.init() node = ast.parse('ti.lang.ops.atomic_add', mode='eval').body assert ASTResolver.resolve_to(node, ti.atomic_add, locals())
def test_ast_resolver_direct_import(): import taichi as ti ti.init() from taichi import kernel node = ast.parse('kernel', mode='eval').body assert ASTResolver.resolve_to(node, kernel, locals())