示例#1
0
    def plot(self, fnum, pnum):
        self.infr.update_visual_attrs(self.show_cuts)

        layoutkw = dict(prog='neato', splines='spline', sep=10 / 72)

        # draw_implicit=self.show_cuts)
        self.plotinfo = pt.show_nx(
            self.infr.graph,
            as_directed=False,
            fnum=self.fnum,
            layoutkw=layoutkw,
            # node_labels=True,
            modify_ax=False,
            use_image=self.use_image,
            verbose=0,
        )

        ut.util_graph.graph_info(self.infr.graph, verbose=True)

        # _, edge_weights, edge_colors = self.infr.get_colored_edge_weights()
        # pt.colorbar(edge_weights, edge_colors, lbl='weights')

        # _normal_ticks = np.linspace(0, 1, num=11)
        # _normal_scores = np.linspace(0, 1, num=500)
        # _normal_colors = self.infr.get_colored_weights(_normal_scores)
        # cb = pt.colorbar(_normal_scores, _normal_colors, lbl='weights',
        #                  ticklabels=_normal_ticks)

        # cb.ax.annotate('threshold',
        #                xy=(1, self.infr.thresh),
        #                xytext=(2.5, .3 if self.infr.thresh < .5 else .7),
        #                arrowprops=dict(
        #                    alpha=.5,
        #                    fc="0.6",
        #                    connectionstyle="angle3,angleA=90,angleB=0"),)

        ax = pt.gca()
        self.enable_pan_and_zoom(ax)
        # ax.autoscale()
        for aid in self.selected_aids:
            self.highlight_aid(aid, pt.ORANGE)
        # self.static_plot(fnum, pnum)
        self.make_hud()
        # logger.info(ut.repr2(self.infr.graph.edges, nl=2))
        logger.info('Finished Plot')
def viz_factor_graph(gm):
    """
    ut.qtensure()
    gm = build_factor_graph(G, nodes, edges , n_annots, n_names, lookup_annot_idx,
                            use_unaries=True, edge_probs=None, operator='multiplier')
    """
    ut.qtensure()
    import networkx
    from networkx.drawing.nx_agraph import graphviz_layout

    networkx.graphviz_layout = graphviz_layout
    opengm.visualizeGm(
        gm,
        show=False,
        layout='neato',
        plotUnaries=True,
        iterations=1000,
        plotFunctions=True,
        plotNonShared=False,
        relNodeSize=1.0,
    )

    _ = pt.show_nx(gm.G)  # NOQA
