def build_IfExp(ctx, node): node.test = build_stmt(ctx, node.test) node.body = build_stmt(ctx, node.body) node.orelse = build_stmt(ctx, node.orelse) if ti.is_taichi_class(node.test.ptr) or ti.is_taichi_class( node.body.ptr) or ti.is_taichi_class(node.orelse.ptr): node.ptr = ti.select(node.test.ptr, node.body.ptr, node.orelse.ptr) return node is_static_if = (IRBuilder.get_decorator(ctx, node.test) == "static") if is_static_if: if node.test.ptr: node.body = build_stmt(ctx, node.body) node.ptr = node.body.ptr else: node.orelse = build_stmt(ctx, node.orelse) node.ptr = node.orelse.ptr return node val = ti.expr_init(None) ti.begin_frontend_if(node.test.ptr) ti.core.begin_frontend_if_true() val.assign(node.body.ptr) ti.core.pop_scope() ti.core.begin_frontend_if_false() val.assign(node.orelse.ptr) ti.core.pop_scope() node.ptr = val return node
def wrapped(a, b): if ti.is_taichi_class(a): return a.element_wise_binary(foo, b) elif ti.is_taichi_class(b): return b.element_wise_binary(rev_foo, a) else: return foo(a, b)
def build_IfExp(ctx, node): build_stmt(ctx, node.test) build_stmt(ctx, node.body) build_stmt(ctx, node.orelse) if ti.is_taichi_class(node.test.ptr) or ti.is_taichi_class( node.body.ptr) or ti.is_taichi_class(node.orelse.ptr): node.ptr = ti.select(node.test.ptr, node.body.ptr, node.orelse.ptr) return node.ptr is_static_if = (ASTTransformer.get_decorator(ctx, node.test) == "static") if is_static_if: if node.test.ptr: node.ptr = build_stmt(ctx, node.body) else: node.ptr = build_stmt(ctx, node.orelse) return node.ptr val = ti.expr_init(None) ti.begin_frontend_if(node.test.ptr) _ti_core.begin_frontend_if_true() val.assign(node.body.ptr) _ti_core.pop_scope() _ti_core.begin_frontend_if_false() val.assign(node.orelse.ptr) _ti_core.pop_scope() node.ptr = val return node.ptr
def wrapped(a, b): _taichi_skip_traceback = 1 if ti.is_taichi_class(a): return a.element_wise_binary(imp_foo, b) elif ti.is_taichi_class(b): return b.element_wise_binary(rev_foo, a) else: return imp_foo(a, b)
def wrapped(a, b): if ti.is_taichi_class(a): return a.element_wise_binary(imp_foo, b) elif ti.is_taichi_class(b): raise SyntaxError( f'cannot augassign taichi class {type(b)} to scalar expr') else: return imp_foo(a, b)
def wrapped(a, b): if ti.is_taichi_class(a): return a.element_wise_binary(foo, b) elif ti.is_taichi_class(b): rev_foo = lambda x, y: foo(y, x) return b.element_wise_binary(rev_foo, a) else: return foo(Expr(a), Expr(b))
def wrapped(a, b): _taichi_skip_traceback = 1 if ti.is_taichi_class(a): return a.element_wise_writeback_binary(imp_foo, b) elif ti.is_taichi_class(b): raise TaichiSyntaxError( f'cannot augassign taichi class {type(b)} to scalar expr') else: return imp_foo(a, b)
def wrapped(a, b, c): _taichi_skip_traceback = 1 if ti.is_taichi_class(a): return a.element_wise_ternary(abc_foo, b, c) elif ti.is_taichi_class(b): return b.element_wise_ternary(bac_foo, a, c) elif ti.is_taichi_class(c): return c.element_wise_ternary(cab_foo, a, b) else: return abc_foo(a, b, c)
def __pow__(self, power, modulo=None): import taichi as ti if ti.is_taichi_class(power): return power.element_wise_binary(lambda x, y: pow(y, x), self) if not isinstance(power, int) or abs(power) > 100: return Expr(taichi_lang_core.expr_pow(self.ptr, Expr(power).ptr)) if power == 0: return Expr(1) negative = power < 0 power = abs(power) tmp = self ret = None while power: if power & 1: if ret is None: ret = tmp else: ret = ti.expr_init(ret * tmp) tmp = ti.expr_init(tmp * tmp) power >>= 1 if negative: return 1 / ret else: return ret
def wrapped(a): _taichi_skip_traceback = 1 if ti.is_taichi_class(a): return a.element_wise_unary(imp_foo) else: return imp_foo(a)
def wrapped(a): if ti.is_taichi_class(a): return a.element_wise_unary(foo) else: return foo(a)
def numpy_or_constant(x): import taichi as ti if ti.is_taichi_class(x): return x.to_numpy() else: return x
def element_wise_writeback_binary(a, foo, b): if ti.is_taichi_class(b): b = b.variable() if not isinstance(b, Complex): b = Complex(b) return Complex(foo(a.x, b.x), foo(a.y, b.y))