Esempio n. 1
0
File: impl.py Progetto: k-ye/taichi
def subscript(value, *_indices, skip_reordered=False, get_ref=False):
    if isinstance(value, np.ndarray):
        return value.__getitem__(_indices)

    if isinstance(value, (tuple, list, dict)):
        assert len(_indices) == 1
        return value[_indices[0]]

    has_slice = False
    flattened_indices = []
    for _index in _indices:
        if is_taichi_class(_index):
            ind = _index.entries
        elif isinstance(_index, slice):
            ind = [_index]
            has_slice = True
        else:
            ind = [_index]
        flattened_indices += ind
    _indices = tuple(flattened_indices)
    if len(_indices) == 1 and _indices[0] is None:
        _indices = ()

    if has_slice:
        if not isinstance(value, Matrix):
            raise SyntaxError(
                f"The type {type(value)} do not support index of slice type")
    else:
        indices_expr_group = make_expr_group(*_indices)
        index_dim = indices_expr_group.size()

    if is_taichi_class(value):
        return value._subscript(*_indices, get_ref=get_ref)
    if isinstance(value, MeshElementFieldProxy):
        return value.subscript(*_indices)
    if isinstance(value, MeshRelationAccessProxy):
        return value.subscript(*_indices)
    if isinstance(value,
                  (MeshReorderedScalarFieldProxy,
                   MeshReorderedMatrixFieldProxy)) and not skip_reordered:
        assert index_dim == 1
        reordered_index = tuple([
            Expr(
                _ti_core.get_index_conversion(value.mesh_ptr,
                                              value.element_type,
                                              Expr(_indices[0]).ptr,
                                              ConvType.g2r))
        ])
        return subscript(value, *reordered_index, skip_reordered=True)
    if isinstance(value, SparseMatrixProxy):
        return value.subscript(*_indices)
    if isinstance(value, Field):
        _var = value._get_field_members()[0].ptr
        if _var.snode() is None:
            if _var.is_primal():
                raise RuntimeError(
                    f"{_var.get_expr_name()} has not been placed.")
            else:
                raise RuntimeError(
                    f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`"
                )
        field_dim = int(_var.get_attribute("dim"))
        if field_dim != index_dim:
            raise IndexError(
                f'Field with dim {field_dim} accessed with indices of dim {index_dim}'
            )
        if isinstance(value, MatrixField):
            return _MatrixFieldElement(value, indices_expr_group)
        if isinstance(value, StructField):
            entries = {k: subscript(v, *_indices) for k, v in value._items}
            entries['__struct_methods'] = value.struct_methods
            return _IntermediateStruct(entries)
        return Expr(_ti_core.subscript(_var, indices_expr_group))
    if isinstance(value, AnyArray):
        # TODO: deprecate using get_attribute to get dim
        field_dim = int(value.ptr.get_attribute("dim"))
        element_dim = len(value.element_shape)
        if field_dim != index_dim + element_dim:
            raise IndexError(
                f'Field with dim {field_dim - element_dim} accessed with indices of dim {index_dim}'
            )
        if element_dim == 0:
            return Expr(_ti_core.subscript(value.ptr, indices_expr_group))
        n = value.element_shape[0]
        m = 1 if element_dim == 1 else value.element_shape[1]
        any_array_access = AnyArrayAccess(value, _indices)
        ret = _IntermediateMatrix(n, m, [
            any_array_access.subscript(i, j) for i in range(n)
            for j in range(m)
        ])
        ret.any_array_access = any_array_access
        return ret
    # Directly evaluate in Python for non-Taichi types
    return value.__getitem__(*_indices)
Esempio n. 2
0
 def wrapped(a):
     _taichi_skip_traceback = 1
     if is_taichi_class(a):
         return a.element_wise_unary(imp_foo)
     else:
         return imp_foo(a)
Esempio n. 3
0
 def wrapped(a, b):
     if is_taichi_class(a):
         return a.element_wise_binary(imp_foo, b)
     if is_taichi_class(b):
         return b.element_wise_binary(rev_foo, a)
     return imp_foo(a, b)
Esempio n. 4
0
 def wrapped(a):
     if is_taichi_class(a):
         return a.element_wise_unary(imp_foo)
     return imp_foo(a)
Esempio n. 5
0
def bit_cast(obj, dtype):
    dtype = cook_dtype(dtype)
    if is_taichi_class(obj):
        raise ValueError('Cannot apply bit_cast on Taichi classes')
    else:
        return expr.Expr(_ti_core.bits_cast(expr.Expr(obj).ptr, dtype))
Esempio n. 6
0
def cast(obj, dtype):
    dtype = cook_dtype(dtype)
    if is_taichi_class(obj):
        # TODO: unify with element_wise_unary
        return obj.cast(dtype)
    return expr.Expr(_ti_core.value_cast(expr.Expr(obj).ptr, dtype))