示例#3
0
def draw_em_graph(P, Pn, PL, gam, num_labels):
    """
    python -m wbia.algo.hots.testem test_em --show --no-cnn
    """
    num_labels = PL.shape[1]
    name_nodes = ['N%d' % x for x in list(range(1, num_labels + 1))]
    annot_nodes = ['X%d' % x for x in list(range(1, len(Pn) + 1))]

    nodes = name_nodes + annot_nodes

    PL2 = gam[:, num_labels:].T
    PL2 += 0.01
    PL2 = PL2 / PL2.sum(axis=1)[:, None]
    # PL2 = PL2 / np.linalg.norm(PL2, axis=0)
    zero_part = np.zeros((num_labels, len(Pn) + num_labels))
    prob_part = np.hstack([PL2, Pn])
    logger.info(ut.hz_str(' PL2 = ', ut.repr2(PL2, precision=2)))
    # Redo p with posteriors
    if ut.get_argflag('--postem'):
        P = np.vstack([zero_part, prob_part])

    weight_matrix = P  # NOQA
    graph = ut.nx_from_matrix(P, nodes=nodes)
    graph = graph.to_directed()
    # delete graph
    dup_edges = []
    seen_ = set([])
    for u, v in graph.edges():
        if u < v:
            u, v = v, u
        if (u, v) not in seen_:
            seen_.add((u, v))
        else:
            dup_edges.append((u, v))
    graph.remove_edges_from(dup_edges)
    import wbia.plottool as pt
    import networkx as nx

    if len(name_nodes) == 3 and len(annot_nodes) == 4:
        graph.nodes[annot_nodes[0]]['pos'] = (20.0, 200.0)
        graph.nodes[annot_nodes[1]]['pos'] = (220.0, 200.0)
        graph.nodes[annot_nodes[2]]['pos'] = (20.0, 100.0)
        graph.nodes[annot_nodes[3]]['pos'] = (220.0, 100.0)
        graph.nodes[name_nodes[0]]['pos'] = (10.0, 300.0)
        graph.nodes[name_nodes[1]]['pos'] = (120.0, 300.0)
        graph.nodes[name_nodes[2]]['pos'] = (230.0, 300.0)
        nx.set_node_attributes(graph, name='pin', values='true')

        logger.info('annot_nodes = %r' % (annot_nodes,))
        logger.info('name_nodes = %r' % (name_nodes,))

        for u in annot_nodes:
            for v in name_nodes:
                if graph.has_edge(u, v):
                    logger.info('1) u, v = %r' % ((u, v),))
                    graph.edge[u][v]['taillabel'] = graph.edge[u][v]['label']
                    graph.edge[u][v]['color'] = pt.ORANGE
                    graph.edge[u][v]['labelcolor'] = pt.BLUE
                    del graph.edge[u][v]['label']
                elif graph.has_edge(v, u):
                    logger.info('2) u, v = %r' % ((u, v),))
                    graph.edge[v][u]['headlabel'] = graph.edge[v][u]['label']
                    graph.edge[v][u]['color'] = pt.ORANGE
                    graph.edge[v][u]['labelcolor'] = pt.BLUE
                    del graph.edge[v][u]['label']
                else:
                    logger.info((u, v))
                    logger.info('!!')

    # import itertools
    # name_const_edges = [(u, v, {'style': 'invis'}) for u, v in itertools.combinations(name_nodes, 2)]
    # graph.add_edges_from(name_const_edges)
    # nx.set_edge_attributes(graph, name='constraint', values={edge: False for edge in graph.edges() if edge[0] == 'b' or edge[1] == 'b'})
    # nx.set_edge_attributes(graph, name='constraint', values={edge: False for edge in graph.edges() if edge[0] in annot_nodes and edge[1] in annot_nodes})
    # nx.set_edge_attributes(graph, name='constraint', values={edge: True for edge in graph.edges() if edge[0] in name_nodes or edge[1] in name_nodes})
    # nx.set_edge_attributes(graph, name='constraint', values={edge: True for edge in graph.edges() if (edge[0] in ['a', 'b'] and edge[1] in ['a', 'b']) and edge[0] in annot_nodes and edge[1] in annot_nodes})
    # nx.set_edge_attributes(graph, name='constraint', values={edge: True for edge in graph.edges() if (edge[0] in ['c'] or edge[1] in ['c']) and edge[0] in annot_nodes and edge[1] in annot_nodes})
    # nx.set_edge_attributes(graph, name='constraint', values={edge: True for edge in graph.edges() if (edge[0] in ['a'] or edge[1] in ['a']) and edge[0] in annot_nodes and edge[1] in annot_nodes})
    # nx.set_edge_attributes(graph, name='constraint', values={edge: True for edge in graph.edges() if (edge[0] in ['b'] or edge[1] in ['b']) and edge[0] in annot_nodes and edge[1] in annot_nodes})
    # graph.add_edges_from([('root', n) for n in nodes])
    # {node: 'names' for node in name_nodes})
    nx.set_node_attributes(
        graph, name='color', values={node: pt.RED for node in name_nodes}
    )
    # nx.set_node_attributes(graph, name='width', values={node: 20 for node in nodes})
    # nx.set_node_attributes(graph, name='height', values={node: 20 for node in nodes})
    # nx.set_node_attributes(graph, name='group', values={node: 'names' for node in name_nodes})
    # nx.set_node_attributes(graph, name='group', values={node: 'annots' for node in annot_nodes})
    nx.set_node_attributes(
        graph, name='groupid', values={node: 'names' for node in name_nodes}
    )
    nx.set_node_attributes(
        graph, name='groupid', values={node: 'annots' for node in annot_nodes}
    )
    graph.graph['clusterrank'] = 'local'
    # graph.graph['groupattrs'] = {
    #     'names': {'rankdir': 'LR', 'rank': 'source'},
    #     'annots': {'rankdir': 'TB', 'rank': 'source'},
    # }
    ut.nx_delete_edge_attr(graph, 'weight')
    # pt.show_nx(graph, fontsize=10, layoutkw={'splines': 'spline', 'prog': 'dot', 'sep': 2.0}, verbose=1)
    layoutkw = {
        # 'rankdir': 'LR',
        'splines': 'spline',
        # 'splines': 'ortho',
        # 'splines': 'curved',
        # 'compound': 'True',
        # 'prog': 'dot',
        'prog': 'neato',
        # 'packMode': 'clust',
        # 'sep': 4,
        # 'nodesep': 1,
        # 'ranksep': 1,
    }
    # pt.show_nx(graph, fontsize=12, layoutkw=layoutkw, verbose=0, as_directed=False)
    pt.show_nx(
        graph,
        fontsize=6,
        fontname='Ubuntu',
        layoutkw=layoutkw,
        verbose=0,
        as_directed=False,
    )
    pt.interactions.zoom_factory()
