예제 #1
0
파일: impl.py 프로젝트: k-ye/taichi
def create_field_member(dtype, name, needs_grad, needs_dual):
    dtype = cook_dtype(dtype)

    # primal
    prog = get_runtime().prog
    if prog is None:
        raise TaichiRuntimeError(
            "Cannont create field, maybe you forgot to call `ti.init()` first?"
        )

    x = Expr(prog.make_id_expr(""))
    x.declaration_tb = get_traceback(stacklevel=4)
    x.ptr = _ti_core.global_new(x.ptr, dtype)
    x.ptr.set_name(name)
    x.ptr.set_is_primal(True)
    pytaichi.global_vars.append(x)

    x_grad = None
    x_dual = None
    if _ti_core.is_real(dtype):
        # adjoint
        x_grad = Expr(get_runtime().prog.make_id_expr(""))
        x_grad.declaration_tb = get_traceback(stacklevel=4)
        x_grad.ptr = _ti_core.global_new(x_grad.ptr, dtype)
        x_grad.ptr.set_name(name + ".grad")
        x_grad.ptr.set_is_primal(False)
        x.ptr.set_adjoint(x_grad.ptr)
        if needs_grad:
            pytaichi.grad_vars.append(x_grad)

        # dual
        x_dual = Expr(get_runtime().prog.make_id_expr(""))
        x_dual.ptr = _ti_core.global_new(x_dual.ptr, dtype)
        x_dual.ptr.set_name(name + ".dual")
        x_dual.ptr.set_is_primal(False)
        x.ptr.set_dual(x_dual.ptr)
        if needs_dual:
            pytaichi.dual_vars.append(x_dual)
    elif needs_grad or needs_dual:
        raise TaichiRuntimeError(
            f'{dtype} is not supported for field with `needs_grad=True` or `needs_dual=True`.'
        )

    return x, x_grad, x_dual
예제 #2
0
def create_field_member(dtype, name):
    dtype = cook_dtype(dtype)

    # primal
    x = Expr(_ti_core.make_id_expr(""))
    x.declaration_tb = get_traceback(stacklevel=4)
    x.ptr = _ti_core.global_new(x.ptr, dtype)
    x.ptr.set_name(name)
    x.ptr.set_is_primal(True)
    pytaichi.global_vars.append(x)

    x_grad = None
    if _ti_core.needs_grad(dtype):
        # adjoint
        x_grad = Expr(_ti_core.make_id_expr(""))
        x_grad.ptr = _ti_core.global_new(x_grad.ptr, dtype)
        x_grad.ptr.set_name(name + ".grad")
        x_grad.ptr.set_is_primal(False)
        x.ptr.set_grad(x_grad.ptr)

    return x, x_grad