Пример #1
0
    def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
                 entry_expr: impl.Expr):
        self.mesh = mesh
        self.element_type = element_type
        self.entry_expr = entry_expr

        element_field = self.mesh.fields[self.element_type]
        for key, attr in element_field.field_dict.items():
            global_entry_expr = impl.Expr(
                _ti_core.get_index_conversion(
                    self.mesh.mesh_ptr, element_type, entry_expr,
                    ConvType.l2r if element_field.attr_dict[key].reorder else
                    ConvType.l2g))  # transform index space
            global_entry_expr_group = impl.make_expr_group(
                *tuple([global_entry_expr]))
            if isinstance(attr, MatrixField):
                setattr(
                    self, key,
                    _IntermediateMatrix(attr.n, attr.m, [
                        impl.Expr(
                            _ti_core.subscript(e.ptr, global_entry_expr_group))
                        for e in attr.get_field_members()
                    ]))
            elif isinstance(attr, StructField):
                raise RuntimeError('ti.Mesh has not support StructField yet')
            else:  # isinstance(attr, Field)
                var = attr.get_field_members()[0].ptr
                setattr(
                    self, key,
                    impl.Expr(_ti_core.subscript(var,
                                                 global_entry_expr_group)))

        for element_type in self.mesh._type.elements:
            setattr(self, element_type_name(element_type),
                    impl.mesh_relation_access(self.mesh, self, element_type))
Пример #2
0
def subscript(value, *_indices, skip_reordered=False):
    _taichi_skip_traceback = 1
    if isinstance(value, np.ndarray):
        return value.__getitem__(*_indices)

    if isinstance(value, (tuple, list, dict)):
        assert len(_indices) == 1
        return value[_indices[0]]

    flattened_indices = []
    for _index in _indices:
        if is_taichi_class(_index):
            ind = _index.entries
        else:
            ind = [_index]
        flattened_indices += ind
    _indices = tuple(flattened_indices)
    if isinstance(_indices,
                  tuple) and len(_indices) == 1 and _indices[0] is None:
        _indices = ()
    indices_expr_group = make_expr_group(*_indices)
    index_dim = indices_expr_group.size()

    if is_taichi_class(value):
        return value.subscript(*_indices)
    if isinstance(value, MeshElementFieldProxy):
        return value.subscript(*_indices)
    if isinstance(value, MeshRelationAccessProxy):
        return value.subscript(*_indices)
    if isinstance(value,
                  (MeshReorderedScalarFieldProxy,
                   MeshReorderedMatrixFieldProxy)) and not skip_reordered:
        assert index_dim == 1
        reordered_index = tuple([
            Expr(
                _ti_core.get_index_conversion(value.mesh_ptr,
                                              value.element_type,
                                              Expr(_indices[0]).ptr,
                                              ConvType.g2r))
        ])
        return subscript(value, *reordered_index, skip_reordered=True)
    if isinstance(value, SparseMatrixProxy):
        return value.subscript(*_indices)
    if isinstance(value, Field):
        _var = value.get_field_members()[0].ptr
        if _var.snode() is None:
            if _var.is_primal():
                raise RuntimeError(
                    f"{_var.get_expr_name()} has not been placed.")
            else:
                raise RuntimeError(
                    f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`"
                )
        field_dim = int(_var.get_attribute("dim"))
        if field_dim != index_dim:
            raise IndexError(
                f'Field with dim {field_dim} accessed with indices of dim {index_dim}'
            )
        if isinstance(value, MatrixField):
            return _IntermediateMatrix(value.n, value.m, [
                Expr(_ti_core.subscript(e.ptr, indices_expr_group))
                for e in value.get_field_members()
            ])
        if isinstance(value, StructField):
            return _IntermediateStruct(
                {k: subscript(v, *_indices)
                 for k, v in value.items})
        return Expr(_ti_core.subscript(_var, indices_expr_group))
    if isinstance(value, AnyArray):
        # TODO: deprecate using get_attribute to get dim
        field_dim = int(value.ptr.get_attribute("dim"))
        element_dim = len(value.element_shape)
        if field_dim != index_dim + element_dim:
            raise IndexError(
                f'Field with dim {field_dim - element_dim} accessed with indices of dim {index_dim}'
            )
        if element_dim == 0:
            return Expr(_ti_core.subscript(value.ptr, indices_expr_group))
        n = value.element_shape[0]
        m = 1 if element_dim == 1 else value.element_shape[1]
        any_array_access = AnyArrayAccess(value, _indices)
        ret = _IntermediateMatrix(n, m, [
            any_array_access.subscript(i, j) for i in range(n)
            for j in range(m)
        ])
        ret.any_array_access = any_array_access
        return ret
    if isinstance(value, SNode):
        # When reading bit structure we only support the 0-D case for now.
        field_dim = 0
        if field_dim != index_dim:
            raise IndexError(
                f'Field with dim {field_dim} accessed with indices of dim {index_dim}'
            )
        return Expr(_ti_core.subscript(value.ptr, indices_expr_group))
    # Directly evaluate in Python for non-Taichi types
    return value.__getitem__(*_indices)
Пример #3
0
 def __getitem__(self, key):
     self.initialize_host_accessors()
     key = self.g2r_field[key]
     key = self.pad_key(key)
     return _IntermediateMatrix(self.n, self.m, self.host_access(key))
Пример #4
0
 def __iter__(self):
     for ind in self.r:
         yield _IntermediateMatrix(len(ind), 1, list(ind))