Example #1
0
def transitive_planner(op_list):
    """
    Create merged execution plan through transitive closure construction.

    This is something like a middle ground between `.greedy_planner` and
    `.tree_planner`; it can improve simulation time over the greedy
    planner, but comes with potentially significant build time increases.

    Parameters
    ----------
    op_list : list of `~nengo.builder.Operator`
        All the ``nengo`` operators in a model (unordered)

    Returns
    -------
    plan : list of tuple of `~nengo.builder.Operator`
        Operators combined into mergeable groups and in execution order
    """

    n_ele = len(op_list)
    merge_groups = {}
    dg = operator_dependency_graph(op_list)
    op_codes = {op: np.uint32(i) for i, op in enumerate(op_list)}
    dg = {op_codes[k]: set(op_codes[x] for x in v) for k, v in dg.items()}
    op_codes = {}  # so it will get garbage collected
    dg = BidirectionalDAG(dg)

    # fail fast here if the op graph has cycles
    toposort(dg.forward)

    builder_types = [builder.Builder.builders[type(op)] for op in op_list]

    # sort operators by builder (we'll only be interested in one builder type
    # at a time, because we can't merge operators between builder types anyway)
    ops_by_type = defaultdict(set)
    for i, op in enumerate(op_list):
        ops_by_type[builder_types[i]].add(np.uint32(i))

    # heuristic ordering for builder types (earlier items in the list will
    # have higher priority, meaning that we will choose to merge those ops
    # and potentially break lower-priority groups)
    order = [
        op_builders.SparseDotIncBuilder, op_builders.ElementwiseIncBuilder,
        neuron_builders.SimNeuronsBuilder, process_builders.SimProcessBuilder,
        op_builders.SimPyFuncBuilder, learning_rule_builders.SimOjaBuilder,
        learning_rule_builders.SimVojaBuilder,
        learning_rule_builders.SimBCMBuilder, op_builders.CopyBuilder,
        op_builders.ResetBuilder, tensor_node.SimTensorNodeBuilder]

    for builder_type in order:
        if builder_type not in ops_by_type:
            # no ops of this type in the model
            continue

        ops = ops_by_type[builder_type]

        # compute transitive closure
        trans = [None for _ in range(n_ele)]
        transitive_closure_recurse(dg.forward, ops, trans, builder_type,
                                   builder_types, {})

        # reduce it to the elements we care about (ops of the current
        # builder type)
        trans = {i: v for i, v in enumerate(trans[:len(op_list)]) if i in ops}

        while len(trans) > 0:
            # find all the ops that have no downstream dependents
            available = set(k for k, v in trans.items() if len(v) == 0)

            # sort those ops into mergeable groups
            groups = []
            for op in available:
                for g in groups:
                    if mergeable(op_list[op], (op_list[g[0]],)):
                        g.append(op)
                        break
                else:
                    groups.append([op])

            # merge the groups
            for g in groups:
                dg.merge(g, n_ele)
                merge_groups[n_ele] = g
                n_ele += 1

            # remove those ops from the transitive closure
            for op in available:
                del trans[op]

            # remove those ops from the transitive closure of upstream ops
            # note: we first remove all the duplicate aliased transitive sets,
            # to reduce the number of set operations we need to do
            unique_trans = {id(v): v for v in trans.values()}
            for t in unique_trans.values():
                t -= available

        # trans_reverse = [None for _ in range(n_ele)]
        # transitive_closure_recurse(dg.backward, ops, trans_reverse,
        #                            builder_type, builder_types, cache)
        # trans_reverse = {i: v for i, v in
        #                  enumerate(trans_reverse[:len(op_list)]) if i in ops}
        # group = None
        # for op in toposort(trans, trans_reverse):
        #     if group is None:
        #         group = [op]
        #         continue
        #
        #     if mergeable(op_list[op], (op_list[group[0]],)) and all(
        #             x not in trans[op] for x in group):
        #         group.append(op)
        #     else:
        #         dg.merge(group, n_ele)
        #         merge_groups[n_ele] = group
        #         n_ele += 1
        #         group = [op]
        #
        # dg.merge(group, n_ele)
        # merge_groups[n_ele] = group
        # n_ele += 1

        del ops_by_type[builder_type]

    assert len(ops_by_type) == 0

    # toposort the merged graph to come up with execution plan
    plan = toposort(dg.forward)
    plan = [tuple(op_list[x] for x in merge_groups[group]) for group in plan]

    logger.debug("TRANSITIVE PLAN")
    logger.debug("\n%s" * len(plan), *plan)

    return plan
