def sym_eig2x2(A, dt): """Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real symmetric matrix. Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix. Args: A (ti.Matrix(2, 2)): input 2x2 symmetric matrix `A`. dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. Returns: eigenvalues (ti.Vector(2)): The eigenvalues. Each entry store one eigen value. eigenvectors (ti.Matrix(2, 2)): The eigenvectors. Each column stores one eigenvector. """ tr = A.trace() det = A.determinant() gap = tr**2 - 4 * det lambda1 = (tr + ops.sqrt(gap)) * 0.5 lambda2 = (tr - ops.sqrt(gap)) * 0.5 eigenvalues = Vector([lambda1, lambda2], dt=dt) A1 = A - lambda1 * Matrix.identity(dt, 2) A2 = A - lambda2 * Matrix.identity(dt, 2) v1 = Vector.zero(dt, 2) v2 = Vector.zero(dt, 2) if all(A1 == Matrix.zero(dt, 2, 2)) and all(A1 == Matrix.zero(dt, 2, 2)): v1 = Vector([0.0, 1.0]).cast(dt) v2 = Vector([1.0, 0.0]).cast(dt) else: v1 = Vector([A2[0, 0], A2[1, 0]], dt=dt).normalized() v2 = Vector([A1[0, 0], A1[1, 0]], dt=dt).normalized() eigenvectors = Matrix.cols([v1, v2]) return eigenvalues, eigenvectors
def polar_decompose2d(A, dt): """Perform polar decomposition (A=UP) for 2x2 matrix. Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition. Args: A (ti.Matrix(2, 2)): input 2x2 matrix `A`. dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. Returns: Decomposed 2x2 matrices `U` and `P`. `U` is a 2x2 orthogonal matrix and `P` is a 2x2 positive or semi-positive definite matrix. """ U = Matrix.identity(dt, 2) P = ops.cast(A, dt) zero = ops.cast(0.0, dt) # if A is a zero matrix we simply return the pair (I, A) if (A[0, 0] == zero and A[0, 1] == zero and A[1, 0] == zero and A[1, 1] == zero): pass else: detA = A[0, 0] * A[1, 1] - A[1, 0] * A[0, 1] adetA = abs(detA) B = Matrix([[A[0, 0] + A[1, 1], A[0, 1] - A[1, 0]], [A[1, 0] - A[0, 1], A[1, 1] + A[0, 0]]], dt) if detA < zero: B = Matrix([[A[0, 0] - A[1, 1], A[0, 1] + A[1, 0]], [A[1, 0] + A[0, 1], A[1, 1] - A[0, 0]]], dt) # here det(B) != 0 if A is not the zero matrix adetB = abs(B[0, 0] * B[1, 1] - B[1, 0] * B[0, 1]) k = ops.cast(1.0, dt) / ops.sqrt(adetB) U = B * k P = (A.transpose() @ A + adetA * Matrix.identity(dt, 2)) * k return U, P
def expr_init(rhs): if rhs is None: return Expr(get_runtime().prog.current_ast_builder().expr_alloca()) if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")): return type(rhs)(*rhs.to_list()) if isinstance(rhs, Matrix): return Matrix(rhs.to_list()) if isinstance(rhs, Struct): return Struct(rhs.to_dict()) if isinstance(rhs, list): return [expr_init(e) for e in rhs] if isinstance(rhs, tuple): return tuple(expr_init(e) for e in rhs) if isinstance(rhs, dict): return dict((key, expr_init(val)) for key, val in rhs.items()) if isinstance(rhs, _ti_core.DataType): return rhs if isinstance(rhs, _ti_core.Arch): return rhs if isinstance(rhs, _Ndrange): return rhs if isinstance(rhs, MeshElementFieldProxy): return rhs if isinstance(rhs, MeshRelationAccessProxy): return rhs if hasattr(rhs, '_data_oriented'): return rhs return Expr(get_runtime().prog.current_ast_builder().expr_var( Expr(rhs).ptr))
def solve(A, b, dt=None): """Solve a matrix using Gauss elimination method. Args: A (ti.Matrix(n, n)): input nxn matrix `A`. b (ti.Vector(n, 1)): input nx1 vector `b`. dt (DataType): The datatype for the `A` and `b`. Returns: x (ti.Vector(n, 1)): the solution of Ax=b. """ assert A.n == A.m, "Only sqaure matrix is supported" assert A.n >= 2 and A.n <= 3, "Only 2D and 3D matrices are supported" assert A.m == b.n, "Matrix and Vector dimension dismatch" if dt is None: dt = impl.get_runtime().default_fp nrow, ncol = static(A.n, A.n + 1) Ab = expr_init(Matrix.zero(dt, nrow, ncol)) lhs = tuple([e.ptr for e in A.entries]) rhs = tuple([e.ptr for e in b.entries]) for i in range(nrow): for j in range(nrow): Ab(i, j)._assign(lhs[nrow * i + j]) for i in range(nrow): Ab(i, nrow)._assign(rhs[i]) if A.n == 2: return _gauss_elimination_2x2(Ab, dt) if A.n == 3: return _gauss_elimination_3x3(Ab, dt) raise Exception("Solver only supports 2D and 3D matrices.")
def expr_init(rhs): if rhs is None: return Expr(_ti_core.expr_alloca()) if isinstance(rhs, Matrix): if rhs.in_python_scope or isinstance(rhs, _IntermediateMatrix): return Matrix(rhs.to_list()) return rhs if isinstance(rhs, Struct): if rhs.in_python_scope or isinstance(rhs, _IntermediateStruct): return Struct(rhs.to_dict()) return rhs if isinstance(rhs, list): return [expr_init(e) for e in rhs] if isinstance(rhs, tuple): return tuple(expr_init(e) for e in rhs) if isinstance(rhs, dict): return dict((key, expr_init(val)) for key, val in rhs.items()) if isinstance(rhs, _ti_core.DataType): return rhs if isinstance(rhs, _ti_core.Arch): return rhs if isinstance(rhs, ti.ndrange): return rhs if isinstance(rhs, MeshElementFieldProxy): return rhs if isinstance(rhs, MeshRelationAccessProxy): return rhs if hasattr(rhs, '_data_oriented'): return rhs return Expr(_ti_core.expr_var(Expr(rhs).ptr))
def eig2x2(A, dt): """Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 2x2 real matrix. Mathematical concept refers to https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix. Args: A (ti.Matrix(2, 2)): input 2x2 matrix `A`. dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. Returns: eigenvalues (ti.Matrix(2, 2)): The eigenvalues in complex form. Each row stores one eigenvalue. The first number of the eigenvalue represents the real part and the second number represents the imaginary part. eigenvectors: (ti.Matrix(4, 2)): The eigenvectors in complex form. Each column stores one eigenvector. Each eigenvector consists of 2 entries, each of which is represented by two numbers for its real part and imaginary part. """ tr = A.trace() det = A.determinant() gap = tr**2 - 4 * det lambda1 = Vector.zero(dt, 2) lambda2 = Vector.zero(dt, 2) v1 = Vector.zero(dt, 4) v2 = Vector.zero(dt, 4) if gap > 0: lambda1 = Vector([tr + ops.sqrt(gap), 0.0], dt=dt) * 0.5 lambda2 = Vector([tr - ops.sqrt(gap), 0.0], dt=dt) * 0.5 A1 = A - lambda1[0] * Matrix.identity(dt, 2) A2 = A - lambda2[0] * Matrix.identity(dt, 2) if all(A1 == Matrix.zero(dt, 2, 2)) and all( A1 == Matrix.zero(dt, 2, 2)): v1 = Vector([0.0, 0.0, 1.0, 0.0]).cast(dt) v2 = Vector([1.0, 0.0, 0.0, 0.0]).cast(dt) else: v1 = Vector([A2[0, 0], 0.0, A2[1, 0], 0.0], dt=dt).normalized() v2 = Vector([A1[0, 0], 0.0, A1[1, 0], 0.0], dt=dt).normalized() else: lambda1 = Vector([tr, ops.sqrt(-gap)], dt=dt) * 0.5 lambda2 = Vector([tr, -ops.sqrt(-gap)], dt=dt) * 0.5 A1r = A - lambda1[0] * Matrix.identity(dt, 2) A1i = -lambda1[1] * Matrix.identity(dt, 2) A2r = A - lambda2[0] * Matrix.identity(dt, 2) A2i = -lambda2[1] * Matrix.identity(dt, 2) v1 = Vector([A2r[0, 0], A2i[0, 0], A2r[1, 0], A2i[1, 0]], dt=dt).normalized() v2 = Vector([A1r[0, 0], A1i[0, 0], A1r[1, 0], A1i[1, 0]], dt=dt).normalized() eigenvalues = Matrix.rows([lambda1, lambda2]) eigenvectors = Matrix.cols([v1, v2]) return eigenvalues, eigenvectors
def produce_injected_args(kernel, symbolic_args=None): injected_args = [] for i, arg in enumerate(kernel.arguments): anno = arg.annotation if isinstance(anno, template_types): if not isinstance(anno, NdarrayType): raise TaichiCompilationError( f'Expected Ndaray type, got {anno}') if symbolic_args is not None: element_shape = tuple(symbolic_args[i].element_shape) element_dim = len(element_shape) dtype = symbolic_args[i].dtype() else: element_shape = anno.element_shape element_dim = anno.element_dim dtype = anno.dtype if element_shape is None or anno.field_dim is None: raise TaichiCompilationError( 'Please either specify both `element_shape` and `field_dim` ' 'in the param annotation, or provide an example ' f'ndarray for param={arg.name}') if element_dim is None or element_dim == 0: injected_args.append( ScalarNdarray(dtype, (2, ) * anno.field_dim)) elif element_dim == 1: injected_args.append( VectorNdarray(element_shape[0], dtype=dtype, shape=(2, ) * anno.field_dim, layout=Layout.AOS)) elif element_dim == 2: injected_args.append( MatrixNdarray(element_shape[0], element_shape[1], dtype=dtype, shape=(2, ) * anno.field_dim, layout=Layout.AOS)) else: raise RuntimeError('') elif isinstance(anno, MatrixType): if not isinstance(symbolic_args[i], list): raise RuntimeError('Expected a symbolic arg with Matrix type.') symbolic_mat_n = len(symbolic_args[i]) symbolic_mat_m = len(symbolic_args[i][0]) if symbolic_mat_m != anno.m or symbolic_mat_n != anno.n: raise RuntimeError( f'Matrix dimension mismatch, expected ({anno.n}, {anno.m}) ' f'but dispathed shape ({symbolic_mat_n}, {symbolic_mat_m}).' ) injected_args.append(Matrix([0] * anno.n * anno.m, dt=anno.dtype)) else: # For primitive types, we can just inject a dummy value. injected_args.append(0) return injected_args
def polar_decompose2d(A, dt): """Perform polar decomposition (A=UP) for 2x2 matrix. Mathematical concept refers to https://en.wikipedia.org/wiki/Polar_decomposition. Args: A (ti.Matrix(2, 2)): input 2x2 matrix `A`. dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. Returns: Decomposed 2x2 matrices `U` and `P`. """ x, y = A(0, 0) + A(1, 1), A(1, 0) - A(0, 1) scale = (1.0 / ops.sqrt(x * x + y * y)) c = x * scale s = y * scale r = Matrix([[c, -s], [s, c]], dt=dt) return r, r.transpose() @ A
def svd3d(A, dt, iters=None): """Perform singular value decomposition (A=USV^T) for 3x3 matrix. Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition. Args: A (ti.Matrix(3, 3)): input 3x3 matrix `A`. dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. iters (int): iteration number to control algorithm precision. Returns: Decomposed 3x3 matrices `U`, 'S' and `V`. """ assert A.n == 3 and A.m == 3 inputs = tuple([e.ptr for e in A.entries]) assert dt in [f32, f64] if iters is None: if dt == f32: iters = 5 else: iters = 8 if dt == f32: rets = get_runtime().prog.current_ast_builder().sifakis_svd_f32( *inputs, iters) else: rets = get_runtime().prog.current_ast_builder().sifakis_svd_f64( *inputs, iters) assert len(rets) == 21 U_entries = rets[:9] V_entries = rets[9:18] sig_entries = rets[18:] U = expr_init(Matrix.zero(dt, 3, 3)) V = expr_init(Matrix.zero(dt, 3, 3)) sigma = expr_init(Matrix.zero(dt, 3, 3)) for i in range(3): for j in range(3): U(i, j)._assign(U_entries[i * 3 + j]) V(i, j)._assign(V_entries[i * 3 + j]) sigma(i, i)._assign(sig_entries[i]) return U, sigma, V
def svd2d(A, dt): """Perform singular value decomposition (A=USV^T) for 2x2 matrix. Mathematical concept refers to https://en.wikipedia.org/wiki/Singular_value_decomposition. Args: A (ti.Matrix(2, 2)): input 2x2 matrix `A`. dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. Returns: Decomposed 2x2 matrices `U`, 'S' and `V`. """ R, S = polar_decompose2d(A, dt) c, s = ops.cast(0.0, dt), ops.cast(0.0, dt) s1, s2 = ops.cast(0.0, dt), ops.cast(0.0, dt) if abs(S[0, 1]) < 1e-5: c, s = 1, 0 s1, s2 = S[0, 0], S[1, 1] else: tao = ops.cast(0.5, dt) * (S[0, 0] - S[1, 1]) w = ops.sqrt(tao**2 + S[0, 1]**2) t = ops.cast(0.0, dt) if tao > 0: t = S[0, 1] / (tao + w) else: t = S[0, 1] / (tao - w) c = 1 / ops.sqrt(t**2 + 1) s = -t * c s1 = c**2 * S[0, 0] - 2 * c * s * S[0, 1] + s**2 * S[1, 1] s2 = s**2 * S[0, 0] + 2 * c * s * S[0, 1] + c**2 * S[1, 1] V = Matrix.zero(dt, 2, 2) if s1 < s2: tmp = s1 s1 = s2 s2 = tmp V = Matrix([[-s, c], [-c, -s]], dt=dt) else: V = Matrix([[c, s], [-s, c]], dt=dt) U = R @ V return U, Matrix([[s1, ops.cast(0, dt)], [ops.cast(0, dt), s2]], dt=dt), V
def __init__(self, *args, **kwargs): # converts lists to matrices and dicts to structs if len(args) == 1 and kwargs == {} and isinstance(args[0], dict): self.entries = args[0] elif len(args) == 0: self.entries = kwargs else: raise TaichiSyntaxError( "Custom structs need to be initialized using either dictionary or keyword arguments" ) for k, v in self.entries.items(): if isinstance(v, (list, tuple)): v = Matrix(v) if isinstance(v, dict): v = Struct(v) self.entries[k] = v if in_python_scope() else impl.expr_init(v) self._register_members()
def __init__(self, *args, **kwargs): # converts lists to matrices and dicts to structs if len(args) == 1 and kwargs == {} and isinstance(args[0], dict): self.entries = args[0] elif len(args) == 0: self.entries = kwargs else: raise TaichiSyntaxError( "Custom structs need to be initialized using either dictionary or keyword arguments" ) for k, v in self.entries.items(): if isinstance(v, (list, tuple)): v = Matrix(v) if isinstance(v, dict): v = Struct(v) self.entries[k] = v self.register_members() self.local_tensor_proxy = None self.any_array_access = None
def produce_injected_args(kernel, symbolic_args=None): injected_args = [] for i, arg in enumerate(kernel.arguments): anno = arg.annotation if isinstance(anno, template_types): if not isinstance(anno, NdarrayType): raise TaichiCompilationError( f'Expected Ndaray type, got {anno}') if symbolic_args is not None: element_shape = tuple(symbolic_args[i].element_shape) element_dim = len(element_shape) field_dim = symbolic_args[i].field_dim dtype = symbolic_args[i].dtype() else: element_shape = anno.element_shape element_dim = anno.element_dim field_dim = anno.field_dim dtype = anno.dtype if element_shape is None or field_dim is None: raise TaichiCompilationError( 'Please either specify both `element_shape` and `field_dim` ' 'in the param annotation, or provide an example ' f'ndarray for param={arg.name}') if anno.field_dim is not None and field_dim != anno.field_dim: raise TaichiCompilationError( f'{field_dim} from Arg {arg.name} doesn\'t match kernel\'s annotated field_dim={anno.field_dim}' ) if anno.dtype is not None and not check_type_match( dtype, anno.dtype): raise TaichiCompilationError( f' Arg {arg.name}\'s dtype {dtype.to_string()} doesn\'t match kernel\'s annotated dtype={anno.dtype.to_string()}' ) if element_dim is None or element_dim == 0: injected_args.append(ScalarNdarray(dtype, (2, ) * field_dim)) elif element_dim == 1: injected_args.append( VectorNdarray(element_shape[0], dtype=dtype, shape=(2, ) * field_dim, layout=Layout.AOS)) elif element_dim == 2: injected_args.append( MatrixNdarray(element_shape[0], element_shape[1], dtype=dtype, shape=(2, ) * field_dim, layout=Layout.AOS)) else: raise RuntimeError('') elif isinstance(anno, (TextureType, RWTextureType)): if symbolic_args is None: raise RuntimeError( 'Texture type annotation doesn\'t have enough information for aot. Please either specify the channel_format, shape and num_channels in the graph arg declaration.' ) texture_shape = tuple(symbolic_args[i].texture_shape) channel_format = symbolic_args[i].channel_format() num_channels = symbolic_args[i].num_channels injected_args.append( Texture(channel_format, num_channels, texture_shape)) elif isinstance(anno, MatrixType): if not isinstance(symbolic_args[i], list): raise RuntimeError('Expected a symbolic arg with Matrix type.') symbolic_mat_n = len(symbolic_args[i]) symbolic_mat_m = len(symbolic_args[i][0]) if symbolic_mat_m != anno.m or symbolic_mat_n != anno.n: raise RuntimeError( f'Matrix dimension mismatch, expected ({anno.n}, {anno.m}) ' f'but dispatched shape ({symbolic_mat_n}, {symbolic_mat_m}).' ) injected_args.append(Matrix([0] * anno.n * anno.m, dt=anno.dtype)) else: if symbolic_args is not None: dtype = symbolic_args[i].dtype() else: dtype = anno if not check_type_match(dtype, anno): raise TaichiCompilationError( f' Arg {arg.name}\'s dtype {dtype.to_string()} doesn\'t match kernel\'s annotated dtype={anno.to_string()}' ) # For primitive types, we can just inject a dummy value. injected_args.append(0) return injected_args
def sym_eig3x3(A, dt): """Compute the eigenvalues and right eigenvectors (Av=lambda v) of a 3x3 real symmetric matrix using Cardano's method. Mathematical concept refers to https://www.mpi-hd.mpg.de/personalhomes/globes/3x3/. Args: A (ti.Matrix(3, 3)): input 3x3 symmetric matrix `A`. dt (DataType): date type of elements in matrix `A`, typically accepts ti.f32 or ti.f64. Returns: eigenvalues (ti.Vector(3)): The eigenvalues. Each entry store one eigen value. eigenvectors (ti.Matrix(3, 3)): The eigenvectors. Each column stores one eigenvector. """ M_SQRT3 = 1.73205080756887729352744634151 m = A.trace() dd = A[0, 1] * A[0, 1] ee = A[1, 2] * A[1, 2] ff = A[0, 2] * A[0, 2] c1 = A[0, 0] * A[1, 1] + A[0, 0] * A[2, 2] + A[1, 1] * A[2, 2] - (dd + ee + ff) c0 = A[2, 2] * dd + A[0, 0] * ee + A[1, 1] * ff - A[0, 0] * A[1, 1] * A[ 2, 2] - 2.0 * A[0, 2] * A[0, 1] * A[1, 2] p = m * m - 3.0 * c1 q = m * (p - 1.5 * c1) - 13.5 * c0 sqrt_p = ops.sqrt(ops.abs(p)) phi = 27.0 * (0.25 * c1 * c1 * (p - c1) + c0 * (q + 6.75 * c0)) phi = (1.0 / 3.0) * ops.atan2(ops.sqrt(ops.abs(phi)), q) c = sqrt_p * ops.cos(phi) s = (1.0 / M_SQRT3) * sqrt_p * ops.sin(phi) eigenvalues = Vector([0.0, 0.0, 0.0], dt=dt) eigenvalues[2] = (1.0 / 3.0) * (m - c) eigenvalues[1] = eigenvalues[2] + s eigenvalues[0] = eigenvalues[2] + c eigenvalues[2] = eigenvalues[2] - s t = ops.abs(eigenvalues[0]) u = ops.abs(eigenvalues[1]) if u > t: t = u u = ops.abs(eigenvalues[2]) if u > t: t = u if t < 1.0: u = t else: u = t * t Q = Matrix.zero(dt, 3, 3) Q[0, 1] = A[0, 1] * A[1, 2] - A[0, 2] * A[1, 1] Q[1, 1] = A[0, 2] * A[0, 1] - A[1, 2] * A[0, 0] Q[2, 1] = A[0, 1] * A[0, 1] Q[0, 0] = Q[0, 1] + A[0, 2] * eigenvalues[0] Q[1, 0] = Q[1, 1] + A[1, 2] * eigenvalues[0] Q[2, 0] = (A[0, 0] - eigenvalues[0]) * (A[1, 1] - eigenvalues[0]) - Q[2, 1] norm = Q[0, 0] * Q[0, 0] + Q[1, 0] * Q[1, 0] + Q[2, 0] * Q[2, 0] norm = ops.sqrt(1.0 / norm) Q[0, 0] *= norm Q[1, 0] *= norm Q[2, 0] *= norm Q[0, 1] = Q[0, 1] + A[0, 2] * eigenvalues[1] Q[1, 1] = Q[1, 1] + A[1, 2] * eigenvalues[1] Q[2, 1] = (A[0, 0] - eigenvalues[1]) * (A[1, 1] - eigenvalues[1]) - Q[2, 1] norm = Q[0, 1] * Q[0, 1] + Q[1, 1] * Q[1, 1] + Q[2, 1] * Q[2, 1] norm = ops.sqrt(1.0 / norm) Q[0, 1] *= norm Q[1, 1] *= norm Q[2, 1] *= norm Q[0, 2] = Q[1, 0] * Q[2, 1] - Q[2, 0] * Q[1, 1] Q[1, 2] = Q[2, 0] * Q[0, 1] - Q[0, 0] * Q[2, 1] Q[2, 2] = Q[0, 0] * Q[1, 1] - Q[1, 0] * Q[0, 1] return eigenvalues, Q
def __getitem__(self, key): self._initialize_host_accessors() key = self.g2r_field[key] key = self._pad_key(key) return Matrix(self._host_access(key), is_ref=True)
def decl_matrix_arg(matrixtype): return Matrix( [[decl_scalar_arg(matrixtype.dtype) for _ in range(matrixtype.m)] for _ in range(matrixtype.n)])
def func__(*args): assert len(args) == len( self.argument_annotations ), f'{len(self.argument_annotations)} arguments needed but {len(args)} provided' tmps = [] callbacks = [] has_external_arrays = False has_torch = has_pytorch() actual_argument_slot = 0 launch_ctx = t_kernel.make_launch_context() for i, v in enumerate(args): needed = self.argument_annotations[i] if isinstance(needed, template): continue provided = type(v) # Note: do not use sth like "needed == f32". That would be slow. if id(needed) in primitive_types.real_type_ids: if not isinstance(v, (float, int)): raise TaichiRuntimeTypeError.get( i, needed.to_string(), provided) launch_ctx.set_arg_float(actual_argument_slot, float(v)) elif id(needed) in primitive_types.integer_type_ids: if not isinstance(v, int): raise TaichiRuntimeTypeError.get( i, needed.to_string(), provided) launch_ctx.set_arg_int(actual_argument_slot, int(v)) elif isinstance(needed, sparse_matrix_builder): # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument launch_ctx.set_arg_int(actual_argument_slot, v._get_addr()) elif isinstance(needed, ndarray_type.NdarrayType) and isinstance( v, taichi.lang._ndarray.Ndarray): has_external_arrays = True v = v.arr launch_ctx.set_arg_ndarray(actual_argument_slot, v) elif isinstance( needed, ndarray_type.NdarrayType) and (self.match_ext_arr(v)): has_external_arrays = True is_numpy = isinstance(v, np.ndarray) if is_numpy: tmp = np.ascontiguousarray(v) # Purpose: DO NOT GC |tmp|! tmps.append(tmp) launch_ctx.set_arg_external_array_with_shape( actual_argument_slot, int(tmp.ctypes.data), tmp.nbytes, v.shape) else: is_ndarray = False tmp, torch_callbacks = self.get_torch_callbacks( v, has_torch, is_ndarray) callbacks += torch_callbacks launch_ctx.set_arg_external_array_with_shape( actual_argument_slot, int(tmp.data_ptr()), tmp.element_size() * tmp.nelement(), v.shape) elif isinstance(needed, MatrixType): if id(needed.dtype) in primitive_types.real_type_ids: for a in range(needed.n): for b in range(needed.m): if not isinstance(v[a, b], (int, float)): raise TaichiRuntimeTypeError.get( i, needed.dtype.to_string(), type(v[a, b])) launch_ctx.set_arg_float( actual_argument_slot, float(v[a, b])) actual_argument_slot += 1 elif id(needed.dtype) in primitive_types.integer_type_ids: for a in range(needed.n): for b in range(needed.m): if not isinstance(v[a, b], int): raise TaichiRuntimeTypeError.get( i, needed.dtype.to_string(), type(v[a, b])) launch_ctx.set_arg_int(actual_argument_slot, int(v[a, b])) actual_argument_slot += 1 else: raise ValueError( f'Matrix dtype {needed.dtype} is not integer type or real type.' ) continue else: raise ValueError( f'Argument type mismatch. Expecting {needed}, got {type(v)}.' ) actual_argument_slot += 1 # Both the class kernels and the plain-function kernels are unified now. # In both cases, |self.grad| is another Kernel instance that computes the # gradient. For class kernels, args[0] is always the kernel owner. if not self.is_grad and self.runtime.target_tape and not self.runtime.grad_replaced: self.runtime.target_tape.insert(self, args) if actual_argument_slot > 8 and ( impl.current_cfg().arch == _ti_core.opengl or impl.current_cfg().arch == _ti_core.cc): raise TaichiRuntimeError( f"The number of elements in kernel arguments is too big! Do not exceed 8 on {_ti_core.arch_name(impl.current_cfg().arch)} backend." ) if actual_argument_slot > 64 and ( (impl.current_cfg().arch != _ti_core.opengl and impl.current_cfg().arch != _ti_core.cc)): raise TaichiRuntimeError( f"The number of elements in kernel arguments is too big! Do not exceed 64 on {_ti_core.arch_name(impl.current_cfg().arch)} backend." ) try: t_kernel(launch_ctx) except Exception as e: e = handle_exception_from_cpp(e) raise e from None ret = None ret_dt = self.return_type has_ret = ret_dt is not None if has_ret or (impl.current_cfg().async_mode and has_external_arrays): runtime_ops.sync() if has_ret: if id(ret_dt) in primitive_types.integer_type_ids: ret = t_kernel.get_ret_int(0) elif id(ret_dt) in primitive_types.real_type_ids: ret = t_kernel.get_ret_float(0) elif id(ret_dt.dtype) in primitive_types.integer_type_ids: it = iter(t_kernel.get_ret_int_tensor(0)) ret = Matrix([[next(it) for _ in range(ret_dt.m)] for _ in range(ret_dt.n)]) else: it = iter(t_kernel.get_ret_float_tensor(0)) ret = Matrix([[next(it) for _ in range(ret_dt.m)] for _ in range(ret_dt.n)]) if callbacks: for c in callbacks: c() return ret