def _filter(o: Expr, self: Memoizer) -> Expr: if not isinstance(o, Expr): raise AssertionError(f"Cannot handle term with type {type(o)}") if self.predicate(o): return Zero(shape=o.ufl_shape, free_indices=o.ufl_free_indices, index_dimensions=o.ufl_index_dimensions) else: return ufl_reuse_if_untouched(o, *map(self, o.ufl_operands))
def _split_component_tensor(o, self, inct): expressions, multiindices = (self(op, True) for op in o.ufl_operands) result = [] shape_indices = set(i.count() for i in multiindices[0].indices()) for expression, multiindex in zip(expressions, multiindices): if shape_indices <= set(expression.ufl_free_indices): result.append(ufl_reuse_if_untouched(o, expression, multiindex)) else: result.append(expression) return tuple(result)
def _split_indexed(o, self, inct): aggregate, multiindex = o.ufl_operands indices = multiindex.indices() result = [] for agg in self(aggregate, False): ncmp = len(agg.ufl_shape) if ncmp == 0: result.append(agg) elif not inct: idx = indices[:ncmp] indices = indices[ncmp:] mi = multiindex if multiindex.indices() == idx else MultiIndex(idx) result.append(ufl_reuse_if_untouched(o, agg, mi)) else: # shape and inct aggshape = (flatten(agg.ufl_shape) + tuple(itertools.repeat(1, len(aggregate.ufl_shape) - 1))) agg = reshape(agg, aggshape) result.append(ufl_reuse_if_untouched(o, agg, multiindex)) return tuple(result)
def _split_indexed(o, self): aggregate, multiindex = o.ufl_operands indices = multiindex.indices() result = [] for agg in self(aggregate): ncmp = len(agg.ufl_shape) idx = indices[:ncmp] indices = indices[ncmp:] if ncmp == 0: result.append(agg) else: mi = multiindex if multiindex.indices() == idx else MultiIndex(idx) result.append(ufl_reuse_if_untouched(o, agg, mi)) return tuple(result)
def _split_expr(o, self, inct): return tuple( ufl_reuse_if_untouched(o, *ops) for ops in zip(*(self(op, inct) for op in o.ufl_operands)))
def strip_dt(e, self): os = e.ufl_operands if os: stripped_os = map(self, os) return ufl_reuse_if_untouched(e, *stripped_os) return e