Exemplo n.º 1
0
    def build_range_for(ctx, node):
        # for i in range(n)
        node.body = build_stmts(ctx, node.body)
        loop_var = node.target.id
        ctx.check_loop_var(loop_var)
        template = '''
if 1:
    {} = ti.Expr(ti.core.make_id_expr(''))
    ___begin = ti.Expr(0)
    ___end = ti.Expr(0)
    ___begin = ti.cast(___begin, ti.i32)
    ___end = ti.cast(___end, ti.i32)
    ti.core.begin_frontend_range_for({}.ptr, ___begin.ptr, ___end.ptr)
    ti.core.end_frontend_range_for()
        '''.format(loop_var, loop_var)
        t = ast.parse(template).body[0]

        assert len(node.iter.args) in [1, 2]
        if len(node.iter.args) == 2:
            bgn = build_expr(ctx, node.iter.args[0])
            end = build_expr(ctx, node.iter.args[1])
        else:
            bgn = StmtBuilder.make_constant(value=0)
            end = build_expr(ctx, node.iter.args[0])

        t.body[1].value.args[0] = bgn
        t.body[2].value.args[0] = end
        t.body = t.body[:6] + node.body + t.body[6:]
        t.body.append(parse_stmt('del {}'.format(loop_var)))
        return ast.copy_location(t, node)
Exemplo n.º 2
0
 def build_AugAssign(ctx, node):
     node.target = build_expr(ctx, node.target)
     node.value = build_expr(ctx, node.value)
     template = 'x.augassign(0, 0)'
     t = ast.parse(template).body[0]
     t.value.func.value = node.target
     t.value.func.value.ctx = ast.Load()
     t.value.args[0] = node.value
     t.value.args[1] = ast.Str(s=type(node.op).__name__,
                               ctx=ast.Load(),
                               kind=None)
     return ast.copy_location(t, node)
Exemplo n.º 3
0
    def build_If(ctx, node):
        node.test = build_expr(ctx, node.test)
        node.body = build_stmts(ctx, node.body)
        node.orelse = build_stmts(ctx, node.orelse)

        is_static_if = isinstance(node.test, ast.Call) and isinstance(
            node.test.func, ast.Attribute)
        if is_static_if:
            attr = node.test.func
            if attr.attr == 'static':
                is_static_if = True
            else:
                is_static_if = False

        if is_static_if:
            # Do nothing
            return node

        template = '''
if 1:
  __cond = 0
  ti.begin_frontend_if(__cond)
  ti.core.begin_frontend_if_true()
  ti.core.pop_scope()
  ti.core.begin_frontend_if_false()
  ti.core.pop_scope()
'''
        t = ast.parse(template).body[0]
        cond = node.test
        t.body[0].value = cond
        t.body = t.body[:5] + node.orelse + t.body[5:]
        t.body = t.body[:3] + node.body + t.body[3:]
        return ast.copy_location(t, node)
Exemplo n.º 4
0
 def build_Expr(ctx, node):
     if isinstance(node.value, ast.Call):
         # A function call.
         node.value = build_expr(ctx, node.value)
         # note that we can only return an ast.Expr instead of an ast.Call.
         return node
     # A statement with a single expression.
     # TODO(#2495): Deprecate maybe_transform_ti_func_call_to_stmt
     return node
Exemplo n.º 5
0
 def _handle_string_mod_args(ctx, msg):
     assert StmtBuilder._is_string_mod_args(msg)
     s = msg.left.s
     t = None
     if isinstance(msg.right, ast.Tuple):
         t = msg.right
     else:
         # assuming the format is `str % single_item`
         t = ast.Tuple(elts=[msg.right], ctx=ast.Load())
     t = build_expr(ctx, t)
     return s, t
Exemplo n.º 6
0
    def build_Assert(ctx, node):
        extra_args = ast.List(elts=[], ctx=ast.Load())
        if node.msg is not None:
            if isinstance(node.msg, ast.Constant):
                msg = node.msg.value
            elif isinstance(node.msg, ast.Str):
                msg = node.msg.s
            elif StmtBuilder._is_string_mod_args(node.msg):
                msg = build_expr(ctx, node.msg)
                msg, extra_args = StmtBuilder._handle_string_mod_args(ctx, msg)
            else:
                raise ValueError(
                    f"assert info must be constant, not {ast.dump(node.msg)}")
        else:
            msg = astor.to_source(node.test)
        node.test = build_expr(ctx, node.test)

        new_node = parse_stmt('ti.ti_assert(0, 0, [])')
        new_node.value.args[0] = node.test
        new_node.value.args[1] = parse_expr("'{}'".format(msg.strip()))
        new_node.value.args[2] = extra_args
        new_node = ast.copy_location(new_node, node)
        return new_node
Exemplo n.º 7
0
    def build_Expr(ctx, node):
        if not isinstance(node.value, ast.Call):
            # A statement with a single expression.
            return node

        # A function call.
        node.value = build_expr(ctx, node.value)
        # Note that we can only return an ast.Expr instead of an ast.Call.

        if impl.get_runtime().experimental_real_function:
            # Generates code that inserts a FrontendExprStmt if the function
            # called is a Taichi function.
            # We cannot insert the FrontendExprStmt here because we do not
            # know if the function is a Taichi function now.
            node.value.args = [node.value.func] + node.value.args
            node.value.func = parse_expr('ti.insert_expr_stmt_if_ti_func')
        return node
