Пример #1
0
def transcribe_phase(dag, field_var_name, field_components, phase_name,
                     sym_operator):
    """Generate a Grudge operator for a Dagrt time integrator phase.

    Arguments:

        dag: The Dagrt code object for the time integrator

        field_var_name: The name of the simulation variable

        field_components: The number of components (fields) in the variable

        phase_name: The name of the phase to transcribe

        sym_operator: The Grudge symbolic expression to substitue for the
            right-hand side evaluation in the Dagrt code
    """
    sym_operator = gmap.OperatorBinder()(sym_operator)
    phase = dag.phases[phase_name]

    ctx = {
            "<t>": sym.var("input_t", dof_desc.DD_SCALAR),
            "<dt>": sym.var("input_dt", dof_desc.DD_SCALAR),
            f"<state>{field_var_name}": sym.make_sym_array(
                f"input_{field_var_name}", field_components),
            "<p>residual": sym.make_sym_array(
                "input_residual", field_components),
    }

    rhs_name = f"<func>{field_var_name}"
    output_vars = [v for v in ctx]
    yielded_states = []

    ordered_stmts = topological_sort(
            isolate_function_calls_in_phase(
                phase,
                dag.get_stmt_id_generator(),
                dag.get_var_name_generator()).statements,
            phase.depends_on)

    for stmt in ordered_stmts:
        if stmt.condition is not True:
            raise NotImplementedError(
                "non-True condition (in statement '%s') not supported"
                % stmt.id)

        if isinstance(stmt, lang.Nop):
            pass

        elif isinstance(stmt, lang.Assign):
            if not isinstance(stmt.lhs, p.Variable):
                raise NotImplementedError("lhs of statement %s is not a variable: %s"
                        % (stmt.id, stmt.lhs))
            ctx[stmt.lhs.name] = sym.cse(
                    DagrtToGrudgeRewriter(ctx)(stmt.rhs),
                    (
                        stmt.lhs.name
                        .replace("<", "")
                        .replace(">", "")))

        elif isinstance(stmt, lang.AssignFunctionCall):
            if stmt.function_id != rhs_name:
                raise NotImplementedError(
                        "statement '%s' calls unsupported function '%s'"
                        % (stmt.id, stmt.function_id))

            if stmt.parameters:
                raise NotImplementedError(
                    "statement '%s' calls function '%s' with positional arguments"
                    % (stmt.id, stmt.function_id))

            kwargs = {name: sym.cse(DagrtToGrudgeRewriter(ctx)(arg))
                      for name, arg in stmt.kw_parameters.items()}

            if len(stmt.assignees) != 1:
                raise NotImplementedError(
                    "statement '%s' calls function '%s' "
                    "with more than one LHS"
                    % (stmt.id, stmt.function_id))

            assignee, = stmt.assignees
            ctx[assignee] = GrudgeArgSubstitutor(kwargs)(sym_operator)

        elif isinstance(stmt, lang.YieldState):
            d2g = DagrtToGrudgeRewriter(ctx)
            yielded_states.append(
                    (
                        stmt.time_id,
                        d2g(stmt.time),
                        stmt.component_id,
                        d2g(stmt.expression)))

        else:
            raise NotImplementedError("statement %s is of unsupported type ''%s'"
                        % (stmt.id, type(stmt).__name__))

    return output_vars, [ctx[ov] for ov in output_vars], yielded_states
Пример #2
0
def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None,
        dumper=lambda name, sym_operator: None):

    orig_sym_operator = sym_operator
    import grudge.symbolic.mappers as mappers

    dumper("before-bind", sym_operator)
    sym_operator = mappers.OperatorBinder()(sym_operator)

    mappers.ErrorChecker(discrwb.mesh)(sym_operator)

    sym_operator = \
            mappers.OppositeInteriorFaceSwapUniqueIDAssigner()(sym_operator)

    # {{{ broadcast root rank's symn_operator

    # also make sure all ranks had same orig_sym_operator

    if discrwb.mpi_communicator is not None:
        (mgmt_rank_orig_sym_operator, mgmt_rank_sym_operator) = \
                discrwb.mpi_communicator.bcast(
                    (orig_sym_operator, sym_operator),
                    discrwb.get_management_rank_index())

        from pytools.obj_array import is_equal as is_oa_equal
        if not is_oa_equal(mgmt_rank_orig_sym_operator, orig_sym_operator):
            raise ValueError("rank %d received a different symbolic "
                    "operator to bind from rank %d"
                    % (discrwb.mpi_communicator.Get_rank(),
                        discrwb.get_management_rank_index()))

        sym_operator = mgmt_rank_sym_operator

    # }}}

    if post_bind_mapper is not None:
        dumper("before-postbind", sym_operator)
        sym_operator = post_bind_mapper(sym_operator)

    dumper("before-empty-flux-killer", sym_operator)
    sym_operator = mappers.EmptyFluxKiller(discrwb.mesh)(sym_operator)

    dumper("before-cfold", sym_operator)
    sym_operator = mappers.CommutativeConstantFoldingMapper()(sym_operator)

    dumper("before-qcheck", sym_operator)
    sym_operator = mappers.QuadratureCheckerAndRemover(
            discrwb.quad_tag_to_group_factory)(sym_operator)

    # Work around https://github.com/numpy/numpy/issues/9438
    #
    # The idea is that we need 1j as an expression to survive
    # until code generation time. If it is evaluated and combined
    # with other constants, we will need to determine its size
    # (as np.complex64/128) within the expression. But because
    # of the above numpy bug, sized numbers are not likely to survive
    # expression building--so that's why we step in here to fix that.

    dumper("before-csize", sym_operator)
    sym_operator = mappers.ConstantToNumpyConversionMapper(
            real_type=discrwb.real_dtype.type,
            complex_type=discrwb.complex_dtype.type,
            )(sym_operator)

    dumper("before-global-to-reference", sym_operator)
    sym_operator = mappers.GlobalToReferenceMapper(discrwb.ambient_dim)(sym_operator)

    dumper("before-distributed", sym_operator)

    volume_mesh = discrwb.discr_from_dd("vol").mesh
    from meshmode.distributed import get_connected_partitions
    connected_parts = get_connected_partitions(volume_mesh)

    if connected_parts:
        sym_operator = mappers.DistributedMapper(connected_parts)(sym_operator)

    dumper("before-imass", sym_operator)
    sym_operator = mappers.InverseMassContractor()(sym_operator)

    dumper("before-cfold-2", sym_operator)
    sym_operator = mappers.CommutativeConstantFoldingMapper()(sym_operator)

    # FIXME: Reenable derivative joiner
    # dumper("before-derivative-join", sym_operator)
    # sym_operator = mappers.DerivativeJoiner()(sym_operator)

    dumper("process-finished", sym_operator)

    return sym_operator