def transcribe_phase(dag, field_var_name, field_components, phase_name, sym_operator): """Generate a Grudge operator for a Dagrt time integrator phase. Arguments: dag: The Dagrt code object for the time integrator field_var_name: The name of the simulation variable field_components: The number of components (fields) in the variable phase_name: The name of the phase to transcribe sym_operator: The Grudge symbolic expression to substitue for the right-hand side evaluation in the Dagrt code """ sym_operator = gmap.OperatorBinder()(sym_operator) phase = dag.phases[phase_name] ctx = { "<t>": sym.var("input_t", dof_desc.DD_SCALAR), "<dt>": sym.var("input_dt", dof_desc.DD_SCALAR), f"<state>{field_var_name}": sym.make_sym_array( f"input_{field_var_name}", field_components), "<p>residual": sym.make_sym_array( "input_residual", field_components), } rhs_name = f"<func>{field_var_name}" output_vars = [v for v in ctx] yielded_states = [] ordered_stmts = topological_sort( isolate_function_calls_in_phase( phase, dag.get_stmt_id_generator(), dag.get_var_name_generator()).statements, phase.depends_on) for stmt in ordered_stmts: if stmt.condition is not True: raise NotImplementedError( "non-True condition (in statement '%s') not supported" % stmt.id) if isinstance(stmt, lang.Nop): pass elif isinstance(stmt, lang.Assign): if not isinstance(stmt.lhs, p.Variable): raise NotImplementedError("lhs of statement %s is not a variable: %s" % (stmt.id, stmt.lhs)) ctx[stmt.lhs.name] = sym.cse( DagrtToGrudgeRewriter(ctx)(stmt.rhs), ( stmt.lhs.name .replace("<", "") .replace(">", ""))) elif isinstance(stmt, lang.AssignFunctionCall): if stmt.function_id != rhs_name: raise NotImplementedError( "statement '%s' calls unsupported function '%s'" % (stmt.id, stmt.function_id)) if stmt.parameters: raise NotImplementedError( "statement '%s' calls function '%s' with positional arguments" % (stmt.id, stmt.function_id)) kwargs = {name: sym.cse(DagrtToGrudgeRewriter(ctx)(arg)) for name, arg in stmt.kw_parameters.items()} if len(stmt.assignees) != 1: raise NotImplementedError( "statement '%s' calls function '%s' " "with more than one LHS" % (stmt.id, stmt.function_id)) assignee, = stmt.assignees ctx[assignee] = GrudgeArgSubstitutor(kwargs)(sym_operator) elif isinstance(stmt, lang.YieldState): d2g = DagrtToGrudgeRewriter(ctx) yielded_states.append( ( stmt.time_id, d2g(stmt.time), stmt.component_id, d2g(stmt.expression))) else: raise NotImplementedError("statement %s is of unsupported type ''%s'" % (stmt.id, type(stmt).__name__)) return output_vars, [ctx[ov] for ov in output_vars], yielded_states
def process_sym_operator(discrwb, sym_operator, post_bind_mapper=None, dumper=lambda name, sym_operator: None): orig_sym_operator = sym_operator import grudge.symbolic.mappers as mappers dumper("before-bind", sym_operator) sym_operator = mappers.OperatorBinder()(sym_operator) mappers.ErrorChecker(discrwb.mesh)(sym_operator) sym_operator = \ mappers.OppositeInteriorFaceSwapUniqueIDAssigner()(sym_operator) # {{{ broadcast root rank's symn_operator # also make sure all ranks had same orig_sym_operator if discrwb.mpi_communicator is not None: (mgmt_rank_orig_sym_operator, mgmt_rank_sym_operator) = \ discrwb.mpi_communicator.bcast( (orig_sym_operator, sym_operator), discrwb.get_management_rank_index()) from pytools.obj_array import is_equal as is_oa_equal if not is_oa_equal(mgmt_rank_orig_sym_operator, orig_sym_operator): raise ValueError("rank %d received a different symbolic " "operator to bind from rank %d" % (discrwb.mpi_communicator.Get_rank(), discrwb.get_management_rank_index())) sym_operator = mgmt_rank_sym_operator # }}} if post_bind_mapper is not None: dumper("before-postbind", sym_operator) sym_operator = post_bind_mapper(sym_operator) dumper("before-empty-flux-killer", sym_operator) sym_operator = mappers.EmptyFluxKiller(discrwb.mesh)(sym_operator) dumper("before-cfold", sym_operator) sym_operator = mappers.CommutativeConstantFoldingMapper()(sym_operator) dumper("before-qcheck", sym_operator) sym_operator = mappers.QuadratureCheckerAndRemover( discrwb.quad_tag_to_group_factory)(sym_operator) # Work around https://github.com/numpy/numpy/issues/9438 # # The idea is that we need 1j as an expression to survive # until code generation time. If it is evaluated and combined # with other constants, we will need to determine its size # (as np.complex64/128) within the expression. But because # of the above numpy bug, sized numbers are not likely to survive # expression building--so that's why we step in here to fix that. dumper("before-csize", sym_operator) sym_operator = mappers.ConstantToNumpyConversionMapper( real_type=discrwb.real_dtype.type, complex_type=discrwb.complex_dtype.type, )(sym_operator) dumper("before-global-to-reference", sym_operator) sym_operator = mappers.GlobalToReferenceMapper(discrwb.ambient_dim)(sym_operator) dumper("before-distributed", sym_operator) volume_mesh = discrwb.discr_from_dd("vol").mesh from meshmode.distributed import get_connected_partitions connected_parts = get_connected_partitions(volume_mesh) if connected_parts: sym_operator = mappers.DistributedMapper(connected_parts)(sym_operator) dumper("before-imass", sym_operator) sym_operator = mappers.InverseMassContractor()(sym_operator) dumper("before-cfold-2", sym_operator) sym_operator = mappers.CommutativeConstantFoldingMapper()(sym_operator) # FIXME: Reenable derivative joiner # dumper("before-derivative-join", sym_operator) # sym_operator = mappers.DerivativeJoiner()(sym_operator) dumper("process-finished", sym_operator) return sym_operator