示例#4
0
def viz_netx_chipgraph(
    ibs,
    graph,
    fnum=None,
    use_image=False,
    layout=None,
    zoom=None,
    prog='neato',
    as_directed=False,
    augment_graph=True,
    layoutkw=None,
    framewidth=True,
    **kwargs,
):
    r"""
    DEPRICATE or improve

    Args:
        ibs (IBEISController):  wbia controller object
        graph (nx.DiGraph):
        fnum (int):  figure number(default = None)
        use_image (bool): (default = False)
        zoom (float): (default = 0.4)

    Returns:
        ?: pos

    CommandLine:
        python -m wbia --tf viz_netx_chipgraph --show

    Cand:
        wbia review_tagged_joins --save figures4/mergecase.png --figsize=15,15
            --clipwhite --diskshow
        wbia compute_occurrence_groups --save figures4/occurgraph.png
            --figsize=40,40 --clipwhite --diskshow
        ~/code/wbia/wbia/algo/preproc/preproc_occurrence.py

    Example:
        >>> # DISABLE_DOCTEST
        >>> from wbia.viz.viz_graph import *  # NOQA
        >>> import wbia
        >>> ibs = wbia.opendb(defaultdb='PZ_MTEST')
        >>> nid_list = ibs.get_valid_nids()[0:10]
        >>> fnum = None
        >>> use_image = True
        >>> zoom = 0.4
        >>> make_name_graph_interaction(ibs, nid_list, prog='neato')
        >>> ut.show_if_requested()
    """
    import wbia.plottool as pt

    logger.info('[viz_graph] drawing chip graph')
    fnum = pt.ensure_fnum(fnum)
    pt.figure(fnum=fnum, pnum=(1, 1, 1))
    ax = pt.gca()

    if layout is None:
        layout = 'agraph'
    logger.info('layout = %r' % (layout, ))

    if use_image:
        ensure_node_images(ibs, graph)
    nx.set_node_attributes(graph, name='shape', values='rect')

    if layoutkw is None:
        layoutkw = {}
    layoutkw['prog'] = layoutkw.get('prog', prog)
    layoutkw.update(kwargs)

    if prog == 'neato':
        graph = graph.to_undirected()

    plotinfo = pt.show_nx(
        graph,
        ax=ax,
        # img_dict=img_dict,
        layout=layout,
        # hacknonode=bool(use_image),
        layoutkw=layoutkw,
        as_directed=as_directed,
        framewidth=framewidth,
    )
    return plotinfo