Exemplo n.º 8
0
 def build_Return(ctx, node):
     node.value = build_expr(ctx, node.value)
     if ctx.is_kernel or impl.get_runtime().experimental_real_function:
         # TODO: check if it's at the end of a kernel, throw TaichiSyntaxError if not
         if node.value is not None:
             if ctx.returns is None:
                 raise TaichiSyntaxError(
                     f'A {"kernel" if ctx.is_kernel else "function"} '
                     'with a return value must be annotated '
                     'with a return type, e.g. def func() -> ti.f32')
             ret_expr = parse_expr('ti.cast(ti.Expr(0), 0)')
             ret_expr.args[0].args[0] = node.value
             ret_expr.args[1] = ctx.returns
             ret_stmt = parse_stmt('ti.core.create_kernel_return(ret.ptr)')
             # For args[0], it is an ast.Attribute, because it loads the
             # attribute, |ptr|, of the expression |ret_expr|. Therefore we
             # only need to replace the object part, i.e. args[0].value
             ret_stmt.value.args[0].value = ret_expr
             return ret_stmt
     return node
Exemplo n.º 9
0
    def build_Assign(ctx, node):
        node.value = build_expr(ctx, node.value)
        node.targets = build_exprs(ctx, node.targets)

        is_static_assign = isinstance(
            node.value, ast.Call) and ASTResolver.resolve_to(
                node.value.func, ti.static, globals())
        if is_static_assign:
            return node

        # Keep all generated assign statements and compose single one at last.
        # The variable is introduced to support chained assignments.
        # Ref https://github.com/taichi-dev/taichi/issues/2659.
        assign_stmts = []
        for node_target in node.targets:
            if isinstance(node_target, ast.Tuple):
                assign_stmts.append(
                    StmtBuilder.build_assign_unpack(ctx, node, node_target))
            else:
                assign_stmts.append(
                    StmtBuilder.build_assign_basic(ctx, node, node_target,
                                                   node.value))
        return StmtBuilder.make_single_statement(assign_stmts)
Exemplo n.º 10
0
 def build_Raise(ctx, node):
     node.exc = build_expr(ctx, node.exc)
     return node
Exemplo n.º 11
0
    def build_Assign(ctx, node):
        assert (len(node.targets) == 1)
        node.value = build_expr(ctx, node.value)
        node.targets = build_exprs(ctx, node.targets)

        is_static_assign = isinstance(
            node.value, ast.Call) and ASTResolver.resolve_to(
                node.value.func, ti.static, globals())
        if is_static_assign:
            return node

        if isinstance(node.targets[0], ast.Tuple):
            targets = node.targets[0].elts

            # Create
            stmts = []

            holder = parse_stmt('__tmp_tuple = ti.expr_init_list(0, '
                                f'{len(targets)})')
            holder.value.args[0] = node.value

            stmts.append(holder)

            def tuple_indexed(i):
                indexing = parse_stmt('__tmp_tuple[0]')
                StmtBuilder.set_subscript_index(indexing.value,
                                                parse_expr("{}".format(i)))
                return indexing.value

            for i, target in enumerate(targets):
                is_local = isinstance(target, ast.Name)
                if is_local and ctx.is_creation(target.id):
                    var_name = target.id
                    target.ctx = ast.Store()
                    # Create, no AST resolution needed
                    init = ast.Attribute(value=ast.Name(id='ti',
                                                        ctx=ast.Load()),
                                         attr='expr_init',
                                         ctx=ast.Load())
                    rhs = ast.Call(
                        func=init,
                        args=[tuple_indexed(i)],
                        keywords=[],
                    )
                    ctx.create_variable(var_name)
                    stmts.append(
                        ast.Assign(targets=[target],
                                   value=rhs,
                                   type_comment=None))
                else:
                    # Assign
                    target.ctx = ast.Load()
                    func = ast.Attribute(value=target,
                                         attr='assign',
                                         ctx=ast.Load())
                    call = ast.Call(func=func,
                                    args=[tuple_indexed(i)],
                                    keywords=[])
                    stmts.append(ast.Expr(value=call))

            for stmt in stmts:
                ast.copy_location(stmt, node)
            stmts.append(parse_stmt('del __tmp_tuple'))
            return StmtBuilder.make_single_statement(stmts)
        else:
            is_local = isinstance(node.targets[0], ast.Name)
            if is_local and ctx.is_creation(node.targets[0].id):
                var_name = node.targets[0].id
                # Create, no AST resolution needed
                init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()),
                                     attr='expr_init',
                                     ctx=ast.Load())
                rhs = ast.Call(
                    func=init,
                    args=[node.value],
                    keywords=[],
                )
                ctx.create_variable(var_name)
                return ast.copy_location(
                    ast.Assign(targets=node.targets,
                               value=rhs,
                               type_comment=None), node)
            else:
                # Assign
                node.targets[0].ctx = ast.Load()
                func = ast.Attribute(value=node.targets[0],
                                     attr='assign',
                                     ctx=ast.Load())
                call = ast.Call(func=func, args=[node.value], keywords=[])
                return ast.copy_location(ast.Expr(value=call), node)