def build(self, mesh_instance, size, g2r_field): field_dict = {} for key, attr in self.attr_dict.items(): if isinstance(attr.dtype, CompoundType): field_dict[key] = attr.dtype.field(shape=None, needs_grad=attr.needs_grad) else: field_dict[key] = impl.field(attr.dtype, shape=None, needs_grad=attr.needs_grad) if self.layout == Layout.SOA: for key, field in field_dict.items(): impl.root.dense(impl.axes(0), size).place(field) if self.attr_dict[key].needs_grad: impl.root.dense(impl.axes(0), size).place(field.grad) elif len(field_dict) > 0: impl.root.dense(impl.axes(0), size).place(*tuple(field_dict.values())) grads = [] for key, field in field_dict.items(): if self.attr_dict[key].needs_grad: grads.append(field.grad) if len(grads) > 0: impl.root.dense(impl.axes(0), size).place(*grads) return MeshElementField(mesh_instance, self._type, self.attr_dict, field_dict, g2r_field)
from copy import deepcopy as _deepcopy from taichi._lib import core as _ti_core from taichi._lib.utils import locale_encode from taichi.lang import impl from taichi.lang.expr import Expr from taichi.lang.impl import axes, get_runtime from taichi.profiler.kernel_profiler import get_default_kernel_profiler from taichi.types.primitive_types import f32, f64, i32, i64 from taichi import _logging, _snode, _version_check warnings.filterwarnings("once", category=DeprecationWarning, module="taichi") # ---------------------- i = axes(0) """Axis 0. For multi-dimensional arrays it's the direction downward the rows. For a 1d array it's the direction along this array. """ # ---------------------- j = axes(1) """Axis 1. For multi-dimensional arrays it's the direction across the columns. """ # ---------------------- k = axes(2) """Axis 2. For arrays of dimension `d` >= 3, view each cell as an array of lower dimension d-2, it's the first axis of this cell. """ # ----------------------