def plot(self, fnum, pnum): self.infr.update_visual_attrs(self.show_cuts) layoutkw = dict(prog='neato', splines='spline', sep=10 / 72) # draw_implicit=self.show_cuts) self.plotinfo = pt.show_nx( self.infr.graph, as_directed=False, fnum=self.fnum, layoutkw=layoutkw, # node_labels=True, modify_ax=False, use_image=self.use_image, verbose=0, ) ut.util_graph.graph_info(self.infr.graph, verbose=True) # _, edge_weights, edge_colors = self.infr.get_colored_edge_weights() # pt.colorbar(edge_weights, edge_colors, lbl='weights') # _normal_ticks = np.linspace(0, 1, num=11) # _normal_scores = np.linspace(0, 1, num=500) # _normal_colors = self.infr.get_colored_weights(_normal_scores) # cb = pt.colorbar(_normal_scores, _normal_colors, lbl='weights', # ticklabels=_normal_ticks) # cb.ax.annotate('threshold', # xy=(1, self.infr.thresh), # xytext=(2.5, .3 if self.infr.thresh < .5 else .7), # arrowprops=dict( # alpha=.5, # fc="0.6", # connectionstyle="angle3,angleA=90,angleB=0"),) ax = pt.gca() self.enable_pan_and_zoom(ax) # ax.autoscale() for aid in self.selected_aids: self.highlight_aid(aid, pt.ORANGE) # self.static_plot(fnum, pnum) self.make_hud() # logger.info(ut.repr2(self.infr.graph.edges, nl=2)) logger.info('Finished Plot')
def viz_factor_graph(gm): """ ut.qtensure() gm = build_factor_graph(G, nodes, edges , n_annots, n_names, lookup_annot_idx, use_unaries=True, edge_probs=None, operator='multiplier') """ ut.qtensure() import networkx from networkx.drawing.nx_agraph import graphviz_layout networkx.graphviz_layout = graphviz_layout opengm.visualizeGm( gm, show=False, layout='neato', plotUnaries=True, iterations=1000, plotFunctions=True, plotNonShared=False, relNodeSize=1.0, ) _ = pt.show_nx(gm.G) # NOQA
def draw_em_graph(P, Pn, PL, gam, num_labels): """ python -m wbia.algo.hots.testem test_em --show --no-cnn """ num_labels = PL.shape[1] name_nodes = ['N%d' % x for x in list(range(1, num_labels + 1))] annot_nodes = ['X%d' % x for x in list(range(1, len(Pn) + 1))] nodes = name_nodes + annot_nodes PL2 = gam[:, num_labels:].T PL2 += 0.01 PL2 = PL2 / PL2.sum(axis=1)[:, None] # PL2 = PL2 / np.linalg.norm(PL2, axis=0) zero_part = np.zeros((num_labels, len(Pn) + num_labels)) prob_part = np.hstack([PL2, Pn]) logger.info(ut.hz_str(' PL2 = ', ut.repr2(PL2, precision=2))) # Redo p with posteriors if ut.get_argflag('--postem'): P = np.vstack([zero_part, prob_part]) weight_matrix = P # NOQA graph = ut.nx_from_matrix(P, nodes=nodes) graph = graph.to_directed() # delete graph dup_edges = [] seen_ = set([]) for u, v in graph.edges(): if u < v: u, v = v, u if (u, v) not in seen_: seen_.add((u, v)) else: dup_edges.append((u, v)) graph.remove_edges_from(dup_edges) import wbia.plottool as pt import networkx as nx if len(name_nodes) == 3 and len(annot_nodes) == 4: graph.nodes[annot_nodes[0]]['pos'] = (20.0, 200.0) graph.nodes[annot_nodes[1]]['pos'] = (220.0, 200.0) graph.nodes[annot_nodes[2]]['pos'] = (20.0, 100.0) graph.nodes[annot_nodes[3]]['pos'] = (220.0, 100.0) graph.nodes[name_nodes[0]]['pos'] = (10.0, 300.0) graph.nodes[name_nodes[1]]['pos'] = (120.0, 300.0) graph.nodes[name_nodes[2]]['pos'] = (230.0, 300.0) nx.set_node_attributes(graph, name='pin', values='true') logger.info('annot_nodes = %r' % (annot_nodes,)) logger.info('name_nodes = %r' % (name_nodes,)) for u in annot_nodes: for v in name_nodes: if graph.has_edge(u, v): logger.info('1) u, v = %r' % ((u, v),)) graph.edge[u][v]['taillabel'] = graph.edge[u][v]['label'] graph.edge[u][v]['color'] = pt.ORANGE graph.edge[u][v]['labelcolor'] = pt.BLUE del graph.edge[u][v]['label'] elif graph.has_edge(v, u): logger.info('2) u, v = %r' % ((u, v),)) graph.edge[v][u]['headlabel'] = graph.edge[v][u]['label'] graph.edge[v][u]['color'] = pt.ORANGE graph.edge[v][u]['labelcolor'] = pt.BLUE del graph.edge[v][u]['label'] else: logger.info((u, v)) logger.info('!!') # import itertools # name_const_edges = [(u, v, {'style': 'invis'}) for u, v in itertools.combinations(name_nodes, 2)] # graph.add_edges_from(name_const_edges) # nx.set_edge_attributes(graph, name='constraint', values={edge: False for edge in graph.edges() if edge[0] == 'b' or edge[1] == 'b'}) # nx.set_edge_attributes(graph, name='constraint', values={edge: False for edge in graph.edges() if edge[0] in annot_nodes and edge[1] in annot_nodes}) # nx.set_edge_attributes(graph, name='constraint', values={edge: True for edge in graph.edges() if edge[0] in name_nodes or edge[1] in name_nodes}) # nx.set_edge_attributes(graph, name='constraint', values={edge: True for edge in graph.edges() if (edge[0] in ['a', 'b'] and edge[1] in ['a', 'b']) and edge[0] in annot_nodes and edge[1] in annot_nodes}) # nx.set_edge_attributes(graph, name='constraint', values={edge: True for edge in graph.edges() if (edge[0] in ['c'] or edge[1] in ['c']) and edge[0] in annot_nodes and edge[1] in annot_nodes}) # nx.set_edge_attributes(graph, name='constraint', values={edge: True for edge in graph.edges() if (edge[0] in ['a'] or edge[1] in ['a']) and edge[0] in annot_nodes and edge[1] in annot_nodes}) # nx.set_edge_attributes(graph, name='constraint', values={edge: True for edge in graph.edges() if (edge[0] in ['b'] or edge[1] in ['b']) and edge[0] in annot_nodes and edge[1] in annot_nodes}) # graph.add_edges_from([('root', n) for n in nodes]) # {node: 'names' for node in name_nodes}) nx.set_node_attributes( graph, name='color', values={node: pt.RED for node in name_nodes} ) # nx.set_node_attributes(graph, name='width', values={node: 20 for node in nodes}) # nx.set_node_attributes(graph, name='height', values={node: 20 for node in nodes}) # nx.set_node_attributes(graph, name='group', values={node: 'names' for node in name_nodes}) # nx.set_node_attributes(graph, name='group', values={node: 'annots' for node in annot_nodes}) nx.set_node_attributes( graph, name='groupid', values={node: 'names' for node in name_nodes} ) nx.set_node_attributes( graph, name='groupid', values={node: 'annots' for node in annot_nodes} ) graph.graph['clusterrank'] = 'local' # graph.graph['groupattrs'] = { # 'names': {'rankdir': 'LR', 'rank': 'source'}, # 'annots': {'rankdir': 'TB', 'rank': 'source'}, # } ut.nx_delete_edge_attr(graph, 'weight') # pt.show_nx(graph, fontsize=10, layoutkw={'splines': 'spline', 'prog': 'dot', 'sep': 2.0}, verbose=1) layoutkw = { # 'rankdir': 'LR', 'splines': 'spline', # 'splines': 'ortho', # 'splines': 'curved', # 'compound': 'True', # 'prog': 'dot', 'prog': 'neato', # 'packMode': 'clust', # 'sep': 4, # 'nodesep': 1, # 'ranksep': 1, } # pt.show_nx(graph, fontsize=12, layoutkw=layoutkw, verbose=0, as_directed=False) pt.show_nx( graph, fontsize=6, fontname='Ubuntu', layoutkw=layoutkw, verbose=0, as_directed=False, ) pt.interactions.zoom_factory()
def viz_netx_chipgraph( ibs, graph, fnum=None, use_image=False, layout=None, zoom=None, prog='neato', as_directed=False, augment_graph=True, layoutkw=None, framewidth=True, **kwargs, ): r""" DEPRICATE or improve Args: ibs (IBEISController): wbia controller object graph (nx.DiGraph): fnum (int): figure number(default = None) use_image (bool): (default = False) zoom (float): (default = 0.4) Returns: ?: pos CommandLine: python -m wbia --tf viz_netx_chipgraph --show Cand: wbia review_tagged_joins --save figures4/mergecase.png --figsize=15,15 --clipwhite --diskshow wbia compute_occurrence_groups --save figures4/occurgraph.png --figsize=40,40 --clipwhite --diskshow ~/code/wbia/wbia/algo/preproc/preproc_occurrence.py Example: >>> # DISABLE_DOCTEST >>> from wbia.viz.viz_graph import * # NOQA >>> import wbia >>> ibs = wbia.opendb(defaultdb='PZ_MTEST') >>> nid_list = ibs.get_valid_nids()[0:10] >>> fnum = None >>> use_image = True >>> zoom = 0.4 >>> make_name_graph_interaction(ibs, nid_list, prog='neato') >>> ut.show_if_requested() """ import wbia.plottool as pt logger.info('[viz_graph] drawing chip graph') fnum = pt.ensure_fnum(fnum) pt.figure(fnum=fnum, pnum=(1, 1, 1)) ax = pt.gca() if layout is None: layout = 'agraph' logger.info('layout = %r' % (layout, )) if use_image: ensure_node_images(ibs, graph) nx.set_node_attributes(graph, name='shape', values='rect') if layoutkw is None: layoutkw = {} layoutkw['prog'] = layoutkw.get('prog', prog) layoutkw.update(kwargs) if prog == 'neato': graph = graph.to_undirected() plotinfo = pt.show_nx( graph, ax=ax, # img_dict=img_dict, layout=layout, # hacknonode=bool(use_image), layoutkw=layoutkw, as_directed=as_directed, framewidth=framewidth, ) return plotinfo
def show_graph( infr, graph=None, use_image=False, update_attrs=True, with_colorbar=False, pnum=(1, 1, 1), zoomable=True, pickable=False, **kwargs, ): r""" Args: infr (?): graph (None): (default = None) use_image (bool): (default = False) update_attrs (bool): (default = True) with_colorbar (bool): (default = False) pnum (tuple): plot number(default = (1, 1, 1)) zoomable (bool): (default = True) pickable (bool): (de = False) **kwargs: verbose, with_labels, fnum, layout, ax, pos, img_dict, title, layoutkw, framewidth, modify_ax, as_directed, hacknoedge, hacknode, node_labels, arrow_width, fontsize, fontweight, fontname, fontfamilty, fontproperties CommandLine: python -m wbia.algo.graph.mixin_viz GraphVisualization.show_graph --show Example: >>> # xdoctest: +REQUIRES(module:pygraphviz) >>> # ENABLE_DOCTEST >>> from wbia.algo.graph.mixin_viz import * # NOQA >>> from wbia.algo.graph import demo >>> import wbia.plottool as pt >>> infr = demo.demodata_infr(ccs=ut.estarmap( >>> range, [(1, 6), (6, 10), (10, 13), (13, 15), (15, 16), >>> (17, 20)])) >>> pnum_ = pt.make_pnum_nextgen(nRows=1, nCols=3) >>> infr.show_graph(show_cand=True, simple_labels=True, pickable=True, fnum=1, pnum=pnum_()) >>> infr.add_feedback((1, 5), INCMP) >>> infr.add_feedback((14, 18), INCMP) >>> infr.refresh_candidate_edges() >>> infr.show_graph(show_cand=True, simple_labels=True, pickable=True, fnum=1, pnum=pnum_()) >>> infr.add_feedback((17, 18), NEGTV) # add inconsistency >>> infr.apply_nondynamic_update() >>> infr.show_graph(show_cand=True, simple_labels=True, pickable=True, fnum=1, pnum=pnum_()) >>> ut.show_if_requested() """ import wbia.plottool as pt if graph is None: graph = infr.graph # kwargs['fontsize'] = kwargs.get('fontsize', 8) with warnings.catch_warnings(): warnings.simplefilter('ignore') # default_update_kw = ut.get_func_kwargs(infr.update_visual_attrs) # update_kw = ut.update_existing(default_update_kw, kwargs) # infr.update_visual_attrs(**update_kw) if update_attrs: infr.update_visual_attrs(graph=graph, **kwargs) verbose = kwargs.pop('verbose', infr.verbose) pt.show_nx( graph, layout='custom', as_directed=False, modify_ax=False, use_image=use_image, pnum=pnum, verbose=verbose, **kwargs, ) if zoomable: pt.zoom_factory() pt.pan_factory(pt.gca()) # if with_colorbar: # # Draw a colorbar # _normal_ticks = np.linspace(0, 1, num=11) # _normal_scores = np.linspace(0, 1, num=500) # _normal_colors = infr.get_colored_weights(_normal_scores) # cb = pt.colorbar(_normal_scores, _normal_colors, lbl='weights', # ticklabels=_normal_ticks) # # point to threshold location # thresh = None # if thresh is not None: # xy = (1, thresh) # xytext = (2.5, .3 if thresh < .5 else .7) # cb.ax.annotate('threshold', xy=xy, xytext=xytext, # arrowprops=dict( # alpha=.5, fc="0.6", # connectionstyle="angle3,angleA=90,angleB=0"),) # infr.graph if graph.graph.get('dark_background', None): pt.dark_background(force=True) if pickable: fig = pt.gcf() fig.canvas.mpl_connect('pick_event', ut.partial(on_pick, infr=infr))
def draw_twoday_count(ibs, visit_info_list_): import copy visit_info_list = copy.deepcopy(visit_info_list_) aids_day1, aids_day2 = ut.take_column(visit_info_list_, 'aids') nids_day1, nids_day2 = ut.take_column(visit_info_list_, 'unique_nids') resight_nids = ut.isect(nids_day1, nids_day2) if False: # HACK REMOVE DATA TO MAKE THIS FASTER num = 20 for info in visit_info_list: non_resight_nids = list( set(info['unique_nids']) - set(resight_nids)) sample_nids2 = non_resight_nids[0:num] + resight_nids[:num] info['grouped_aids'] = ut.dict_subset(info['grouped_aids'], sample_nids2) info['unique_nids'] = sample_nids2 # Build a graph of matches if False: debug = False for info in visit_info_list: edges = [] grouped_aids = info['grouped_aids'] aids_list = list(grouped_aids.values()) ams_list = ibs.get_annotmatch_rowids_in_cliques(aids_list) aids1_list = ibs.unflat_map(ibs.get_annotmatch_aid1, ams_list) aids2_list = ibs.unflat_map(ibs.get_annotmatch_aid2, ams_list) for ams, aids, aids1, aids2 in zip(ams_list, aids_list, aids1_list, aids2_list): edge_nodes = set(aids1 + aids2) # #if len(edge_nodes) != len(set(aids)): # #logger.info('--') # #logger.info('aids = %r' % (aids,)) # #logger.info('edge_nodes = %r' % (edge_nodes,)) bad_aids = edge_nodes - set(aids) if len(bad_aids) > 0: logger.info('bad_aids = %r' % (bad_aids, )) unlinked_aids = set(aids) - edge_nodes mst_links = list( ut.itertwo(list(unlinked_aids) + list(edge_nodes)[:1])) bad_aids.add(None) user_links = [(u, v) for (u, v) in zip(aids1, aids2) if u not in bad_aids and v not in bad_aids] new_edges = mst_links + user_links new_edges = [(int(u), int(v)) for u, v in new_edges if u not in bad_aids and v not in bad_aids] edges += new_edges info['edges'] = edges # Add edges between days grouped_aids1, grouped_aids2 = ut.take_column(visit_info_list, 'grouped_aids') nids_day1, nids_day2 = ut.take_column(visit_info_list, 'unique_nids') resight_nids = ut.isect(nids_day1, nids_day2) resight_aids1 = ut.take(grouped_aids1, resight_nids) resight_aids2 = ut.take(grouped_aids2, resight_nids) # resight_aids3 = [list(aids1) + list(aids2) for aids1, aids2 in zip(resight_aids1, resight_aids2)] ams_list = ibs.get_annotmatch_rowids_between_groups( resight_aids1, resight_aids2) aids1_list = ibs.unflat_map(ibs.get_annotmatch_aid1, ams_list) aids2_list = ibs.unflat_map(ibs.get_annotmatch_aid2, ams_list) between_edges = [] for ams, aids1, aids2, rawaids1, rawaids2 in zip( ams_list, aids1_list, aids2_list, resight_aids1, resight_aids2): link_aids = aids1 + aids2 rawaids3 = rawaids1 + rawaids2 badaids = ut.setdiff(link_aids, rawaids3) assert not badaids user_links = [(int(u), int(v)) for (u, v) in zip(aids1, aids2) if u is not None and v is not None] # HACK THIS OFF user_links = [] if len(user_links) == 0: # Hack in an edge between_edges += [(rawaids1[0], rawaids2[0])] else: between_edges += user_links assert np.all(0 == np.diff(np.array( ibs.unflat_map(ibs.get_annot_nids, between_edges)), axis=1)) import wbia.plottool as pt import networkx as nx # pt.qt4ensure() # len(list(nx.connected_components(graph1))) # logger.info(ut.graph_info(graph1)) # Layout graph layoutkw = dict( prog='neato', draw_implicit=False, splines='line', # splines='curved', # splines='spline', # sep=10 / 72, # prog='dot', rankdir='TB', ) def translate_graph_to_origin(graph): x, y, w, h = ut.get_graph_bounding_box(graph) ut.translate_graph(graph, (-x, -y)) def stack_graphs(graph_list, vert=False, pad=None): graph_list_ = [g.copy() for g in graph_list] for g in graph_list_: translate_graph_to_origin(g) bbox_list = [ut.get_graph_bounding_box(g) for g in graph_list_] if vert: dim1 = 3 dim2 = 2 else: dim1 = 2 dim2 = 3 dim1_list = np.array([bbox[dim1] for bbox in bbox_list]) dim2_list = np.array([bbox[dim2] for bbox in bbox_list]) if pad is None: pad = np.mean(dim1_list) / 2 offset1_list = ut.cumsum([0] + [d + pad for d in dim1_list[:-1]]) max_dim2 = max(dim2_list) offset2_list = [(max_dim2 - d2) / 2 for d2 in dim2_list] if vert: t_xy_list = [(d2, d1) for d1, d2 in zip(offset1_list, offset2_list)] else: t_xy_list = [(d1, d2) for d1, d2 in zip(offset1_list, offset2_list)] for g, t_xy in zip(graph_list_, t_xy_list): ut.translate_graph(g, t_xy) nx.set_node_attributes(g, name='pin', values='true') new_graph = nx.compose_all(graph_list_) # pt.show_nx(new_graph, layout='custom', node_labels=False, as_directed=False) # NOQA return new_graph # Construct graph for count, info in enumerate(visit_info_list): graph = nx.Graph() edges = [(int(u), int(v)) for u, v in info['edges'] if u is not None and v is not None] graph.add_edges_from(edges, attr_dict={'zorder': 10}) nx.set_node_attributes(graph, name='zorder', values=20) # Layout in neato _ = pt.nx_agraph_layout(graph, inplace=True, **layoutkw) # NOQA # Extract components and then flatten in nid ordering ccs = list(nx.connected_components(graph)) root_aids = [] cc_graphs = [] for cc_nodes in ccs: cc = graph.subgraph(cc_nodes) try: root_aids.append( list(ut.nx_source_nodes(cc.to_directed()))[0]) except nx.NetworkXUnfeasible: root_aids.append(list(cc.nodes())[0]) cc_graphs.append(cc) root_nids = ibs.get_annot_nids(root_aids) nid2_graph = dict(zip(root_nids, cc_graphs)) resight_nids_ = set(resight_nids).intersection(set(root_nids)) noresight_nids_ = set(root_nids) - resight_nids_ n_graph_list = ut.take(nid2_graph, sorted(noresight_nids_)) r_graph_list = ut.take(nid2_graph, sorted(resight_nids_)) if len(n_graph_list) > 0: n_graph = nx.compose_all(n_graph_list) _ = pt.nx_agraph_layout(n_graph, inplace=True, **layoutkw) # NOQA n_graphs = [n_graph] else: n_graphs = [] r_graphs = [ stack_graphs(chunk) for chunk in ut.ichunks(r_graph_list, 100) ] if count == 0: new_graph = stack_graphs(n_graphs + r_graphs, vert=True) else: new_graph = stack_graphs(r_graphs[::-1] + n_graphs, vert=True) # pt.show_nx(new_graph, layout='custom', node_labels=False, as_directed=False) # NOQA info['graph'] = new_graph graph1_, graph2_ = ut.take_column(visit_info_list, 'graph') if False: pt.show_nx(graph1_, layout='custom', node_labels=False, as_directed=False) pt.show_nx(graph2_, layout='custom', node_labels=False, as_directed=False) graph_list = [graph1_, graph2_] twoday_graph = stack_graphs(graph_list, vert=True, pad=None) nx.set_node_attributes(twoday_graph, name='pin', values='true') if debug: ut.nx_delete_None_edge_attr(twoday_graph) ut.nx_delete_None_node_attr(twoday_graph) logger.info('twoday_graph(pre) info' + ut.repr3(ut.graph_info(twoday_graph), nl=2)) # Hack, no idea why there are nodes that dont exist here between_edges_ = [ edge for edge in between_edges if twoday_graph.has_node(edge[0]) and twoday_graph.has_node(edge[1]) ] twoday_graph.add_edges_from(between_edges_, attr_dict={ 'alpha': 0.2, 'zorder': 0 }) ut.nx_ensure_agraph_color(twoday_graph) layoutkw['splines'] = 'line' layoutkw['prog'] = 'neato' agraph = pt.nx_agraph_layout(twoday_graph, inplace=True, return_agraph=True, **layoutkw)[-1] # NOQA if False: fpath = ut.truepath('~/ggr_graph.png') agraph.draw(fpath) ut.startfile(fpath) if debug: logger.info('twoday_graph(post) info' + ut.repr3(ut.graph_info(twoday_graph))) pt.show_nx(twoday_graph, layout='custom', node_labels=False, as_directed=False)
def make_graph(self, show=False): import networkx as nx import itertools cm_list = self.cm_list unique_nids, prob_names = self.make_prob_names() thresh = self.choose_thresh() # Simply cut any edge with a weight less than a threshold qaid_list = [cm.qaid for cm in cm_list] postcut = prob_names > thresh qxs, nxs = np.where(postcut) if False: kw = dict(precision=2, max_line_width=140, suppress_small=True) logger.info( ut.hz_str('prob_names = ', ut.repr2((prob_names), **kw))) logger.info( ut.hz_str('postcut = ', ut.repr2((postcut).astype(np.int), **kw))) matching_qaids = ut.take(qaid_list, qxs) matched_nids = ut.take(unique_nids, nxs) qreq_ = self.qreq_ nodes = ut.unique(qreq_.qaids.tolist() + qreq_.daids.tolist()) # qnids = qreq_.get_qreq_annot_nids(qreq_.qaids) dnids = qreq_.get_qreq_annot_nids(qreq_.daids) dnid2_daids = ut.group_items(qreq_.daids, dnids) grouped_aids = dnid2_daids.values() matched_daids = ut.take(dnid2_daids, matched_nids) name_cliques = [ list(itertools.combinations(aids, 2)) for aids in grouped_aids ] aid_matches = [ list(ut.product([qaid], daids)) for qaid, daids in zip(matching_qaids, matched_daids) ] graph = nx.Graph() graph.add_nodes_from(nodes) graph.add_edges_from(ut.flatten(name_cliques)) graph.add_edges_from(ut.flatten(aid_matches)) # matchless_quries = ut.take(qaid_list, ut.index_complement(qxs, len(qaid_list))) name_nodes = [('nid', nid) for nid in dnids] db_aid_nid_edges = list(zip(qreq_.daids, name_nodes)) # query_aid_nid_edges = list(zip(matching_qaids, [('nid', l) for l in matched_nids])) # G = nx.Graph() # G.add_nodes_from(matchless_quries) # G.add_edges_from(db_aid_nid_edges) # G.add_edges_from(query_aid_nid_edges) graph.add_edges_from(db_aid_nid_edges) if self.user_feedback is not None: user_feedback = ut.map_dict_vals(np.array, self.user_feedback) p_bg = 0.0 part1 = user_feedback['p_match'] * (1 - user_feedback['p_notcomp']) part2 = p_bg * user_feedback['p_notcomp'] p_same_list = part1 + part2 for aid1, aid2, p_same in zip(user_feedback['aid1'], user_feedback['aid2'], p_same_list): if p_same > 0.5: if not graph.has_edge(aid1, aid2): graph.add_edge(aid1, aid2) else: if graph.has_edge(aid1, aid2): graph.remove_edge(aid1, aid2) if show: import wbia.plottool as pt nx.set_node_attributes( graph, name='color', values={aid: pt.LIGHT_PINK for aid in qreq_.daids}) nx.set_node_attributes( graph, name='color', values={aid: pt.TRUE_BLUE for aid in qreq_.qaids}) nx.set_node_attributes( graph, name='color', values={ aid: pt.LIGHT_PURPLE for aid in np.intersect1d(qreq_.qaids, qreq_.daids) }, ) nx.set_node_attributes( graph, name='label', values={node: 'n%r' % (node[1], ) for node in name_nodes}, ) nx.set_node_attributes( graph, name='color', values={node: pt.LIGHT_GREEN for node in name_nodes}, ) if show: import wbia.plottool as pt pt.show_nx(graph, layoutkw={'prog': 'neato'}, verbose=False) return graph
def draw_bayesian_model(model, evidence={}, soft_evidence={}, fnum=None, pnum=None, **kwargs): from pgmpy.models import BayesianModel if not isinstance(model, BayesianModel): model = model.to_bayesian_model() import wbia.plottool as pt import networkx as nx kwargs = kwargs.copy() factor_list = kwargs.pop('factor_list', []) ttype_colors, ttype_scalars = make_colorcodes(model) textprops = { 'horizontalalignment': 'left', 'family': 'monospace', 'size': 8, } # build graph attrs tup = get_node_viz_attrs(model, evidence, soft_evidence, factor_list, ttype_colors, **kwargs) node_color, pos_list, pos_dict, takws = tup # draw graph has_inferred = evidence or 'factor_list' in kwargs if False: fig = pt.figure(fnum=fnum, pnum=pnum, doclf=True) # NOQA ax = pt.gca() drawkw = dict(pos=pos_dict, ax=ax, with_labels=True, node_size=1100, node_color=node_color) nx.draw(model, **drawkw) else: # BE VERY CAREFUL if 1: graph = model.copy() graph.__class__ = nx.DiGraph graph.graph['groupattrs'] = ut.ddict(dict) # graph = model. if getattr(graph, 'ttype2_cpds', None) is not None: # Add invis edges and ttype groups for ttype in model.ttype2_cpds.keys(): ttype_cpds = model.ttype2_cpds[ttype] # use defined ordering ttype_nodes = ut.list_getattr(ttype_cpds, 'variable') # ttype_nodes = sorted(ttype_nodes) invis_edges = list(ut.itertwo(ttype_nodes)) graph.add_edges_from(invis_edges) nx.set_edge_attributes( graph, name='style', values={edge: 'invis' for edge in invis_edges}, ) nx.set_node_attributes( graph, name='groupid', values={node: ttype for node in ttype_nodes}, ) graph.graph['groupattrs'][ttype]['rank'] = 'same' graph.graph['groupattrs'][ttype]['cluster'] = False else: graph = model pt.show_nx(graph, layout_kw={'prog': 'dot'}, fnum=fnum, pnum=pnum, verbose=0) pt.zoom_factory() fig = pt.gcf() ax = pt.gca() pass hacks = [ pt.draw_text_annotations(textprops=textprops, **takw) for takw in takws if takw ] xmin, ymin = np.array(pos_list).min(axis=0) xmax, ymax = np.array(pos_list).max(axis=0) if 'name' in model.ttype2_template: num_names = len(model.ttype2_template['name'].basis) num_annots = len(model.ttype2_cpds['name']) if num_annots > 4: ax.set_xlim((xmin - 40, xmax + 40)) ax.set_ylim((ymin - 50, ymax + 50)) fig.set_size_inches(30, 7) else: ax.set_xlim((xmin - 42, xmax + 42)) ax.set_ylim((ymin - 50, ymax + 50)) fig.set_size_inches(23, 7) title = 'num_names=%r, num_annots=%r' % ( num_names, num_annots, ) else: title = '' map_assign = kwargs.get('map_assign', None) def word_insert(text): return '' if len(text) == 0 else text + ' ' top_assignments = kwargs.get('top_assignments', None) if top_assignments is not None: map_assign, map_prob = top_assignments[0] if map_assign is not None: title += '\n%sMAP: ' % (word_insert(kwargs.get('method', ''))) title += map_assign + ' @' + '%.2f%%' % (100 * map_prob, ) if kwargs.get('show_title', True): pt.set_figtitle(title, size=14) for hack in hacks: hack() if has_inferred: # Hack in colorbars # if ut.list_type(basis) is int: # pt.colorbar(scalars, colors, lbl='score', ticklabels=np.array(basis) + 1) # else: # pt.colorbar(scalars, colors, lbl='score', ticklabels=basis) keys = ['name', 'score'] locs = ['left', 'right'] for key, loc in zip(keys, locs): if key in ttype_colors: basis = model.ttype2_template[key].basis # scalars = colors = ttype_colors[key] scalars = ttype_scalars[key] pt.colorbar(scalars, colors, lbl=key, ticklabels=basis, ticklocation=loc)