Beispiel #1
0
    def __init__(self, snode):
        if _ti_core.is_real(snode.data_type()):

            def getter(*key):
                assert len(key) == _ti_core.get_max_num_indices()
                return snode.read_float(key)

            def setter(value, *key):
                assert len(key) == _ti_core.get_max_num_indices()
                snode.write_float(key, value)
        else:
            if _ti_core.is_signed(snode.data_type()):

                def getter(*key):
                    assert len(key) == _ti_core.get_max_num_indices()
                    return snode.read_int(key)
            else:

                def getter(*key):
                    assert len(key) == _ti_core.get_max_num_indices()
                    return snode.read_uint(key)

            def setter(value, *key):
                assert len(key) == _ti_core.get_max_num_indices()
                snode.write_int(key, value)

        self.getter = getter
        self.setter = setter
Beispiel #2
0
def create_field_member(dtype, name, needs_grad, needs_dual):
    dtype = cook_dtype(dtype)

    # primal
    prog = get_runtime().prog
    if prog is None:
        raise TaichiRuntimeError(
            "Cannont create field, maybe you forgot to call `ti.init()` first?"
        )

    x = Expr(prog.make_id_expr(""))
    x.declaration_tb = get_traceback(stacklevel=4)
    x.ptr = _ti_core.global_new(x.ptr, dtype)
    x.ptr.set_name(name)
    x.ptr.set_is_primal(True)
    pytaichi.global_vars.append(x)

    x_grad = None
    x_dual = None
    if _ti_core.is_real(dtype):
        # adjoint
        x_grad = Expr(get_runtime().prog.make_id_expr(""))
        x_grad.declaration_tb = get_traceback(stacklevel=4)
        x_grad.ptr = _ti_core.global_new(x_grad.ptr, dtype)
        x_grad.ptr.set_name(name + ".grad")
        x_grad.ptr.set_is_primal(False)
        x.ptr.set_adjoint(x_grad.ptr)
        if needs_grad:
            pytaichi.grad_vars.append(x_grad)

        # dual
        x_dual = Expr(get_runtime().prog.make_id_expr(""))
        x_dual.ptr = _ti_core.global_new(x_dual.ptr, dtype)
        x_dual.ptr.set_name(name + ".dual")
        x_dual.ptr.set_is_primal(False)
        x.ptr.set_dual(x_dual.ptr)
        if needs_dual:
            pytaichi.dual_vars.append(x_dual)
    elif needs_grad or needs_dual:
        raise TaichiRuntimeError(
            f'{dtype} is not supported for field with `needs_grad=True` or `needs_dual=True`.'
        )

    return x, x_grad, x_dual
Beispiel #3
0
    def __init__(self, ndarray):
        if _ti_core.is_real(ndarray.dtype):

            def getter(*key):
                return ndarray.read_float(key)

            def setter(value, *key):
                ndarray.write_float(key, value)
        else:
            if _ti_core.is_signed(ndarray.dtype):

                def getter(*key):
                    return ndarray.read_int(key)
            else:

                def getter(*key):
                    return ndarray.read_uint(key)

            def setter(value, *key):
                ndarray.write_int(key, value)

        self.getter = getter
        self.setter = setter