Exemplo n.º 1
0
    def _check_graph(self, out_stream=sys.stdout):
        # Cycles in group w/o solver
        cgraph = self.root._relevance._cgraph
        for grp in self.root.subgroups(recurse=True, include_self=True):
            path = [] if not grp.pathname else grp.pathname.split('.')
            graph = cgraph.subgraph([n for n in cgraph if n.startswith(grp.pathname)])
            renames = {}
            for node in graph.nodes_iter():
                renames[node] = '.'.join(node.split('.')[:len(path)+1])
                if renames[node] == node:
                    del renames[node]

            # get the graph of direct children of current group
            nx.relabel_nodes(graph, renames, copy=False)

            # remove self loops created by renaming
            graph.remove_edges_from([(u,v) for u,v in graph.edges()
                                         if u==v])

            strong = [s for s in nx.strongly_connected_components(graph)
                        if len(s)>1]

            if strong and isinstance(grp.nl_solver, RunOnce): # no solver, cycles BAD
                relstrong = []
                for slist in strong:
                    relstrong.append([])
                    for s in slist:
                        relstrong[-1].append(name_relative_to(grp.pathname, s))
                        relstrong[-1] = sorted(relstrong[-1])
                print("Group '%s' has the following cycles: %s" %
                     (grp.pathname, relstrong), file=out_stream)

            # Components/Systems/Groups are not in the right execution order
            subnames = [s.pathname for s in grp.subsystems()]
            while strong:
                # break cycles to check order
                lsys = [s for s in subnames if s in strong[0]]
                for p in graph.predecessors(lsys[0]):
                    if p in lsys:
                        graph.remove_edge(p, lsys[0])
                strong = [s for s in nx.strongly_connected_components(graph)
                            if len(s)>1]

            visited = set()
            out_of_order = set()
            for sub in grp.subsystems():
                visited.add(sub.pathname)
                for u,v in nx.dfs_edges(graph, sub.pathname):
                    if v in visited:
                        out_of_order.add(v)

            if out_of_order:
                print("In group '%s', the following subsystems are out-of-order: %s" %
                      (grp.pathname, sorted([name_relative_to(grp.pathname, n)
                                                for n in out_of_order])), file=out_stream)
Exemplo n.º 2
0
    def _setup_data_transfer(self, my_params, relevance, var_of_interest):
        """
        Create `DataXfer` objects to handle data transfer for all of the
        connections that involve parameters for which this `Group`
        is responsible.

        Args
        ----

        my_params : list
            List of pathnames for parameters that the `Group` is
            responsible for propagating.

        relevance : `Relevance`
            An object containing info about what variables are relevant
            to a variable of interest.

        var_of_interest : str or None
            The name of a variable of interest.

        """

        xfer_dict = {}
        for param, unknown in self.connections.items():
            if not (relevance.is_relevant(var_of_interest, param) or
                    relevance.is_relevant(var_of_interest, unknown)):
                continue

            if param in my_params:
                # remove our system pathname from the abs pathname of the param and
                # get the subsystem name from that

                tgt_sys = name_relative_to(self.pathname, param)
                src_sys = name_relative_to(self.pathname, unknown)

                for mode, sname in (('fwd', tgt_sys), ('rev', src_sys)):
                    src_idx_list, dest_idx_list, vec_conns, byobj_conns = \
                        xfer_dict.setdefault((sname, mode), ([], [], [], []))

                    urelname = self.unknowns.get_promoted_varname(unknown)
                    prelname = self.params.get_promoted_varname(param)

                    if self.unknowns.metadata(urelname).get('pass_by_obj'):
                        # rev is for derivs only, so no by_obj passing needed
                        if mode == 'fwd':
                            byobj_conns.append((prelname, urelname))
                    else: # pass by vector
                        sidxs, didxs = self._get_global_idxs(urelname, prelname,
                                                             var_of_interest, mode)
                        vec_conns.append((prelname, urelname))
                        src_idx_list.append(sidxs)
                        dest_idx_list.append(didxs)

        for (tgt_sys, mode), (srcs, tgts, vec_conns, byobj_conns) in xfer_dict.items():
            src_idxs, tgt_idxs = self.unknowns.merge_idxs(srcs, tgts)
            if vec_conns or byobj_conns:
                self._data_xfer[(tgt_sys, mode, var_of_interest)] = \
                    self._impl_factory.create_data_xfer(self.dumat[var_of_interest],
                                                        self.dpmat[var_of_interest],
                                                        src_idxs, tgt_idxs,
                                                        vec_conns, byobj_conns)

        # create a DataXfer object that combines all of the
        # individual subsystem src_idxs, tgt_idxs, and byobj_conns, so that a 'full'
        # scatter to all subsystems can be done at the same time.  Store that DataXfer
        # object under the name ''.

        for mode in ('fwd', 'rev'):
            full_srcs = []
            full_tgts = []
            full_flats = []
            full_byobjs = []
            for (tgt_sys, direction), (srcs, tgts, flats, byobjs) in xfer_dict.items():
                if mode == direction:
                    full_srcs.extend(srcs)
                    full_tgts.extend(tgts)
                    full_flats.extend(flats)
                    full_byobjs.extend(byobjs)

            src_idxs, tgt_idxs = self.unknowns.merge_idxs(full_srcs, full_tgts)
            self._data_xfer[('', mode, var_of_interest)] = \
                self._impl_factory.create_data_xfer(self.dumat[var_of_interest],
                                                    self.dpmat[var_of_interest],
                                                    src_idxs, tgt_idxs,
                                                    full_flats, full_byobjs)