Ejemplo n.º 1
0
def process(func, state):
    """
    Apply ``func`` to the IETs in ``state._efuncs``, and update ``state`` accordingly.
    """
    # Create a Call graph. `func` will be applied to each node in the Call graph.
    # `func` might change an `efunc` signature; the Call graph will be used to
    # propagate such change through the `efunc` callers
    dag = DAG(nodes=['root'])
    queue = ['root']
    while queue:
        caller = queue.pop(0)
        callees = FindNodes(Call).visit(state._efuncs[caller])
        for callee in filter_ordered([i.name for i in callees]):
            if callee in state._efuncs:  # Exclude foreign Calls, e.g., MPI calls
                try:
                    dag.add_node(callee)
                    queue.append(callee)
                except KeyError:
                    # `callee` already in `dag`
                    pass
                dag.add_edge(callee, caller)
    assert dag.size == len(state._efuncs)

    # Apply `func`
    for i in dag.topological_sort():
        state._efuncs[i], metadata = func(state._efuncs[i])

        # Track any new Dimensions introduced by `func`
        state._dimensions.extend(list(metadata.get('dimensions', [])))

        # Track any new #include required by `func`
        state._includes.extend(list(metadata.get('includes', [])))
        state._includes = filter_ordered(state._includes)

        # Track any new ElementalFunctions
        state._efuncs.update(
            OrderedDict([(i.name, i) for i in metadata.get('efuncs', [])]))

        # If there's a change to the `args` and the `iet` is an efunc, then
        # we must update the call sites as well, as the arguments dropped down
        # to the efunc have just increased
        args = as_tuple(metadata.get('args'))
        if args:
            # `extif` avoids redundant updates to the parameters list, due
            # to multiple children wanting to add the same input argument
            extif = lambda v: list(v) + [e for e in args if e not in v]
            stack = [i] + dag.all_downstreams(i)
            for n in stack:
                efunc = state._efuncs[n]
                calls = [
                    c for c in FindNodes(Call).visit(efunc) if c.name in stack
                ]
                mapper = {
                    c: c._rebuild(arguments=extif(c.arguments))
                    for c in calls
                }
                efunc = Transformer(mapper).visit(efunc)
                if efunc.is_Callable:
                    efunc = efunc._rebuild(parameters=extif(efunc.parameters))
                state._efuncs[n] = efunc
Ejemplo n.º 2
0
class State(object):

    def __init__(self, iet):
        self._efuncs = OrderedDict([('main', iet)])
        self._dimensions = []
        self._input = []
        self._includes = []

        self._call_graph = DAG(nodes=['main'])

    def _process(self, func):
        """Apply ``func`` to all tracked ``IETs``."""

        for i in self._call_graph.topological_sort():
            self._efuncs[i], metadata = func(self._efuncs[i])

            # Track any new Dimensions and includes introduced by `func`
            self._dimensions.extend(list(metadata.get('dimensions', [])))
            self._includes.extend(list(metadata.get('includes', [])))

            # If there's a change to the `input` and the `iet` is an efunc, then
            # we must update the call sites as well, as the arguments dropped down
            # to the efunc have just increased
            _input = as_tuple(metadata.get('input'))
            if _input:
                # `extif` avoids redundant updates to the parameters list, due
                # to multiple children wanting to add the same input argument
                extif = lambda v: list(v) + [e for e in _input if e not in v]
                stack = [i] + self._call_graph.all_downstreams(i)
                for n in stack:
                    efunc = self._efuncs[n]
                    calls = [c for c in FindNodes(Call).visit(efunc) if c.name in stack]
                    mapper = {c: c._rebuild(arguments=extif(c.arguments)) for c in calls}
                    efunc = Transformer(mapper).visit(efunc)
                    if efunc.is_Callable:
                        efunc = efunc._rebuild(parameters=extif(efunc.parameters))
                    self._efuncs[n] = efunc
                self._input.extend(list(_input))

            for k, v in metadata.get('efuncs', {}).items():
                # Update the efuncs
                if k.is_Callable:
                    self._efuncs[k.name] = k
                # Update the call graph
                self._call_graph.add_node(k.name, ignore_existing=True)
                for target in (v or [None]):
                    self._call_graph.add_edge(k.name, target or 'main', force_add=True)

    @property
    def root(self):
        return self._efuncs['main']

    @property
    def efuncs(self):
        return tuple(v for k, v in self._efuncs.items() if k != 'main')

    @property
    def dimensions(self):
        return self._dimensions

    @property
    def input(self):
        return self._input

    @property
    def includes(self):
        return self._includes