コード例 #1
0
    def build(self, mesh_instance, size, g2r_field):
        field_dict = {}

        for key, attr in self.attr_dict.items():
            if isinstance(attr.dtype, CompoundType):
                field_dict[key] = attr.dtype.field(shape=None,
                                                   needs_grad=attr.needs_grad)
            else:
                field_dict[key] = impl.field(attr.dtype,
                                             shape=None,
                                             needs_grad=attr.needs_grad)

        if self.layout == Layout.SOA:
            for key, field in field_dict.items():
                impl.root.dense(impl.axes(0), size).place(field)
                if self.attr_dict[key].needs_grad:
                    impl.root.dense(impl.axes(0), size).place(field.grad)
        elif len(field_dict) > 0:
            impl.root.dense(impl.axes(0),
                            size).place(*tuple(field_dict.values()))
            grads = []
            for key, field in field_dict.items():
                if self.attr_dict[key].needs_grad: grads.append(field.grad)
            if len(grads) > 0:
                impl.root.dense(impl.axes(0), size).place(*grads)

        return MeshElementField(mesh_instance, self._type, self.attr_dict,
                                field_dict, g2r_field)
コード例 #2
0
ファイル: misc.py プロジェクト: taichi-dev/taichi
from copy import deepcopy as _deepcopy

from taichi._lib import core as _ti_core
from taichi._lib.utils import locale_encode
from taichi.lang import impl
from taichi.lang.expr import Expr
from taichi.lang.impl import axes, get_runtime
from taichi.profiler.kernel_profiler import get_default_kernel_profiler
from taichi.types.primitive_types import f32, f64, i32, i64

from taichi import _logging, _snode, _version_check

warnings.filterwarnings("once", category=DeprecationWarning, module="taichi")

# ----------------------
i = axes(0)
"""Axis 0. For multi-dimensional arrays it's the direction downward the rows.
For a 1d array it's the direction along this array.
"""
# ----------------------

j = axes(1)
"""Axis 1. For multi-dimensional arrays it's the direction across the columns.
"""
# ----------------------

k = axes(2)
"""Axis 2. For arrays of dimension `d` >= 3, view each cell as an array of
lower dimension d-2, it's the first axis of this cell.
"""
# ----------------------