示例#5
0
    def show_graph(
            infr,
            graph=None,
            use_image=False,
            update_attrs=True,
            with_colorbar=False,
            pnum=(1, 1, 1),
            zoomable=True,
            pickable=False,
            **kwargs,
    ):
        r"""
        Args:
            infr (?):
            graph (None): (default = None)
            use_image (bool): (default = False)
            update_attrs (bool): (default = True)
            with_colorbar (bool): (default = False)
            pnum (tuple):  plot number(default = (1, 1, 1))
            zoomable (bool): (default = True)
            pickable (bool): (de = False)
            **kwargs: verbose, with_labels, fnum, layout, ax, pos, img_dict,
                      title, layoutkw, framewidth, modify_ax, as_directed,
                      hacknoedge, hacknode, node_labels, arrow_width, fontsize,
                      fontweight, fontname, fontfamilty, fontproperties

        CommandLine:
            python -m wbia.algo.graph.mixin_viz GraphVisualization.show_graph --show

        Example:
            >>> # xdoctest: +REQUIRES(module:pygraphviz)
            >>> # ENABLE_DOCTEST
            >>> from wbia.algo.graph.mixin_viz import *  # NOQA
            >>> from wbia.algo.graph import demo
            >>> import wbia.plottool as pt
            >>> infr = demo.demodata_infr(ccs=ut.estarmap(
            >>>    range, [(1, 6), (6, 10), (10, 13), (13, 15), (15, 16),
            >>>            (17, 20)]))
            >>> pnum_ = pt.make_pnum_nextgen(nRows=1, nCols=3)
            >>> infr.show_graph(show_cand=True, simple_labels=True, pickable=True, fnum=1, pnum=pnum_())
            >>> infr.add_feedback((1, 5), INCMP)
            >>> infr.add_feedback((14, 18), INCMP)
            >>> infr.refresh_candidate_edges()
            >>> infr.show_graph(show_cand=True, simple_labels=True, pickable=True, fnum=1, pnum=pnum_())
            >>> infr.add_feedback((17, 18), NEGTV)  # add inconsistency
            >>> infr.apply_nondynamic_update()
            >>> infr.show_graph(show_cand=True, simple_labels=True, pickable=True, fnum=1, pnum=pnum_())
            >>> ut.show_if_requested()
        """
        import wbia.plottool as pt

        if graph is None:
            graph = infr.graph
        # kwargs['fontsize'] = kwargs.get('fontsize', 8)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            # default_update_kw = ut.get_func_kwargs(infr.update_visual_attrs)
            # update_kw = ut.update_existing(default_update_kw, kwargs)
            # infr.update_visual_attrs(**update_kw)
            if update_attrs:
                infr.update_visual_attrs(graph=graph, **kwargs)
            verbose = kwargs.pop('verbose', infr.verbose)
            pt.show_nx(
                graph,
                layout='custom',
                as_directed=False,
                modify_ax=False,
                use_image=use_image,
                pnum=pnum,
                verbose=verbose,
                **kwargs,
            )
            if zoomable:
                pt.zoom_factory()
                pt.pan_factory(pt.gca())

        # if with_colorbar:
        #     # Draw a colorbar
        #     _normal_ticks = np.linspace(0, 1, num=11)
        #     _normal_scores = np.linspace(0, 1, num=500)
        #     _normal_colors = infr.get_colored_weights(_normal_scores)
        #     cb = pt.colorbar(_normal_scores, _normal_colors, lbl='weights',
        #                      ticklabels=_normal_ticks)

        #     # point to threshold location
        #     thresh = None
        #     if thresh is not None:
        #         xy = (1, thresh)
        #         xytext = (2.5, .3 if thresh < .5 else .7)
        #         cb.ax.annotate('threshold', xy=xy, xytext=xytext,
        #                        arrowprops=dict(
        #                            alpha=.5, fc="0.6",
        #                            connectionstyle="angle3,angleA=90,angleB=0"),)

        # infr.graph
        if graph.graph.get('dark_background', None):
            pt.dark_background(force=True)

        if pickable:
            fig = pt.gcf()
            fig.canvas.mpl_connect('pick_event', ut.partial(on_pick,
                                                            infr=infr))