Example #2
0
class OpMergePass:
    """Manages a single optimization pass."""

    def __init__(self, dg):
        self.dg = BidirectionalDAG(dg)
        self.might_merge = set(dg)
        self.sig_replacements = {}

        self.sig2ops = WeakKeyDefaultDict(WeakSet)
        self.base2views = WeakKeyDefaultDict(WeakSet)
        for op in self.dg.forward:
            for s in op.all_signals:
                self.sig2ops[s].add(op)
                self.base2views[s.base].add(s)

        # These variables will be initialized and used on each pass
        self.dependents = None
        self.only_merge_ops_with_view = None

        self.merged = set()
        self.merged_dependents = set()
        self.opinfo = OpInfo()

    def __call__(self, only_merge_ops_with_view):
        """Perform a single optimization pass.

        Parameters
        ----------
        only_merge_ops_with_view : bool
            Limits operator merges to operators with views.
        """

        # --- Initialize pass state
        self.dependents = transitive_closure(self.dg.forward)
        self.only_merge_ops_with_view = only_merge_ops_with_view
        self.merged.clear()
        self.merged_dependents.clear()
        self.opinfo.clear()

        # --- Do an optimization pass
        self.perform_merges()

    def perform_merges(self):
        """Go through all operators and merge them where possible.

        Parameters
        ----------
        only_merge_ops_with_view : bool
            Limit merges to operators with views.
        """

        # We go through the ops grouped by type as only ops with the same
        # type can be merged.
        by_type = groupby(self.might_merge, type)

        # Note that we will stop once we merge any operator, so merges are
        # performed on at most one type of operator per pass.
        # The dependency graph and other information will be updated
        # before merging other operator types.

        # We go through ops in a heuristic order to reduce runtime
        firstops = [ElementwiseInc, Copy, DotInc, SimNeurons]
        sortedops = firstops + [op for op in by_type if op not in firstops]
        for optype in sortedops:

            if OpMerger.is_type_mergeable(optype):
                self.perform_merges_for_subset(by_type[optype])

            # If we're not only merging views, the memory layout changes
            # and non-views are turned into views. In that case we need
            # to update the signals the operators are referring to before
            # trying to merge a different type of operators. Thus, we break
            # the loop here.
            if not self.only_merge_ops_with_view and len(self.merged) > 0:
                break

    def perform_merges_for_subset(self, subset):
        """Performs operator merges for a subset of operators.

        Parameters
        ----------
        subset : list
            Subset of operators.
        """
        by_view = groupby(subset, lambda op: self.opinfo.get(op).v_base)
        if self.only_merge_ops_with_view:
            if None in by_view:
                # If an op has no views, v_base will be None.
                # If we're only merging views, then we get rid of this subset.
                del by_view[None]

            for view_subset in by_view.values():
                if len(view_subset) > 1:
                    self.perform_merges_for_view_subset(view_subset)
        elif None in by_view and len(by_view[None]) > 1:
            self.perform_merges_for_view_subset(by_view[None])

    def perform_merges_for_view_subset(self, subset):
        """Perform merges for a subset of operators with the same view base.

        Parameters
        ----------
        subset : list
            Subset of operators. These need to have the same view base (can be
            None if it is None for all) for their first signal in
            ``all_signals``.
        """

        # Sort to have sequential memory.
        offsets = np.array(
            [self.opinfo.get(op).v_offset for op in subset], dtype=rc.float_dtype
        )
        sort_indices = np.argsort(offsets)
        offsets = offsets[sort_indices]
        sorted_subset = [subset[i] for i in sort_indices]

        for i, op in enumerate(sorted_subset):
            if op in self.merged:
                # Cannot merge merged operator again until dependency graph
                # has been updated
                continue

            if op in self.merged_dependents or any(
                o in self.merged for o in self.dependents[op]
            ):
                continue

            tomerge = OpsToMerge(
                op, self.merged, self.merged_dependents, self.dependents
            )

            # For a merge to be possible the view of the next operator has to
            # start where the view of op ends. Because we have sorted the
            # operators by the start of their views we can do a binary search
            # and potentially skip a number of operators at the beginning.
            start = np.searchsorted(
                offsets, offsets[i] + self.opinfo.get(op).v_size, side="left"
            )

            for op2 in sorted_subset[start:]:

                if tomerge.not_sequential(op2):
                    # If this check is true the view of op2 does not
                    # immediately follow the view of the operators being
                    # merged. Because we iterate over the operators sorted by
                    # view offset there will be a gap between the end of the
                    # operators being merged and all remaining operators to
                    # loop over. With such a gap a merge is never possible and
                    # we can cut the loop short.
                    break

                if op2 in self.merged:
                    continue

                if OpMerger.is_mergeable(op2, tomerge):
                    tomerge.add(op2)

            if len(tomerge.ops) > 1:
                self.merge(tomerge)
            elif self.only_merge_ops_with_view:
                self.might_merge.remove(op)

    def merge(self, tomerge):
        """Merges the given operators.

        This method will also update ``op_replacements``, ``sig_replacements``,
        and the internal list of merged operators to prevent further merges
        on the same operators before all required operators and signals have
        been replaced.
        """
        merged_op, merged_sig = OpMerger.merge(tomerge.ops)
        self.dg.merge(tomerge.ops, merged_op)

        # Update tracking what has been merged and might be mergeable later
        self.might_merge.difference_update(tomerge.ops)
        self.might_merge.add(merged_op)
        self.merged.update(tomerge.ops)
        self.merged_dependents.update(tomerge.all_dependents)

        for op in tomerge.ops:
            # Mark all operators referencing the same signals as merged
            # (even though they are not) to prevent them from getting
            # merged before their signals have been updated.
            for s in op.all_signals:
                self.merged.update(self.sig2ops[s])

        # Signal related updates
        self.resolve_views_on_replaced_signals(merged_sig)
        self.sig_replacements.update(merged_sig)
        self.replace_op_signals(merged_sig)
        self.update_signal_indexing(merged_op, merged_sig)

    def resolve_views_on_replaced_signals(self, replaced_signals):
        for sig in list(replaced_signals):
            for view in self.base2views[sig]:
                if view is sig:
                    continue
                assert view.base is sig
                base_replacement = replaced_signals[sig]
                offset = view.offset
                strides = tuple(
                    a // b * c
                    for a, b, c in zip_longest(
                        view.strides,
                        view.base.strides,
                        base_replacement.strides,
                        fillvalue=1,
                    )
                )
                if base_replacement.is_view:
                    offset += base_replacement.offset
                    base_replacement = base_replacement.base
                buf = base_replacement.initial_value
                initial_value = np.ndarray(
                    buffer=buf,
                    dtype=view.dtype,
                    shape=view.shape,
                    offset=offset,
                    strides=strides,
                )
                replaced_signals[view] = Signal(
                    initial_value,
                    name=view.name,
                    base=base_replacement,
                    readonly=view.readonly,
                    offset=offset,
                )

    def replace_op_signals(self, replaced_signals):
        ops = (op for s in replaced_signals for op in self.sig2ops[s])
        for v in ops:
            # Update the op's signals
            v.sets = [replaced_signals.get(s, s) for s in v.sets]
            v.incs = [replaced_signals.get(s, s) for s in v.incs]
            v.reads = [replaced_signals.get(s, s) for s in v.reads]
            v.updates = [replaced_signals.get(s, s) for s in v.updates]

    def update_signal_indexing(self, merged_op, replaced_signals):
        for s in merged_op.all_signals:
            self.sig2ops[s].add(merged_op)
            if s.is_view:
                self.base2views[s.base].add(s)

        for from_sig, to_sig in replaced_signals.items():
            self.sig2ops[to_sig] = self.sig2ops[from_sig]
            if to_sig.is_view:
                self.base2views[to_sig.base].add(to_sig)
