Exemplo n.º 1
0
 def pad_key(self, key):
     if key is None:
         key = ()
     if not isinstance(key, (tuple, list)):
         key = (key, )
     assert len(key) == len(self.shape)
     return key + ((0, ) * (_ti_core.get_max_num_indices() - len(key)))
Exemplo n.º 2
0
 def __getitem__(self, key):
     impl.get_runtime().materialize()
     self.initialize_accessor()
     if key is None:
         key = ()
     if not isinstance(key, (tuple, list)):
         key = (key, )
     key = key + ((0, ) * (_ti_core.get_max_num_indices() - len(key)))
     return self.getter(*key)
Exemplo n.º 3
0
 def __setitem__(self, key, value):
     impl.get_runtime().materialize()
     self.initialize_accessor()
     if key is None:
         key = ()
     if not isinstance(key, (tuple, list)):
         key = (key, )
     assert len(key) == len(self.shape)
     key = key + ((0, ) * (_ti_core.get_max_num_indices() - len(key)))
     self.setter(value, *key)
Exemplo n.º 4
0
    def __setitem__(self, key, value):
        """Set value with specified key when the class itself represents GlobalVariableExpression (field) or ExternalTensorExpression internally.

        This will not be directly called from python for vector/matrix fields.
        Python Matrix class will decompose operations into scalar-level first.

        Args:
            key (Union[List[int], int, None]): indices to set
            value (Union[int, float]): value to set
        """
        impl.get_runtime().materialize()
        self.initialize_accessor()
        if key is None:
            key = ()
        if not isinstance(key, (tuple, list)):
            key = (key, )
        assert len(key) == len(self.shape)
        key = key + ((0, ) * (_ti_core.get_max_num_indices() - len(key)))
        self.setter(value, *key)
Exemplo n.º 5
0
        def func__(*args):
            assert len(args) == len(
                self.argument_annotations
            ), '{} arguments needed but {} provided'.format(
                len(self.argument_annotations), len(args))

            tmps = []
            callbacks = []
            has_external_arrays = False

            actual_argument_slot = 0
            launch_ctx = t_kernel.make_launch_context()
            for i, v in enumerate(args):
                needed = self.argument_annotations[i]
                if isinstance(needed, template):
                    continue
                provided = type(v)
                # Note: do not use sth like "needed == f32". That would be slow.
                if id(needed) in primitive_types.real_type_ids:
                    if not isinstance(v, (float, int)):
                        raise KernelArgError(i, needed.to_string(), provided)
                    launch_ctx.set_arg_float(actual_argument_slot, float(v))
                elif id(needed) in primitive_types.integer_type_ids:
                    if not isinstance(v, int):
                        raise KernelArgError(i, needed.to_string(), provided)
                    launch_ctx.set_arg_int(actual_argument_slot, int(v))
                elif isinstance(needed, sparse_matrix_builder):
                    # Pass only the base pointer of the ti.linalg.sparse_matrix_builder() argument
                    launch_ctx.set_arg_int(actual_argument_slot, v.get_addr())
                elif isinstance(needed, any_arr) and (
                        self.match_ext_arr(v)
                        or isinstance(v, taichi.lang._ndarray.Ndarray)):
                    is_ndarray = False
                    if isinstance(v, taichi.lang._ndarray.Ndarray):
                        v = v.arr
                        is_ndarray = True
                    has_external_arrays = True
                    ndarray_use_torch = self.runtime.prog.config.ndarray_use_torch
                    has_torch = util.has_pytorch()
                    is_numpy = isinstance(v, np.ndarray)
                    if is_numpy:
                        tmp = np.ascontiguousarray(v)
                        # Purpose: DO NOT GC |tmp|!
                        tmps.append(tmp)
                        launch_ctx.set_arg_external_array(
                            actual_argument_slot, int(tmp.ctypes.data),
                            tmp.nbytes)
                    elif is_ndarray and not ndarray_use_torch:
                        # Use ndarray's own memory allocator
                        tmp = v
                        launch_ctx.set_arg_external_array(
                            actual_argument_slot, int(tmp.data_ptr()),
                            tmp.element_size() * tmp.nelement())
                    else:

                        def get_call_back(u, v):
                            def call_back():
                                u.copy_(v)

                            return call_back

                        assert util.has_pytorch()
                        assert isinstance(v, torch.Tensor)
                        tmp = v
                        taichi_arch = self.runtime.prog.config.arch

                        if str(v.device).startswith('cuda'):
                            # External tensor on cuda
                            if taichi_arch != _ti_core.Arch.cuda:
                                # copy data back to cpu
                                host_v = v.to(device='cpu', copy=True)
                                tmp = host_v
                                callbacks.append(get_call_back(v, host_v))
                        else:
                            # External tensor on cpu
                            if taichi_arch == _ti_core.Arch.cuda:
                                gpu_v = v.cuda()
                                tmp = gpu_v
                                callbacks.append(get_call_back(v, gpu_v))
                        launch_ctx.set_arg_external_array(
                            actual_argument_slot, int(tmp.data_ptr()),
                            tmp.element_size() * tmp.nelement())

                    shape = v.shape
                    max_num_indices = _ti_core.get_max_num_indices()
                    assert len(
                        shape
                    ) <= max_num_indices, "External array cannot have > {} indices".format(
                        max_num_indices)
                    for ii, s in enumerate(shape):
                        launch_ctx.set_extra_arg_int(actual_argument_slot, ii,
                                                     s)
                else:
                    raise ValueError(
                        f'Argument type mismatch. Expecting {needed}, got {type(v)}.'
                    )
                actual_argument_slot += 1
            # Both the class kernels and the plain-function kernels are unified now.
            # In both cases, |self.grad| is another Kernel instance that computes the
            # gradient. For class kernels, args[0] is always the kernel owner.
            if not self.is_grad and self.runtime.target_tape and not self.runtime.grad_replaced:
                self.runtime.target_tape.insert(self, args)

            t_kernel(launch_ctx)

            ret = None
            ret_dt = self.return_type
            has_ret = ret_dt is not None

            if has_external_arrays or has_ret:
                ti.sync()

            if has_ret:
                if id(ret_dt) in primitive_types.integer_type_ids:
                    ret = t_kernel.get_ret_int(0)
                else:
                    ret = t_kernel.get_ret_float(0)

            if callbacks:
                for c in callbacks:
                    c()

            return ret
Exemplo n.º 6
0
 def getter(*key):
     assert len(key) == _ti_core.get_max_num_indices()
     return snode.read_uint(key)
Exemplo n.º 7
0
 def setter(value, *key):
     assert len(key) == _ti_core.get_max_num_indices()
     snode.write_float(key, value)