def __init__(self, mesh: MeshInstance, element_type: MeshElementType, entry_expr: impl.Expr): self.mesh = mesh self.element_type = element_type self.entry_expr = entry_expr element_field = self.mesh.fields[self.element_type] for key, attr in element_field.field_dict.items(): global_entry_expr = impl.Expr( _ti_core.get_index_conversion( self.mesh.mesh_ptr, element_type, entry_expr, ConvType.l2r if element_field.attr_dict[key].reorder else ConvType.l2g)) # transform index space global_entry_expr_group = impl.make_expr_group( *tuple([global_entry_expr])) if isinstance(attr, MatrixField): setattr( self, key, _IntermediateMatrix(attr.n, attr.m, [ impl.Expr( _ti_core.subscript(e.ptr, global_entry_expr_group)) for e in attr.get_field_members() ])) elif isinstance(attr, StructField): raise RuntimeError('ti.Mesh has not support StructField yet') else: # isinstance(attr, Field) var = attr.get_field_members()[0].ptr setattr( self, key, impl.Expr(_ti_core.subscript(var, global_entry_expr_group))) for element_type in self.mesh._type.elements: setattr(self, element_type_name(element_type), impl.mesh_relation_access(self.mesh, self, element_type))
def subscript(self, i, j): indices_second = (i, ) if len(self.arr.element_shape) == 1 else (i, j) if self.arr.layout == Layout.SOA: indices = indices_second + self.indices_first else: indices = self.indices_first + indices_second return Expr(_ti_core.subscript(self.arr.ptr, make_expr_group(*indices)))
def subscript(value, *indices): _taichi_skip_traceback = 1 if isinstance(value, np.ndarray): return value.__getitem__(*indices) if isinstance(value, (tuple, list, dict)): assert len(indices) == 1 return value[indices[0]] flattened_indices = [] for i in range(len(indices)): if is_taichi_class(indices[i]): ind = indices[i].entries else: ind = [indices[i]] flattened_indices += ind indices = tuple(flattened_indices) if is_taichi_class(value): return value.subscript(*indices) elif isinstance(value, (Expr, SNode)): if isinstance(value, Expr): if not value.is_global(): raise TypeError( 'Subscription (e.g., "a[i, j]") only works on fields or external arrays.' ) if not value.ptr.is_external_var() and value.ptr.snode() is None: if not value.ptr.is_primal(): raise RuntimeError( f"Gradient {value.ptr.get_expr_name()} has not been placed, check whether `needs_grad=True`" ) else: raise RuntimeError( f"{value.ptr.get_expr_name()} has not been placed.") field_dim = int(value.ptr.get_attribute("dim")) else: # When reading bit structure we only support the 0-D case for now. field_dim = 0 if isinstance(indices, tuple) and len(indices) == 1 and indices[0] is None: indices = [] indices_expr_group = make_expr_group(*indices) index_dim = indices_expr_group.size() if field_dim != index_dim: raise IndexError( f'Field with dim {field_dim} accessed with indices of dim {index_dim}' ) return Expr(_ti_core.subscript(value.ptr, indices_expr_group)) else: return value[indices]
def subscript(value, *indices): _taichi_skip_traceback = 1 if isinstance(value, np.ndarray): return value.__getitem__(*indices) if isinstance(value, (tuple, list, dict)): assert len(indices) == 1 return value[indices[0]] flattened_indices = [] for i in range(len(indices)): if is_taichi_class(indices[i]): ind = indices[i].entries else: ind = [indices[i]] flattened_indices += ind indices = tuple(flattened_indices) if isinstance(indices, tuple) and len(indices) == 1 and indices[0] is None: indices = () indices_expr_group = make_expr_group(*indices) index_dim = indices_expr_group.size() if is_taichi_class(value): return value.subscript(*indices) elif isinstance(value, Field): var = value.get_field_members()[0].ptr if var.snode() is None: if var.is_primal(): raise RuntimeError( f"{var.get_expr_name()} has not been placed.") else: raise RuntimeError( f"Gradient {var.get_expr_name()} has not been placed, check whether `needs_grad=True`" ) field_dim = int(var.get_attribute("dim")) if field_dim != index_dim: raise IndexError( f'Field with dim {field_dim} accessed with indices of dim {index_dim}' ) if isinstance(value, MatrixField): return ti.Matrix.with_entries(value.n, value.m, [ Expr(_ti_core.subscript(e.ptr, indices_expr_group)) for e in value.get_field_members() ]) else: return Expr(_ti_core.subscript(var, indices_expr_group)) elif isinstance(value, AnyArray): # TODO: deprecate using get_attribute to get dim field_dim = int(value.ptr.get_attribute("dim")) element_dim = len(value.element_shape) if field_dim != index_dim + element_dim: raise IndexError( f'Field with dim {field_dim - element_dim} accessed with indices of dim {index_dim}' ) if element_dim == 0: return Expr(_ti_core.subscript(value.ptr, indices_expr_group)) n = value.element_shape[0] m = 1 if element_dim == 1 else value.element_shape[1] any_array_access = AnyArrayAccess(value, indices) ret = ti.Matrix.with_entries(n, m, [ any_array_access.subscript(i, j) for i in range(n) for j in range(m) ]) ret.any_array_access = any_array_access return ret elif isinstance(value, (ExtArray, SNode)): if isinstance(value, ExtArray): field_dim = int(value.ptr.get_attribute("dim")) else: # When reading bit structure we only support the 0-D case for now. field_dim = 0 if field_dim != index_dim: raise IndexError( f'Field with dim {field_dim} accessed with indices of dim {index_dim}' ) return Expr(_ti_core.subscript(value.ptr, indices_expr_group)) else: raise TypeError( 'Subscription (e.g., "a[i, j]") only works on fields or external arrays.' )
def subscript(value, *_indices, skip_reordered=False): _taichi_skip_traceback = 1 if isinstance(value, np.ndarray): return value.__getitem__(*_indices) if isinstance(value, (tuple, list, dict)): assert len(_indices) == 1 return value[_indices[0]] flattened_indices = [] for _index in _indices: if is_taichi_class(_index): ind = _index.entries else: ind = [_index] flattened_indices += ind _indices = tuple(flattened_indices) if isinstance(_indices, tuple) and len(_indices) == 1 and _indices[0] is None: _indices = () indices_expr_group = make_expr_group(*_indices) index_dim = indices_expr_group.size() if is_taichi_class(value): return value.subscript(*_indices) if isinstance(value, MeshElementFieldProxy): return value.subscript(*_indices) if isinstance(value, MeshRelationAccessProxy): return value.subscript(*_indices) if isinstance(value, (MeshReorderedScalarFieldProxy, MeshReorderedMatrixFieldProxy)) and not skip_reordered: assert index_dim == 1 reordered_index = tuple([ Expr( _ti_core.get_index_conversion(value.mesh_ptr, value.element_type, Expr(_indices[0]).ptr, ConvType.g2r)) ]) return subscript(value, *reordered_index, skip_reordered=True) if isinstance(value, SparseMatrixProxy): return value.subscript(*_indices) if isinstance(value, Field): _var = value.get_field_members()[0].ptr if _var.snode() is None: if _var.is_primal(): raise RuntimeError( f"{_var.get_expr_name()} has not been placed.") else: raise RuntimeError( f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`" ) field_dim = int(_var.get_attribute("dim")) if field_dim != index_dim: raise IndexError( f'Field with dim {field_dim} accessed with indices of dim {index_dim}' ) if isinstance(value, MatrixField): return _IntermediateMatrix(value.n, value.m, [ Expr(_ti_core.subscript(e.ptr, indices_expr_group)) for e in value.get_field_members() ]) if isinstance(value, StructField): return _IntermediateStruct( {k: subscript(v, *_indices) for k, v in value.items}) return Expr(_ti_core.subscript(_var, indices_expr_group)) if isinstance(value, AnyArray): # TODO: deprecate using get_attribute to get dim field_dim = int(value.ptr.get_attribute("dim")) element_dim = len(value.element_shape) if field_dim != index_dim + element_dim: raise IndexError( f'Field with dim {field_dim - element_dim} accessed with indices of dim {index_dim}' ) if element_dim == 0: return Expr(_ti_core.subscript(value.ptr, indices_expr_group)) n = value.element_shape[0] m = 1 if element_dim == 1 else value.element_shape[1] any_array_access = AnyArrayAccess(value, _indices) ret = _IntermediateMatrix(n, m, [ any_array_access.subscript(i, j) for i in range(n) for j in range(m) ]) ret.any_array_access = any_array_access return ret if isinstance(value, SNode): # When reading bit structure we only support the 0-D case for now. field_dim = 0 if field_dim != index_dim: raise IndexError( f'Field with dim {field_dim} accessed with indices of dim {index_dim}' ) return Expr(_ti_core.subscript(value.ptr, indices_expr_group)) # Directly evaluate in Python for non-Taichi types return value.__getitem__(*_indices)