示例#6
0
def draw_twoday_count(ibs, visit_info_list_):
    import copy

    visit_info_list = copy.deepcopy(visit_info_list_)

    aids_day1, aids_day2 = ut.take_column(visit_info_list_, 'aids')
    nids_day1, nids_day2 = ut.take_column(visit_info_list_, 'unique_nids')
    resight_nids = ut.isect(nids_day1, nids_day2)

    if False:
        # HACK REMOVE DATA TO MAKE THIS FASTER
        num = 20
        for info in visit_info_list:
            non_resight_nids = list(
                set(info['unique_nids']) - set(resight_nids))
            sample_nids2 = non_resight_nids[0:num] + resight_nids[:num]
            info['grouped_aids'] = ut.dict_subset(info['grouped_aids'],
                                                  sample_nids2)
            info['unique_nids'] = sample_nids2

    # Build a graph of matches
    if False:

        debug = False

        for info in visit_info_list:
            edges = []
            grouped_aids = info['grouped_aids']

            aids_list = list(grouped_aids.values())
            ams_list = ibs.get_annotmatch_rowids_in_cliques(aids_list)
            aids1_list = ibs.unflat_map(ibs.get_annotmatch_aid1, ams_list)
            aids2_list = ibs.unflat_map(ibs.get_annotmatch_aid2, ams_list)
            for ams, aids, aids1, aids2 in zip(ams_list, aids_list, aids1_list,
                                               aids2_list):
                edge_nodes = set(aids1 + aids2)
                # #if len(edge_nodes) != len(set(aids)):
                #    #logger.info('--')
                #    #logger.info('aids = %r' % (aids,))
                #    #logger.info('edge_nodes = %r' % (edge_nodes,))
                bad_aids = edge_nodes - set(aids)
                if len(bad_aids) > 0:
                    logger.info('bad_aids = %r' % (bad_aids, ))
                unlinked_aids = set(aids) - edge_nodes
                mst_links = list(
                    ut.itertwo(list(unlinked_aids) + list(edge_nodes)[:1]))
                bad_aids.add(None)
                user_links = [(u, v) for (u, v) in zip(aids1, aids2)
                              if u not in bad_aids and v not in bad_aids]
                new_edges = mst_links + user_links
                new_edges = [(int(u), int(v)) for u, v in new_edges
                             if u not in bad_aids and v not in bad_aids]
                edges += new_edges
            info['edges'] = edges

        # Add edges between days
        grouped_aids1, grouped_aids2 = ut.take_column(visit_info_list,
                                                      'grouped_aids')
        nids_day1, nids_day2 = ut.take_column(visit_info_list, 'unique_nids')
        resight_nids = ut.isect(nids_day1, nids_day2)

        resight_aids1 = ut.take(grouped_aids1, resight_nids)
        resight_aids2 = ut.take(grouped_aids2, resight_nids)
        # resight_aids3 = [list(aids1) + list(aids2) for aids1, aids2 in zip(resight_aids1, resight_aids2)]

        ams_list = ibs.get_annotmatch_rowids_between_groups(
            resight_aids1, resight_aids2)
        aids1_list = ibs.unflat_map(ibs.get_annotmatch_aid1, ams_list)
        aids2_list = ibs.unflat_map(ibs.get_annotmatch_aid2, ams_list)

        between_edges = []
        for ams, aids1, aids2, rawaids1, rawaids2 in zip(
                ams_list, aids1_list, aids2_list, resight_aids1,
                resight_aids2):
            link_aids = aids1 + aids2
            rawaids3 = rawaids1 + rawaids2
            badaids = ut.setdiff(link_aids, rawaids3)
            assert not badaids
            user_links = [(int(u), int(v)) for (u, v) in zip(aids1, aids2)
                          if u is not None and v is not None]
            # HACK THIS OFF
            user_links = []
            if len(user_links) == 0:
                # Hack in an edge
                between_edges += [(rawaids1[0], rawaids2[0])]
            else:
                between_edges += user_links

        assert np.all(0 == np.diff(np.array(
            ibs.unflat_map(ibs.get_annot_nids, between_edges)),
                                   axis=1))

        import wbia.plottool as pt
        import networkx as nx

        # pt.qt4ensure()
        # len(list(nx.connected_components(graph1)))
        # logger.info(ut.graph_info(graph1))

        # Layout graph
        layoutkw = dict(
            prog='neato',
            draw_implicit=False,
            splines='line',
            # splines='curved',
            # splines='spline',
            # sep=10 / 72,
            # prog='dot', rankdir='TB',
        )

        def translate_graph_to_origin(graph):
            x, y, w, h = ut.get_graph_bounding_box(graph)
            ut.translate_graph(graph, (-x, -y))

        def stack_graphs(graph_list, vert=False, pad=None):
            graph_list_ = [g.copy() for g in graph_list]
            for g in graph_list_:
                translate_graph_to_origin(g)
            bbox_list = [ut.get_graph_bounding_box(g) for g in graph_list_]
            if vert:
                dim1 = 3
                dim2 = 2
            else:
                dim1 = 2
                dim2 = 3
            dim1_list = np.array([bbox[dim1] for bbox in bbox_list])
            dim2_list = np.array([bbox[dim2] for bbox in bbox_list])
            if pad is None:
                pad = np.mean(dim1_list) / 2
            offset1_list = ut.cumsum([0] + [d + pad for d in dim1_list[:-1]])
            max_dim2 = max(dim2_list)
            offset2_list = [(max_dim2 - d2) / 2 for d2 in dim2_list]
            if vert:
                t_xy_list = [(d2, d1)
                             for d1, d2 in zip(offset1_list, offset2_list)]
            else:
                t_xy_list = [(d1, d2)
                             for d1, d2 in zip(offset1_list, offset2_list)]

            for g, t_xy in zip(graph_list_, t_xy_list):
                ut.translate_graph(g, t_xy)
                nx.set_node_attributes(g, name='pin', values='true')

            new_graph = nx.compose_all(graph_list_)
            # pt.show_nx(new_graph, layout='custom', node_labels=False, as_directed=False)  # NOQA
            return new_graph

        # Construct graph
        for count, info in enumerate(visit_info_list):
            graph = nx.Graph()
            edges = [(int(u), int(v)) for u, v in info['edges']
                     if u is not None and v is not None]
            graph.add_edges_from(edges, attr_dict={'zorder': 10})
            nx.set_node_attributes(graph, name='zorder', values=20)

            # Layout in neato
            _ = pt.nx_agraph_layout(graph, inplace=True, **layoutkw)  # NOQA

            # Extract components and then flatten in nid ordering
            ccs = list(nx.connected_components(graph))
            root_aids = []
            cc_graphs = []
            for cc_nodes in ccs:
                cc = graph.subgraph(cc_nodes)
                try:
                    root_aids.append(
                        list(ut.nx_source_nodes(cc.to_directed()))[0])
                except nx.NetworkXUnfeasible:
                    root_aids.append(list(cc.nodes())[0])
                cc_graphs.append(cc)

            root_nids = ibs.get_annot_nids(root_aids)
            nid2_graph = dict(zip(root_nids, cc_graphs))

            resight_nids_ = set(resight_nids).intersection(set(root_nids))
            noresight_nids_ = set(root_nids) - resight_nids_

            n_graph_list = ut.take(nid2_graph, sorted(noresight_nids_))
            r_graph_list = ut.take(nid2_graph, sorted(resight_nids_))

            if len(n_graph_list) > 0:
                n_graph = nx.compose_all(n_graph_list)
                _ = pt.nx_agraph_layout(n_graph, inplace=True,
                                        **layoutkw)  # NOQA
                n_graphs = [n_graph]
            else:
                n_graphs = []

            r_graphs = [
                stack_graphs(chunk) for chunk in ut.ichunks(r_graph_list, 100)
            ]
            if count == 0:
                new_graph = stack_graphs(n_graphs + r_graphs, vert=True)
            else:
                new_graph = stack_graphs(r_graphs[::-1] + n_graphs, vert=True)

            # pt.show_nx(new_graph, layout='custom', node_labels=False, as_directed=False)  # NOQA
            info['graph'] = new_graph

        graph1_, graph2_ = ut.take_column(visit_info_list, 'graph')
        if False:
            pt.show_nx(graph1_,
                       layout='custom',
                       node_labels=False,
                       as_directed=False)
            pt.show_nx(graph2_,
                       layout='custom',
                       node_labels=False,
                       as_directed=False)

        graph_list = [graph1_, graph2_]
        twoday_graph = stack_graphs(graph_list, vert=True, pad=None)
        nx.set_node_attributes(twoday_graph, name='pin', values='true')

        if debug:
            ut.nx_delete_None_edge_attr(twoday_graph)
            ut.nx_delete_None_node_attr(twoday_graph)
            logger.info('twoday_graph(pre) info' +
                        ut.repr3(ut.graph_info(twoday_graph), nl=2))

        # Hack, no idea why there are nodes that dont exist here
        between_edges_ = [
            edge for edge in between_edges if twoday_graph.has_node(edge[0])
            and twoday_graph.has_node(edge[1])
        ]

        twoday_graph.add_edges_from(between_edges_,
                                    attr_dict={
                                        'alpha': 0.2,
                                        'zorder': 0
                                    })
        ut.nx_ensure_agraph_color(twoday_graph)

        layoutkw['splines'] = 'line'
        layoutkw['prog'] = 'neato'
        agraph = pt.nx_agraph_layout(twoday_graph,
                                     inplace=True,
                                     return_agraph=True,
                                     **layoutkw)[-1]  # NOQA
        if False:
            fpath = ut.truepath('~/ggr_graph.png')
            agraph.draw(fpath)
            ut.startfile(fpath)

        if debug:
            logger.info('twoday_graph(post) info' +
                        ut.repr3(ut.graph_info(twoday_graph)))

        pt.show_nx(twoday_graph,
                   layout='custom',
                   node_labels=False,
                   as_directed=False)
    def make_graph(self, show=False):
        import networkx as nx
        import itertools

        cm_list = self.cm_list
        unique_nids, prob_names = self.make_prob_names()
        thresh = self.choose_thresh()

        # Simply cut any edge with a weight less than a threshold
        qaid_list = [cm.qaid for cm in cm_list]
        postcut = prob_names > thresh
        qxs, nxs = np.where(postcut)
        if False:
            kw = dict(precision=2, max_line_width=140, suppress_small=True)
            logger.info(
                ut.hz_str('prob_names = ', ut.repr2((prob_names), **kw)))
            logger.info(
                ut.hz_str('postcut = ', ut.repr2((postcut).astype(np.int),
                                                 **kw)))
        matching_qaids = ut.take(qaid_list, qxs)
        matched_nids = ut.take(unique_nids, nxs)

        qreq_ = self.qreq_

        nodes = ut.unique(qreq_.qaids.tolist() + qreq_.daids.tolist())
        # qnids = qreq_.get_qreq_annot_nids(qreq_.qaids)
        dnids = qreq_.get_qreq_annot_nids(qreq_.daids)
        dnid2_daids = ut.group_items(qreq_.daids, dnids)
        grouped_aids = dnid2_daids.values()
        matched_daids = ut.take(dnid2_daids, matched_nids)
        name_cliques = [
            list(itertools.combinations(aids, 2)) for aids in grouped_aids
        ]
        aid_matches = [
            list(ut.product([qaid], daids))
            for qaid, daids in zip(matching_qaids, matched_daids)
        ]

        graph = nx.Graph()
        graph.add_nodes_from(nodes)
        graph.add_edges_from(ut.flatten(name_cliques))
        graph.add_edges_from(ut.flatten(aid_matches))

        # matchless_quries = ut.take(qaid_list, ut.index_complement(qxs, len(qaid_list)))
        name_nodes = [('nid', nid) for nid in dnids]
        db_aid_nid_edges = list(zip(qreq_.daids, name_nodes))
        # query_aid_nid_edges = list(zip(matching_qaids, [('nid', l) for l in matched_nids]))
        # G = nx.Graph()
        # G.add_nodes_from(matchless_quries)
        # G.add_edges_from(db_aid_nid_edges)
        # G.add_edges_from(query_aid_nid_edges)

        graph.add_edges_from(db_aid_nid_edges)

        if self.user_feedback is not None:
            user_feedback = ut.map_dict_vals(np.array, self.user_feedback)
            p_bg = 0.0
            part1 = user_feedback['p_match'] * (1 - user_feedback['p_notcomp'])
            part2 = p_bg * user_feedback['p_notcomp']
            p_same_list = part1 + part2
            for aid1, aid2, p_same in zip(user_feedback['aid1'],
                                          user_feedback['aid2'], p_same_list):
                if p_same > 0.5:
                    if not graph.has_edge(aid1, aid2):
                        graph.add_edge(aid1, aid2)
                else:
                    if graph.has_edge(aid1, aid2):
                        graph.remove_edge(aid1, aid2)
        if show:
            import wbia.plottool as pt

            nx.set_node_attributes(
                graph,
                name='color',
                values={aid: pt.LIGHT_PINK
                        for aid in qreq_.daids})
            nx.set_node_attributes(
                graph,
                name='color',
                values={aid: pt.TRUE_BLUE
                        for aid in qreq_.qaids})
            nx.set_node_attributes(
                graph,
                name='color',
                values={
                    aid: pt.LIGHT_PURPLE
                    for aid in np.intersect1d(qreq_.qaids, qreq_.daids)
                },
            )
            nx.set_node_attributes(
                graph,
                name='label',
                values={node: 'n%r' % (node[1], )
                        for node in name_nodes},
            )
            nx.set_node_attributes(
                graph,
                name='color',
                values={node: pt.LIGHT_GREEN
                        for node in name_nodes},
            )
        if show:
            import wbia.plottool as pt

            pt.show_nx(graph, layoutkw={'prog': 'neato'}, verbose=False)
        return graph
