コード例 #1
0
ファイル: field.py プロジェクト: victoriacity/taichi
 def pad_key(self, key):
     if key is None:
         key = ()
     if not isinstance(key, (tuple, list)):
         key = (key, )
     assert len(key) == len(self.shape)
     return key + ((0, ) * (_ti_core.get_max_num_indices() - len(key)))
コード例 #2
0
ファイル: field.py プロジェクト: kokizzu/taichi
    def _pad_key(self, key):
        if key is None:
            key = ()
        if not isinstance(key, (tuple, list)):
            key = (key, )

        if len(key) != len(self.shape):
            raise AssertionError("Slicing is not supported on ti.field")

        return key + ((0, ) * (_ti_core.get_max_num_indices() - len(key)))
コード例 #3
0
ファイル: field.py プロジェクト: victoriacity/taichi
 def setter(value, *key):
     assert len(key) == _ti_core.get_max_num_indices()
     snode.write_int(key, value)
コード例 #4
0
ファイル: field.py プロジェクト: victoriacity/taichi
 def getter(*key):
     assert len(key) == _ti_core.get_max_num_indices()
     return snode.read_uint(key)
コード例 #5
0
ファイル: kernel_impl.py プロジェクト: YuCrazing/taichi-1
        def func__(*args):
            assert len(args) == len(
                self.argument_annotations
            ), f'{len(self.argument_annotations)} arguments needed but {len(args)} provided'

            tmps = []
            callbacks = []
            has_external_arrays = False

            actual_argument_slot = 0
            launch_ctx = t_kernel.make_launch_context()
            for i, v in enumerate(args):
                needed = self.argument_annotations[i]
                if isinstance(needed, template):
                    continue
                provided = type(v)
                # Note: do not use sth like "needed == f32". That would be slow.
                if id(needed) in primitive_types.real_type_ids:
                    if not isinstance(v, (float, int)):
                        raise KernelArgError(i, needed.to_string(), provided)
                    launch_ctx.set_arg_float(actual_argument_slot, float(v))
                elif id(needed) in primitive_types.integer_type_ids:
                    if not isinstance(v, int):
                        raise KernelArgError(i, needed.to_string(), provided)
                    launch_ctx.set_arg_int(actual_argument_slot, int(v))
                elif isinstance(needed, sparse_matrix_builder):
                    # Pass only the base pointer of the ti.linalg.sparse_matrix_builder() argument
                    launch_ctx.set_arg_int(actual_argument_slot, v.get_addr())
                elif isinstance(needed, any_arr) and (
                        self.match_ext_arr(v)
                        or isinstance(v, taichi.lang._ndarray.Ndarray)):
                    is_ndarray = False
                    if isinstance(v, taichi.lang._ndarray.Ndarray):
                        v = v.arr
                        is_ndarray = True
                    has_external_arrays = True
                    ndarray_use_torch = self.runtime.prog.config.ndarray_use_torch
                    has_torch = util.has_pytorch()
                    is_numpy = isinstance(v, np.ndarray)
                    if is_numpy:
                        tmp = np.ascontiguousarray(v)
                        # Purpose: DO NOT GC |tmp|!
                        tmps.append(tmp)
                        launch_ctx.set_arg_external_array(
                            actual_argument_slot, int(tmp.ctypes.data),
                            tmp.nbytes, False)
                    elif is_ndarray and not ndarray_use_torch:
                        # Use ndarray's own memory allocator
                        tmp = v
                        launch_ctx.set_arg_external_array(
                            actual_argument_slot,
                            int(tmp.device_allocation_ptr()),
                            tmp.element_size() * tmp.nelement(), True)
                    else:

                        def get_call_back(u, v):
                            def call_back():
                                u.copy_(v)

                            return call_back

                        assert has_torch
                        assert isinstance(v, torch.Tensor)
                        tmp = v
                        taichi_arch = self.runtime.prog.config.arch
                        # Ndarray means its memory is allocated on the specified taichi arch.
                        # Since torch only supports CPU & CUDA, torch-base ndarray only supports
                        # taichi cpu/cuda backend as well.
                        # Note I put x64/arm64/cuda here to be more specific.
                        assert not is_ndarray or taichi_arch in (
                            _ti_core.Arch.cuda, _ti_core.Arch.x64,
                            _ti_core.Arch.arm64
                        ), "Torch-based ndarray is only supported on taichi x64/arm64/cuda backend."

                        if str(v.device).startswith('cuda'):
                            # External tensor on cuda
                            if taichi_arch != _ti_core.Arch.cuda:
                                # copy data back to cpu
                                host_v = v.to(device='cpu', copy=True)
                                tmp = host_v
                                callbacks.append(get_call_back(v, host_v))
                        else:
                            # External tensor on cpu
                            if taichi_arch == _ti_core.Arch.cuda:
                                gpu_v = v.cuda()
                                tmp = gpu_v
                                callbacks.append(get_call_back(v, gpu_v))
                        launch_ctx.set_arg_external_array(
                            actual_argument_slot, int(tmp.data_ptr()),
                            tmp.element_size() * tmp.nelement(), False)

                    shape = v.shape
                    max_num_indices = _ti_core.get_max_num_indices()
                    assert len(
                        shape
                    ) <= max_num_indices, f"External array cannot have > {max_num_indices} indices"
                    for ii, s in enumerate(shape):
                        launch_ctx.set_extra_arg_int(actual_argument_slot, ii,
                                                     s)
                else:
                    raise ValueError(
                        f'Argument type mismatch. Expecting {needed}, got {type(v)}.'
                    )
                actual_argument_slot += 1
            # Both the class kernels and the plain-function kernels are unified now.
            # In both cases, |self.grad| is another Kernel instance that computes the
            # gradient. For class kernels, args[0] is always the kernel owner.
            if not self.is_grad and self.runtime.target_tape and not self.runtime.grad_replaced:
                self.runtime.target_tape.insert(self, args)

            t_kernel(launch_ctx)

            ret = None
            ret_dt = self.return_type
            has_ret = ret_dt is not None

            if has_ret or (ti.current_cfg().async_mode
                           and has_external_arrays):
                ti.sync()

            if has_ret:
                if id(ret_dt) in primitive_types.integer_type_ids:
                    ret = t_kernel.get_ret_int(0)
                else:
                    ret = t_kernel.get_ret_float(0)

            if callbacks:
                for c in callbacks:
                    c()

            return ret