Ejemplo n.º 1
0
def _filter(o: Expr, self: Memoizer) -> Expr:
    if not isinstance(o, Expr):
        raise AssertionError(f"Cannot handle term with type {type(o)}")
    if self.predicate(o):
        return Zero(shape=o.ufl_shape,
                    free_indices=o.ufl_free_indices,
                    index_dimensions=o.ufl_index_dimensions)
    else:
        return ufl_reuse_if_untouched(o, *map(self, o.ufl_operands))
def _split_component_tensor(o, self, inct):
    expressions, multiindices = (self(op, True) for op in o.ufl_operands)
    result = []
    shape_indices = set(i.count() for i in multiindices[0].indices())
    for expression, multiindex in zip(expressions, multiindices):
        if shape_indices <= set(expression.ufl_free_indices):
            result.append(ufl_reuse_if_untouched(o, expression, multiindex))
        else:
            result.append(expression)
    return tuple(result)
Ejemplo n.º 3
0
def _split_indexed(o, self, inct):
    aggregate, multiindex = o.ufl_operands
    indices = multiindex.indices()
    result = []
    for agg in self(aggregate, False):
        ncmp = len(agg.ufl_shape)
        if ncmp == 0:
            result.append(agg)
        elif not inct:
            idx = indices[:ncmp]
            indices = indices[ncmp:]
            mi = multiindex if multiindex.indices() == idx else MultiIndex(idx)
            result.append(ufl_reuse_if_untouched(o, agg, mi))
        else:
            # shape and inct
            aggshape = (flatten(agg.ufl_shape)
                        + tuple(itertools.repeat(1, len(aggregate.ufl_shape) - 1)))
            agg = reshape(agg, aggshape)
            result.append(ufl_reuse_if_untouched(o, agg, multiindex))
    return tuple(result)
Ejemplo n.º 4
0
def _split_indexed(o, self):
    aggregate, multiindex = o.ufl_operands
    indices = multiindex.indices()
    result = []
    for agg in self(aggregate):
        ncmp = len(agg.ufl_shape)
        idx = indices[:ncmp]
        indices = indices[ncmp:]
        if ncmp == 0:
            result.append(agg)
        else:
            mi = multiindex if multiindex.indices() == idx else MultiIndex(idx)
            result.append(ufl_reuse_if_untouched(o, agg, mi))
    return tuple(result)
def _split_expr(o, self, inct):
    return tuple(
        ufl_reuse_if_untouched(o, *ops)
        for ops in zip(*(self(op, inct) for op in o.ufl_operands)))
Ejemplo n.º 6
0
def strip_dt(e, self):
    os = e.ufl_operands
    if os:
        stripped_os = map(self, os)
        return ufl_reuse_if_untouched(e, *stripped_os)
    return e