示例#8
0
def draw_bayesian_model(model,
                        evidence={},
                        soft_evidence={},
                        fnum=None,
                        pnum=None,
                        **kwargs):

    from pgmpy.models import BayesianModel

    if not isinstance(model, BayesianModel):
        model = model.to_bayesian_model()

    import wbia.plottool as pt
    import networkx as nx

    kwargs = kwargs.copy()
    factor_list = kwargs.pop('factor_list', [])

    ttype_colors, ttype_scalars = make_colorcodes(model)

    textprops = {
        'horizontalalignment': 'left',
        'family': 'monospace',
        'size': 8,
    }

    # build graph attrs
    tup = get_node_viz_attrs(model, evidence, soft_evidence, factor_list,
                             ttype_colors, **kwargs)
    node_color, pos_list, pos_dict, takws = tup

    # draw graph
    has_inferred = evidence or 'factor_list' in kwargs

    if False:
        fig = pt.figure(fnum=fnum, pnum=pnum, doclf=True)  # NOQA
        ax = pt.gca()
        drawkw = dict(pos=pos_dict,
                      ax=ax,
                      with_labels=True,
                      node_size=1100,
                      node_color=node_color)
        nx.draw(model, **drawkw)
    else:
        # BE VERY CAREFUL
        if 1:
            graph = model.copy()
            graph.__class__ = nx.DiGraph
            graph.graph['groupattrs'] = ut.ddict(dict)
            # graph = model.
            if getattr(graph, 'ttype2_cpds', None) is not None:
                # Add invis edges and ttype groups
                for ttype in model.ttype2_cpds.keys():
                    ttype_cpds = model.ttype2_cpds[ttype]
                    # use defined ordering
                    ttype_nodes = ut.list_getattr(ttype_cpds, 'variable')
                    # ttype_nodes = sorted(ttype_nodes)
                    invis_edges = list(ut.itertwo(ttype_nodes))
                    graph.add_edges_from(invis_edges)
                    nx.set_edge_attributes(
                        graph,
                        name='style',
                        values={edge: 'invis'
                                for edge in invis_edges},
                    )
                    nx.set_node_attributes(
                        graph,
                        name='groupid',
                        values={node: ttype
                                for node in ttype_nodes},
                    )
                    graph.graph['groupattrs'][ttype]['rank'] = 'same'
                    graph.graph['groupattrs'][ttype]['cluster'] = False
        else:
            graph = model
        pt.show_nx(graph,
                   layout_kw={'prog': 'dot'},
                   fnum=fnum,
                   pnum=pnum,
                   verbose=0)
        pt.zoom_factory()
        fig = pt.gcf()
        ax = pt.gca()
        pass
    hacks = [
        pt.draw_text_annotations(textprops=textprops, **takw) for takw in takws
        if takw
    ]

    xmin, ymin = np.array(pos_list).min(axis=0)
    xmax, ymax = np.array(pos_list).max(axis=0)
    if 'name' in model.ttype2_template:
        num_names = len(model.ttype2_template['name'].basis)
        num_annots = len(model.ttype2_cpds['name'])
        if num_annots > 4:
            ax.set_xlim((xmin - 40, xmax + 40))
            ax.set_ylim((ymin - 50, ymax + 50))
            fig.set_size_inches(30, 7)
        else:
            ax.set_xlim((xmin - 42, xmax + 42))
            ax.set_ylim((ymin - 50, ymax + 50))
            fig.set_size_inches(23, 7)
        title = 'num_names=%r, num_annots=%r' % (
            num_names,
            num_annots,
        )
    else:
        title = ''
    map_assign = kwargs.get('map_assign', None)

    def word_insert(text):
        return '' if len(text) == 0 else text + ' '

    top_assignments = kwargs.get('top_assignments', None)
    if top_assignments is not None:
        map_assign, map_prob = top_assignments[0]
        if map_assign is not None:
            title += '\n%sMAP: ' % (word_insert(kwargs.get('method', '')))
            title += map_assign + ' @' + '%.2f%%' % (100 * map_prob, )
    if kwargs.get('show_title', True):
        pt.set_figtitle(title, size=14)

    for hack in hacks:
        hack()

    if has_inferred:
        # Hack in colorbars
        # if ut.list_type(basis) is int:
        #     pt.colorbar(scalars, colors, lbl='score', ticklabels=np.array(basis) + 1)
        # else:
        #     pt.colorbar(scalars, colors, lbl='score', ticklabels=basis)
        keys = ['name', 'score']
        locs = ['left', 'right']
        for key, loc in zip(keys, locs):
            if key in ttype_colors:
                basis = model.ttype2_template[key].basis
                # scalars =
                colors = ttype_colors[key]
                scalars = ttype_scalars[key]
                pt.colorbar(scalars,
                            colors,
                            lbl=key,
                            ticklabels=basis,
                            ticklocation=loc)