def build_range_for(ctx, node): # for i in range(n) node.body = build_stmts(ctx, node.body) loop_var = node.target.id ctx.check_loop_var(loop_var) template = ''' if 1: {} = ti.Expr(ti.core.make_id_expr('')) ___begin = ti.Expr(0) ___end = ti.Expr(0) ___begin = ti.cast(___begin, ti.i32) ___end = ti.cast(___end, ti.i32) ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr) ti.core.end_frontend_range_for() '''.format(loop_var, loop_var) t = ast.parse(template).body[0] assert len(node.iter.args) in [1, 2] if len(node.iter.args) == 2: bgn = build_expr(ctx, node.iter.args[0]) end = build_expr(ctx, node.iter.args[1]) else: bgn = StmtBuilder.make_constant(value=0) end = build_expr(ctx, node.iter.args[0]) t.body[1].value.args[0] = bgn t.body[2].value.args[0] = end t.body = t.body[:6] + node.body + t.body[6:] t.body.append(parse_stmt('del {}'.format(loop_var))) return ast.copy_location(t, node)
def build_AugAssign(ctx, node): node.target = build_expr(ctx, node.target) node.value = build_expr(ctx, node.value) template = 'x.augassign(0, 0)' t = ast.parse(template).body[0] t.value.func.value = node.target t.value.func.value.ctx = ast.Load() t.value.args[0] = node.value t.value.args[1] = ast.Str(s=type(node.op).__name__, ctx=ast.Load(), kind=None) return ast.copy_location(t, node)
def build_If(ctx, node): node.test = build_expr(ctx, node.test) node.body = build_stmts(ctx, node.body) node.orelse = build_stmts(ctx, node.orelse) is_static_if = isinstance(node.test, ast.Call) and isinstance( node.test.func, ast.Attribute) if is_static_if: attr = node.test.func if attr.attr == 'static': is_static_if = True else: is_static_if = False if is_static_if: # Do nothing return node template = ''' if 1: __cond = 0 ti.begin_frontend_if(__cond) ti.core.begin_frontend_if_true() ti.core.pop_scope() ti.core.begin_frontend_if_false() ti.core.pop_scope() ''' t = ast.parse(template).body[0] cond = node.test t.body[0].value = cond t.body = t.body[:5] + node.orelse + t.body[5:] t.body = t.body[:3] + node.body + t.body[3:] return ast.copy_location(t, node)
def build_Expr(ctx, node): if isinstance(node.value, ast.Call): # A function call. node.value = build_expr(ctx, node.value) # note that we can only return an ast.Expr instead of an ast.Call. return node # A statement with a single expression. # TODO(#2495): Deprecate maybe_transform_ti_func_call_to_stmt return node
def _handle_string_mod_args(ctx, msg): assert StmtBuilder._is_string_mod_args(msg) s = msg.left.s t = None if isinstance(msg.right, ast.Tuple): t = msg.right else: # assuming the format is `str % single_item` t = ast.Tuple(elts=[msg.right], ctx=ast.Load()) t = build_expr(ctx, t) return s, t
def build_Assert(ctx, node): extra_args = ast.List(elts=[], ctx=ast.Load()) if node.msg is not None: if isinstance(node.msg, ast.Constant): msg = node.msg.value elif isinstance(node.msg, ast.Str): msg = node.msg.s elif StmtBuilder._is_string_mod_args(node.msg): msg = build_expr(ctx, node.msg) msg, extra_args = StmtBuilder._handle_string_mod_args(ctx, msg) else: raise ValueError( f"assert info must be constant, not {ast.dump(node.msg)}") else: msg = astor.to_source(node.test) node.test = build_expr(ctx, node.test) new_node = parse_stmt('ti.ti_assert(0, 0, [])') new_node.value.args[0] = node.test new_node.value.args[1] = parse_expr("'{}'".format(msg.strip())) new_node.value.args[2] = extra_args new_node = ast.copy_location(new_node, node) return new_node
def build_Expr(ctx, node): if not isinstance(node.value, ast.Call): # A statement with a single expression. return node # A function call. node.value = build_expr(ctx, node.value) # Note that we can only return an ast.Expr instead of an ast.Call. if impl.get_runtime().experimental_real_function: # Generates code that inserts a FrontendExprStmt if the function # called is a Taichi function. # We cannot insert the FrontendExprStmt here because we do not # know if the function is a Taichi function now. node.value.args = [node.value.func] + node.value.args node.value.func = parse_expr('ti.insert_expr_stmt_if_ti_func') return node
def build_Return(ctx, node): node.value = build_expr(ctx, node.value) if ctx.is_kernel or impl.get_runtime().experimental_real_function: # TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not if node.value is not None: if ctx.returns is None: raise TaichiSyntaxError( f'A {"kernel" if ctx.is_kernel else "function"} ' 'with a return value must be annotated ' 'with a return type, e.g. def func() -> ti.f32') ret_expr = parse_expr('ti.cast(ti.Expr(0), 0)') ret_expr.args[0].args[0] = node.value ret_expr.args[1] = ctx.returns ret_stmt = parse_stmt('ti.core.create_kernel_return(ret.ptr)') # For args[0], it is an ast.Attribute, because it loads the # attribute, |ptr|, of the expression |ret_expr|. Therefore we # only need to replace the object part, i.e. args[0].value ret_stmt.value.args[0].value = ret_expr return ret_stmt return node
def build_Assign(ctx, node): node.value = build_expr(ctx, node.value) node.targets = build_exprs(ctx, node.targets) is_static_assign = isinstance( node.value, ast.Call) and ASTResolver.resolve_to( node.value.func, ti.static, globals()) if is_static_assign: return node # Keep all generated assign statements and compose single one at last. # The variable is introduced to support chained assignments. # Ref https://github.com/taichi-dev/taichi/issues/2659. assign_stmts = [] for node_target in node.targets: if isinstance(node_target, ast.Tuple): assign_stmts.append( StmtBuilder.build_assign_unpack(ctx, node, node_target)) else: assign_stmts.append( StmtBuilder.build_assign_basic(ctx, node, node_target, node.value)) return StmtBuilder.make_single_statement(assign_stmts)
def build_Raise(ctx, node): node.exc = build_expr(ctx, node.exc) return node
def build_Assign(ctx, node): assert (len(node.targets) == 1) node.value = build_expr(ctx, node.value) node.targets = build_exprs(ctx, node.targets) is_static_assign = isinstance( node.value, ast.Call) and ASTResolver.resolve_to( node.value.func, ti.static, globals()) if is_static_assign: return node if isinstance(node.targets[0], ast.Tuple): targets = node.targets[0].elts # Create stmts = [] holder = parse_stmt('__tmp_tuple = ti.expr_init_list(0, ' f'{len(targets)})') holder.value.args[0] = node.value stmts.append(holder) def tuple_indexed(i): indexing = parse_stmt('__tmp_tuple[0]') StmtBuilder.set_subscript_index(indexing.value, parse_expr("{}".format(i))) return indexing.value for i, target in enumerate(targets): is_local = isinstance(target, ast.Name) if is_local and ctx.is_creation(target.id): var_name = target.id target.ctx = ast.Store() # Create, no AST resolution needed init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()), attr='expr_init', ctx=ast.Load()) rhs = ast.Call( func=init, args=[tuple_indexed(i)], keywords=[], ) ctx.create_variable(var_name) stmts.append( ast.Assign(targets=[target], value=rhs, type_comment=None)) else: # Assign target.ctx = ast.Load() func = ast.Attribute(value=target, attr='assign', ctx=ast.Load()) call = ast.Call(func=func, args=[tuple_indexed(i)], keywords=[]) stmts.append(ast.Expr(value=call)) for stmt in stmts: ast.copy_location(stmt, node) stmts.append(parse_stmt('del __tmp_tuple')) return StmtBuilder.make_single_statement(stmts) else: is_local = isinstance(node.targets[0], ast.Name) if is_local and ctx.is_creation(node.targets[0].id): var_name = node.targets[0].id # Create, no AST resolution needed init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()), attr='expr_init', ctx=ast.Load()) rhs = ast.Call( func=init, args=[node.value], keywords=[], ) ctx.create_variable(var_name) return ast.copy_location( ast.Assign(targets=node.targets, value=rhs, type_comment=None), node) else: # Assign node.targets[0].ctx = ast.Load() func = ast.Attribute(value=node.targets[0], attr='assign', ctx=ast.Load()) call = ast.Call(func=func, args=[node.value], keywords=[]) return ast.copy_location(ast.Expr(value=call), node)