def remove_placeholder_functions(expr):
    subexpressions = []

    def visit(e):
        if isinstance(e, Node):
            return e
        elif isinstance(e, PlaceholderFunction):
            for se in e.subexpressions:
                if se.lhs not in {a.lhs for a in subexpressions}:
                    subexpressions.append(se)
            return e.value
        else:
            new_args = [visit(a) for a in e.args]
            return e.func(*new_args) if new_args else e

    return generic_visit(expr, visit), subexpressions
Exemple #2
0
def discretize_spatial(expr, dx, stencil=fd_stencils_standard):
    if isinstance(stencil, str):
        if stencil == 'standard':
            stencil = fd_stencils_standard
        elif stencil == 'isotropic':
            stencil = fd_stencils_isotropic
        else:
            raise ValueError("Unknown stencil. Supported 'standard' and 'isotropic'")

    def visitor(e):
        if isinstance(e, Diff):
            arg, *indices = diff_args(e)
            if not isinstance(arg, Field.Access):
                raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
            return stencil(indices, dx, arg)
        else:
            new_args = [discretize_spatial(a, dx, stencil) for a in e.args]
            return e.func(*new_args) if new_args else e

    return generic_visit(expr, visitor)
Exemple #3
0
def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard):
    def staggered_visitor(e, coordinate, sign):
        if isinstance(e, Diff):
            arg, *indices = diff_args(e)
            if len(indices) != 1:
                raise ValueError("Function supports only up to second derivatives")
            if not isinstance(arg, Field.Access):
                raise ValueError("Argument of inner derivative has to be field access")
            target = indices[0]
            if target == coordinate:
                assert sign in (-1, 1)
                return (arg.neighbor(coordinate, sign) - arg) / dx * sign
            else:
                return (stencil(indices, dx, arg.neighbor(coordinate, sign))
                        + stencil(indices, dx, arg)) / 2
        elif isinstance(e, Field.Access):
            return (e.neighbor(coordinate, sign) + e) / 2
        elif isinstance(e, sp.Symbol):
            loop_idx = LoopOverCoordinate.is_loop_counter_symbol(e)
            return e + sign / 2 if loop_idx == coordinate else e
        else:
            new_args = [staggered_visitor(a, coordinate, sign) for a in e.args]
            return e.func(*new_args) if new_args else e

    def visitor(e):
        if isinstance(e, Diff):
            arg, *indices = diff_args(e)
            if isinstance(arg, Field.Access):
                return stencil(indices, dx, arg)
            else:
                if not len(indices) == 1:
                    raise ValueError("This term is not support by the staggered discretization strategy")
                target = indices[0]
                return (staggered_visitor(arg, target, 1) - staggered_visitor(arg, target, -1)) / dx
        else:
            new_args = [visitor(a) for a in e.args]
            return e.func(*new_args) if new_args else e

    return generic_visit(expr, visitor)