def __call__(self, ctx, node):
     method = getattr(self, 'build_' + node.__class__.__name__, None)
     try:
         if method is None:
             error_msg = f'Unsupported node "{node.__class__.__name__}"'
             raise TaichiSyntaxError(error_msg)
         return method(ctx, node)
     except Exception as e:
         if ctx.raised or not isinstance(node, (ast.stmt, ast.expr)):
             raise e.with_traceback(None)
         ctx.raised = True
         e = handle_exception_from_cpp(e)
         if not isinstance(e, TaichiCompilationError):
             msg = ctx.get_pos_info(node) + traceback.format_exc()
             raise TaichiCompilationError(msg) from None
         msg = ctx.get_pos_info(node) + str(e)
         raise type(e)(msg) from None
Beispiel #2
0
        def func__(*args):
            assert len(args) == len(
                self.argument_annotations
            ), f'{len(self.argument_annotations)} arguments needed but {len(args)} provided'

            tmps = []
            callbacks = []
            has_external_arrays = False
            has_torch = has_pytorch()

            actual_argument_slot = 0
            launch_ctx = t_kernel.make_launch_context()
            for i, v in enumerate(args):
                needed = self.argument_annotations[i]
                if isinstance(needed, template):
                    continue
                provided = type(v)
                # Note: do not use sth like "needed == f32". That would be slow.
                if id(needed) in primitive_types.real_type_ids:
                    if not isinstance(v, (float, int)):
                        raise TaichiRuntimeTypeError.get(
                            i, needed.to_string(), provided)
                    launch_ctx.set_arg_float(actual_argument_slot, float(v))
                elif id(needed) in primitive_types.integer_type_ids:
                    if not isinstance(v, int):
                        raise TaichiRuntimeTypeError.get(
                            i, needed.to_string(), provided)
                    launch_ctx.set_arg_int(actual_argument_slot, int(v))
                elif isinstance(needed, sparse_matrix_builder):
                    # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument
                    launch_ctx.set_arg_int(actual_argument_slot, v._get_addr())
                elif isinstance(needed,
                                ndarray_type.NdarrayType) and isinstance(
                                    v, taichi.lang._ndarray.Ndarray):
                    has_external_arrays = True
                    v = v.arr
                    launch_ctx.set_arg_ndarray(actual_argument_slot, v)
                elif isinstance(
                        needed,
                        ndarray_type.NdarrayType) and (self.match_ext_arr(v)):
                    has_external_arrays = True
                    is_numpy = isinstance(v, np.ndarray)
                    if is_numpy:
                        tmp = np.ascontiguousarray(v)
                        # Purpose: DO NOT GC |tmp|!
                        tmps.append(tmp)
                        launch_ctx.set_arg_external_array_with_shape(
                            actual_argument_slot, int(tmp.ctypes.data),
                            tmp.nbytes, v.shape)
                    else:
                        is_ndarray = False
                        tmp, torch_callbacks = self.get_torch_callbacks(
                            v, has_torch, is_ndarray)
                        callbacks += torch_callbacks
                        launch_ctx.set_arg_external_array_with_shape(
                            actual_argument_slot, int(tmp.data_ptr()),
                            tmp.element_size() * tmp.nelement(), v.shape)

                elif isinstance(needed, MatrixType):
                    if id(needed.dtype) in primitive_types.real_type_ids:
                        for a in range(needed.n):
                            for b in range(needed.m):
                                if not isinstance(v[a, b], (int, float)):
                                    raise TaichiRuntimeTypeError.get(
                                        i, needed.dtype.to_string(),
                                        type(v[a, b]))
                                launch_ctx.set_arg_float(
                                    actual_argument_slot, float(v[a, b]))
                                actual_argument_slot += 1
                    elif id(needed.dtype) in primitive_types.integer_type_ids:
                        for a in range(needed.n):
                            for b in range(needed.m):
                                if not isinstance(v[a, b], int):
                                    raise TaichiRuntimeTypeError.get(
                                        i, needed.dtype.to_string(),
                                        type(v[a, b]))
                                launch_ctx.set_arg_int(actual_argument_slot,
                                                       int(v[a, b]))
                                actual_argument_slot += 1
                    else:
                        raise ValueError(
                            f'Matrix dtype {needed.dtype} is not integer type or real type.'
                        )
                    continue
                else:
                    raise ValueError(
                        f'Argument type mismatch. Expecting {needed}, got {type(v)}.'
                    )
                actual_argument_slot += 1
            # Both the class kernels and the plain-function kernels are unified now.
            # In both cases, |self.grad| is another Kernel instance that computes the
            # gradient. For class kernels, args[0] is always the kernel owner.
            if not self.is_grad and self.runtime.target_tape and not self.runtime.grad_replaced:
                self.runtime.target_tape.insert(self, args)

            if actual_argument_slot > 8 and (
                    impl.current_cfg().arch == _ti_core.opengl
                    or impl.current_cfg().arch == _ti_core.cc):
                raise TaichiRuntimeError(
                    f"The number of elements in kernel arguments is too big! Do not exceed 8 on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
                )

            if actual_argument_slot > 64 and (
                (impl.current_cfg().arch != _ti_core.opengl
                 and impl.current_cfg().arch != _ti_core.cc)):
                raise TaichiRuntimeError(
                    f"The number of elements in kernel arguments is too big! Do not exceed 64 on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
                )

            try:
                t_kernel(launch_ctx)
            except Exception as e:
                e = handle_exception_from_cpp(e)
                raise e from None

            ret = None
            ret_dt = self.return_type
            has_ret = ret_dt is not None

            if has_ret or (impl.current_cfg().async_mode
                           and has_external_arrays):
                runtime_ops.sync()

            if has_ret:
                if id(ret_dt) in primitive_types.integer_type_ids:
                    ret = t_kernel.get_ret_int(0)
                elif id(ret_dt) in primitive_types.real_type_ids:
                    ret = t_kernel.get_ret_float(0)
                elif id(ret_dt.dtype) in primitive_types.integer_type_ids:
                    it = iter(t_kernel.get_ret_int_tensor(0))
                    ret = Matrix([[next(it) for _ in range(ret_dt.m)]
                                  for _ in range(ret_dt.n)])
                else:
                    it = iter(t_kernel.get_ret_float_tensor(0))
                    ret = Matrix([[next(it) for _ in range(ret_dt.m)]
                                  for _ in range(ret_dt.n)])
            if callbacks:
                for c in callbacks:
                    c()

            return ret