def process(func, state): """ Apply ``func`` to the IETs in ``state._efuncs``, and update ``state`` accordingly. """ # Create a Call graph. `func` will be applied to each node in the Call graph. # `func` might change an `efunc` signature; the Call graph will be used to # propagate such change through the `efunc` callers dag = DAG(nodes=['root']) queue = ['root'] while queue: caller = queue.pop(0) callees = FindNodes(Call).visit(state._efuncs[caller]) for callee in filter_ordered([i.name for i in callees]): if callee in state._efuncs: # Exclude foreign Calls, e.g., MPI calls try: dag.add_node(callee) queue.append(callee) except KeyError: # `callee` already in `dag` pass dag.add_edge(callee, caller) assert dag.size == len(state._efuncs) # Apply `func` for i in dag.topological_sort(): state._efuncs[i], metadata = func(state._efuncs[i]) # Track any new Dimensions introduced by `func` state._dimensions.extend(list(metadata.get('dimensions', []))) # Track any new #include required by `func` state._includes.extend(list(metadata.get('includes', []))) state._includes = filter_ordered(state._includes) # Track any new ElementalFunctions state._efuncs.update( OrderedDict([(i.name, i) for i in metadata.get('efuncs', [])])) # If there's a change to the `args` and the `iet` is an efunc, then # we must update the call sites as well, as the arguments dropped down # to the efunc have just increased args = as_tuple(metadata.get('args')) if args: # `extif` avoids redundant updates to the parameters list, due # to multiple children wanting to add the same input argument extif = lambda v: list(v) + [e for e in args if e not in v] stack = [i] + dag.all_downstreams(i) for n in stack: efunc = state._efuncs[n] calls = [ c for c in FindNodes(Call).visit(efunc) if c.name in stack ] mapper = { c: c._rebuild(arguments=extif(c.arguments)) for c in calls } efunc = Transformer(mapper).visit(efunc) if efunc.is_Callable: efunc = efunc._rebuild(parameters=extif(efunc.parameters)) state._efuncs[n] = efunc
class State(object): def __init__(self, iet): self._efuncs = OrderedDict([('main', iet)]) self._dimensions = [] self._input = [] self._includes = [] self._call_graph = DAG(nodes=['main']) def _process(self, func): """Apply ``func`` to all tracked ``IETs``.""" for i in self._call_graph.topological_sort(): self._efuncs[i], metadata = func(self._efuncs[i]) # Track any new Dimensions and includes introduced by `func` self._dimensions.extend(list(metadata.get('dimensions', []))) self._includes.extend(list(metadata.get('includes', []))) # If there's a change to the `input` and the `iet` is an efunc, then # we must update the call sites as well, as the arguments dropped down # to the efunc have just increased _input = as_tuple(metadata.get('input')) if _input: # `extif` avoids redundant updates to the parameters list, due # to multiple children wanting to add the same input argument extif = lambda v: list(v) + [e for e in _input if e not in v] stack = [i] + self._call_graph.all_downstreams(i) for n in stack: efunc = self._efuncs[n] calls = [c for c in FindNodes(Call).visit(efunc) if c.name in stack] mapper = {c: c._rebuild(arguments=extif(c.arguments)) for c in calls} efunc = Transformer(mapper).visit(efunc) if efunc.is_Callable: efunc = efunc._rebuild(parameters=extif(efunc.parameters)) self._efuncs[n] = efunc self._input.extend(list(_input)) for k, v in metadata.get('efuncs', {}).items(): # Update the efuncs if k.is_Callable: self._efuncs[k.name] = k # Update the call graph self._call_graph.add_node(k.name, ignore_existing=True) for target in (v or [None]): self._call_graph.add_edge(k.name, target or 'main', force_add=True) @property def root(self): return self._efuncs['main'] @property def efuncs(self): return tuple(v for k, v in self._efuncs.items() if k != 'main') @property def dimensions(self): return self._dimensions @property def input(self): return self._input @property def includes(self): return self._includes