Example #3
0
class OpMergePass:
    def __init__(self, dg):
        self.dg = BidirectionalDAG(dg)
        self.might_merge = set(dg)
        self.sig_replacements = {}

        self.sig2ops = WeakKeyDefaultDict(WeakSet)
        self.base2views = WeakKeyDefaultDict(WeakSet)
        for op in self.dg.forward:
            for s in op.all_signals:
                self.sig2ops[s].add(op)
                self.base2views[s.base].add(s)

        # These variables will be initialized and used on each pass
        self.dependents = None
        self.only_merge_ops_with_view = None

        self.merged = set()
        self.merged_dependents = set()
        self.opinfo = OpInfo()

    def __call__(self, only_merge_ops_with_view):
        """Perform a single optimization pass.

        Parameters
        ----------
        only_merge_ops_with_view : bool
            Limits operator merges to operators with views.
        """

        # --- Initialize pass state
        self.dependents = transitive_closure(self.dg.forward)
        self.only_merge_ops_with_view = only_merge_ops_with_view
        self.merged.clear()
        self.merged_dependents.clear()
        self.opinfo.clear()

        # --- Do an optimization pass
        self.perform_merges()

    def perform_merges(self):
        """Go through all operators and merge them where possible.

        Parameters
        ----------
        only_merge_ops_with_view : bool
            Limit merges to operators with views.
        """

        # We go through the ops grouped by type as only ops with the same
        # type can be merged.
        by_type = groupby(self.might_merge, type)

        # Note that we will stop once we merge any operator, so merges are
        # performed on at most one type of operator per pass.
        # The dependency graph and other information will be updated
        # before merging other operator types.

        # We go through ops in a heuristic order to reduce runtime
        firstops = [ElementwiseInc, Copy, DotInc, SimNeurons]
        sortedops = firstops + [op for op in by_type if op not in firstops]
        for optype in sortedops:

            if OpMerger.is_type_mergeable(optype):
                self.perform_merges_for_subset(by_type[optype])

            # If we're not only merging views, the memory layout changes
            # and non-views are turned into views. In that case we need
            # to update the signals the operators are referring to before
            # trying to merge a different type of operators. Thus, we break
            # the loop here.
            if not self.only_merge_ops_with_view and len(self.merged) > 0:
                break

    def perform_merges_for_subset(self, subset):
        """Performs operator merges for a subset of operators.

        Parameters
        ----------
        subset : list
            Subset of operators.
        """
        by_view = groupby(subset, lambda op: self.opinfo[op].v_base)
        if self.only_merge_ops_with_view:
            if None in by_view:
                # If an op has no views, v_base will be None.
                # If we're only merging views, then we get rid of this subset.
                del by_view[None]

            for view_subset in by_view.values():
                if len(view_subset) > 1:
                    self.perform_merges_for_view_subset(view_subset)
        elif None in by_view and len(by_view[None]) > 1:
            self.perform_merges_for_view_subset(by_view[None])

    def perform_merges_for_view_subset(self, subset):
        """Perform merges for a subset of operators with the same view base.

        Parameters
        ----------
        subset : list
            Subset of operators. These need to have the same view base (can be
            None if it is None for all) for their first signal in
            `all_signals`.
        """

        # Sort to have sequential memory.
        offsets = np.array([self.opinfo[op].v_offset for op in subset])
        sort_indices = np.argsort(offsets)
        offsets = offsets[sort_indices]
        sorted_subset = [subset[i] for i in sort_indices]

        for i, op1 in enumerate(sorted_subset):
            if op1 in self.merged:
                # Cannot merge merged operator again until dependency graph
                # has been updated
                continue

            if op1 in self.merged_dependents or any(
                    op in self.merged for op in self.dependents[op1]):
                continue

            tomerge = OpsToMerge(op1, self.merged, self.merged_dependents,
                                 self.dependents)

            # For a merge to be possible the view of the next operator has to
            # start where the view of op1 ends. Because we have sorted the
            # operators by the start of their views we can do a binary search
            # and potentially skip a number of operators at the beginning.
            start = np.searchsorted(
                offsets, offsets[i] + self.opinfo[op1].v_size, side='left')

            for op2 in sorted_subset[start:]:

                if tomerge.not_sequential(op2):
                    # If this check is true the view of op2 does not
                    # immediately follow the view of the operators being
                    # merged. Because we iterate over the operators sorted by
                    # view offset there will be a gap between the end of the
                    # operators being merged and all remaining operators to
                    # loop over. With such a gap a merge is never possible and
                    # we can cut the loop short.
                    break

                if op2 in self.merged:
                    continue

                if OpMerger.is_mergeable(op2, tomerge):
                    tomerge.add(op2)

            if len(tomerge.ops) > 1:
                self.merge(tomerge)
            elif self.only_merge_ops_with_view:
                self.might_merge.remove(op1)

    def merge(self, tomerge):
        """Merges the given operators.

        This method will also update ``op_replacements``, ``sig_replacements``,
        and the internal list of merged operators to prevent further merges
        on the same operators before all required operators and signals have
        been replaced.
        """
        merged_op, merged_sig = OpMerger.merge(tomerge.ops)
        self.dg.merge(tomerge.ops, merged_op)

        # Update tracking what has been merged and might be mergeable later
        self.might_merge.difference_update(tomerge.ops)
        self.might_merge.add(merged_op)
        self.merged.update(tomerge.ops)
        self.merged_dependents.update(tomerge.all_dependents)

        for op in tomerge.ops:
            # Mark all operators referencing the same signals as merged
            # (even though they are not) to prevent them from getting
            # merged before their signals have been updated.
            for s in op.all_signals:
                self.merged.update(self.sig2ops[s])

        # Signal related updates
        self.resolve_views_on_replaced_signals(merged_sig)
        self.sig_replacements.update(merged_sig)
        self.replace_op_signals(merged_sig)
        self.update_signal_indexing(merged_op, merged_sig)

    def resolve_views_on_replaced_signals(self, replaced_signals):
        for sig in list(replaced_signals):
            for view in self.base2views[sig]:
                if view is sig:
                    continue
                assert view.base is sig
                base_replacement = replaced_signals[sig]
                offset = view.offset
                strides = tuple(
                    a // b * c for a, b, c in zip_longest(
                        view.strides,
                        view.base.strides,
                        base_replacement.strides,
                        fillvalue=1))
                if base_replacement.is_view:
                    offset += base_replacement.offset
                    base_replacement = base_replacement.base
                buf = base_replacement.initial_value
                initial_value = np.ndarray(buffer=buf,
                                           dtype=view.dtype,
                                           shape=view.shape,
                                           offset=offset,
                                           strides=strides)
                replaced_signals[view] = Signal(initial_value,
                                                name=view.name,
                                                base=base_replacement,
                                                readonly=view.readonly,
                                                offset=offset)

    def replace_op_signals(self, replaced_signals):
        ops = (op for s in replaced_signals for op in self.sig2ops[s])
        for v in ops:
            # Update the op's signals
            v.sets = [replaced_signals.get(s, s) for s in v.sets]
            v.incs = [replaced_signals.get(s, s) for s in v.incs]
            v.reads = [replaced_signals.get(s, s) for s in v.reads]
            v.updates = [replaced_signals.get(s, s) for s in v.updates]

    def update_signal_indexing(self, merged_op, replaced_signals):
        for s in merged_op.all_signals:
            self.sig2ops[s].add(merged_op)
            if s.is_view:
                self.base2views[s.base].add(s)

        for from_sig, to_sig in replaced_signals.items():
            self.sig2ops[to_sig] = self.sig2ops[from_sig]
            if to_sig.is_view:
                self.base2views[to_sig.base].add(to_sig)