def __init__(self, mesh: MeshInstance, element_type: MeshElementType, entry_expr: impl.Expr): self.mesh = mesh self.element_type = element_type self.entry_expr = entry_expr element_field = self.mesh.fields[self.element_type] for key, attr in element_field.field_dict.items(): global_entry_expr = impl.Expr( _ti_core.get_index_conversion( self.mesh.mesh_ptr, element_type, entry_expr, ConvType.l2r if element_field.attr_dict[key].reorder else ConvType.l2g)) # transform index space global_entry_expr_group = impl.make_expr_group( *tuple([global_entry_expr])) if isinstance(attr, MatrixField): setattr(self, key, _MatrixFieldElement(attr, global_entry_expr_group)) elif isinstance(attr, StructField): raise RuntimeError('ti.Mesh has not support StructField yet') else: # isinstance(attr, Field) var = attr.get_field_members()[0].ptr setattr( self, key, impl.Expr(_ti_core.subscript(var, global_entry_expr_group))) for element_type in self.mesh._type.elements: setattr(self, element_type_name(element_type), impl.mesh_relation_access(self.mesh, self, element_type))
def subscript(value, *_indices, skip_reordered=False): if isinstance(value, np.ndarray): return value.__getitem__(_indices) if isinstance(value, (tuple, list, dict)): assert len(_indices) == 1 return value[_indices[0]] has_slice = False flattened_indices = [] for _index in _indices: if is_taichi_class(_index): ind = _index.entries elif isinstance(_index, slice): ind = [_index] has_slice = True else: ind = [_index] flattened_indices += ind _indices = tuple(flattened_indices) if len(_indices) == 1 and _indices[0] is None: _indices = () if has_slice: if not isinstance(value, Matrix): raise SyntaxError( f"The type {type(value)} do not support index of slice type") else: indices_expr_group = make_expr_group(*_indices) index_dim = indices_expr_group.size() if is_taichi_class(value): return value._subscript(*_indices) if isinstance(value, MeshElementFieldProxy): return value.subscript(*_indices) if isinstance(value, MeshRelationAccessProxy): return value.subscript(*_indices) if isinstance(value, (MeshReorderedScalarFieldProxy, MeshReorderedMatrixFieldProxy)) and not skip_reordered: assert index_dim == 1 reordered_index = tuple([ Expr( _ti_core.get_index_conversion(value.mesh_ptr, value.element_type, Expr(_indices[0]).ptr, ConvType.g2r)) ]) return subscript(value, *reordered_index, skip_reordered=True) if isinstance(value, SparseMatrixProxy): return value.subscript(*_indices) if isinstance(value, Field): _var = value._get_field_members()[0].ptr if _var.snode() is None: if _var.is_primal(): raise RuntimeError( f"{_var.get_expr_name()} has not been placed.") else: raise RuntimeError( f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`" ) field_dim = int(_var.get_attribute("dim")) if field_dim != index_dim: raise IndexError( f'Field with dim {field_dim} accessed with indices of dim {index_dim}' ) if isinstance(value, MatrixField): return _MatrixFieldElement(value, indices_expr_group) if isinstance(value, StructField): return _IntermediateStruct( {k: subscript(v, *_indices) for k, v in value._items}) return Expr(_ti_core.subscript(_var, indices_expr_group)) if isinstance(value, AnyArray): # TODO: deprecate using get_attribute to get dim field_dim = int(value.ptr.get_attribute("dim")) element_dim = len(value.element_shape) if field_dim != index_dim + element_dim: raise IndexError( f'Field with dim {field_dim - element_dim} accessed with indices of dim {index_dim}' ) if element_dim == 0: return Expr(_ti_core.subscript(value.ptr, indices_expr_group)) n = value.element_shape[0] m = 1 if element_dim == 1 else value.element_shape[1] any_array_access = AnyArrayAccess(value, _indices) ret = _IntermediateMatrix(n, m, [ any_array_access.subscript(i, j) for i in range(n) for j in range(m) ]) ret.any_array_access = any_array_access return ret if isinstance(value, SNode): # When reading bit structure we only support the 0-D case for now. field_dim = 0 if field_dim != index_dim: raise IndexError( f'Field with dim {field_dim} accessed with indices of dim {index_dim}' ) return Expr(_ti_core.subscript(value.ptr, indices_expr_group)) # Directly evaluate in Python for non-Taichi types return value.__getitem__(*_indices)