def __init__(self, snode): if _ti_core.is_real(snode.data_type()): def getter(*key): assert len(key) == _ti_core.get_max_num_indices() return snode.read_float(key) def setter(value, *key): assert len(key) == _ti_core.get_max_num_indices() snode.write_float(key, value) else: if _ti_core.is_signed(snode.data_type()): def getter(*key): assert len(key) == _ti_core.get_max_num_indices() return snode.read_int(key) else: def getter(*key): assert len(key) == _ti_core.get_max_num_indices() return snode.read_uint(key) def setter(value, *key): assert len(key) == _ti_core.get_max_num_indices() snode.write_int(key, value) self.getter = getter self.setter = setter
def create_field_member(dtype, name, needs_grad, needs_dual): dtype = cook_dtype(dtype) # primal prog = get_runtime().prog if prog is None: raise TaichiRuntimeError( "Cannont create field, maybe you forgot to call `ti.init()` first?" ) x = Expr(prog.make_id_expr("")) x.declaration_tb = get_traceback(stacklevel=4) x.ptr = _ti_core.global_new(x.ptr, dtype) x.ptr.set_name(name) x.ptr.set_is_primal(True) pytaichi.global_vars.append(x) x_grad = None x_dual = None if _ti_core.is_real(dtype): # adjoint x_grad = Expr(get_runtime().prog.make_id_expr("")) x_grad.declaration_tb = get_traceback(stacklevel=4) x_grad.ptr = _ti_core.global_new(x_grad.ptr, dtype) x_grad.ptr.set_name(name + ".grad") x_grad.ptr.set_is_primal(False) x.ptr.set_adjoint(x_grad.ptr) if needs_grad: pytaichi.grad_vars.append(x_grad) # dual x_dual = Expr(get_runtime().prog.make_id_expr("")) x_dual.ptr = _ti_core.global_new(x_dual.ptr, dtype) x_dual.ptr.set_name(name + ".dual") x_dual.ptr.set_is_primal(False) x.ptr.set_dual(x_dual.ptr) if needs_dual: pytaichi.dual_vars.append(x_dual) elif needs_grad or needs_dual: raise TaichiRuntimeError( f'{dtype} is not supported for field with `needs_grad=True` or `needs_dual=True`.' ) return x, x_grad, x_dual
def __init__(self, ndarray): if _ti_core.is_real(ndarray.dtype): def getter(*key): return ndarray.read_float(key) def setter(value, *key): ndarray.write_float(key, value) else: if _ti_core.is_signed(ndarray.dtype): def getter(*key): return ndarray.read_int(key) else: def getter(*key): return ndarray.read_uint(key) def setter(value, *key): ndarray.write_int(key, value) self.getter = getter self.setter = setter