コード例 #1
0
def get_graph_bounding_box(graph):
    import utool as ut
    import networkx as nx
    import vtool as vt
    #nx.get_node_attrs = nx.get_node_attributes
    nodes = list(graph.nodes())
    pos_list = ut.take(nx.get_node_attributes(graph, 'pos'), nodes)
    shape_list = ut.take(nx.get_node_attributes(graph, 'size'), nodes)

    node_extents = np.array([
        vt.extent_from_bbox(vt.bbox_from_center_wh(xy, wh))
        for xy, wh in zip(pos_list, shape_list)
    ])
    tl_x, br_x, tl_y, br_y = node_extents.T
    extent = tl_x.min(), br_x.max(), tl_y.min(), br_y.max()
    bbox = vt.bbox_from_extent(extent)
    return bbox
コード例 #2
0
ファイル: nx_helpers.py プロジェクト: Erotemic/plottool
def draw_network2(graph, layout_info, ax, as_directed=None, hacknoedge=False,
                  hacknode=False, verbose=None, **kwargs):
    """
    fancy way to draw networkx graphs without directly using networkx

    # python -m ibeis.annotmatch_funcs review_tagged_joins --dpath ~/latex/crall-candidacy-2015/ --save figures4/mergecase.png --figsize=15,15 --clipwhite --diskshow
    # python -m dtool --tf DependencyCache.make_graph --show
    """
    import plottool as pt

    patch_dict = {
        'patch_frame_dict': {},
        'node_patch_dict': {},
        'edge_patch_dict': {},
        'arrow_patch_list': {},
    }

    text_pseudo_objects = []

    font_prop = pt.parse_fontkw(**kwargs)
    #print('font_prop = %r' % (font_prop,))
    #print('font_prop.get_name() = %r' % (font_prop.get_name() ,))

    # print('layout_info = %r' % (layout_info,))
    node_pos = layout_info['node']['pos']
    node_size = layout_info['node']['size']
    splines = layout_info['graph']['splines']
    # edge_startpoints = layout_info['edge']['start_pt']

    if as_directed is None:
        as_directed = graph.is_directed()

    # Draw nodes
    for node, nattrs in graph.nodes(data=True):
        # shape = nattrs.get('shape', 'circle')
        if nattrs is None:
            nattrs = {}
        label = nattrs.get('label', None)
        alpha = nattrs.get('alpha', 1.0)
        node_color = nattrs.get('color', pt.NEUTRAL_BLUE)
        if node_color is None:
            node_color = pt.NEUTRAL_BLUE
        xy = node_pos[node]
        using_image = kwargs.get('use_image', True) and 'image' in nattrs
        if using_image:
            if hacknode:
                alpha_ = 0.7
            else:
                alpha_ = 0.0
        else:
            alpha_ = alpha

        node_color = fix_hex_color(node_color)
        #intcolor = int(node_color.replace('#', '0x'), 16)
        node_color = node_color[0:3]
        patch_kw = dict(alpha=alpha_, color=node_color)
        node_shape = nattrs.get('shape', 'ellipse')
        if node_shape == 'circle':
            # divide by 2 seems to work for agraph
            radius = min(_get_node_size(graph, node, node_size)) / 2.0
            patch = mpl.patches.Circle(xy, radius=radius, **patch_kw)
        elif node_shape == 'ellipse':
            # divide by 2 seems to work for agraph
            width, height = np.array(_get_node_size(graph, node, node_size))
            patch = mpl.patches.Ellipse(xy, width, height, **patch_kw)
        elif node_shape in ['none', 'box', 'rect', 'rectangle', 'rhombus']:
            width, height = _get_node_size(graph, node, node_size)
            angle = 45 if node_shape == 'rhombus' else 0
            xy_bl = (xy[0] - width // 2, xy[1] - height // 2)

            # rounded = angle == 0
            rounded = 'rounded' in graph.node.get(node, {}).get('style', '')
            isdiag = 'diagonals' in graph.node.get(node, {}).get('style', '')

            if rounded:
                from matplotlib import patches
                rpad = 20
                xy_bl = np.array(xy_bl) + rpad
                width -= rpad
                height -= rpad
                boxstyle = patches.BoxStyle.Round(pad=rpad)
                patch = mpl.patches.FancyBboxPatch(
                    xy_bl, width, height, boxstyle=boxstyle, **patch_kw)
            else:
                bbox = list(xy_bl) + [width, height]
                if isdiag:
                    center_xy  = vt.bbox_center(bbox)
                    _xy =  np.array(center_xy)
                    newverts_ = [
                        _xy + [         0, -height / 2],
                        _xy + [-width / 2,           0],
                        _xy + [         0,  height / 2],
                        _xy + [ width / 2,           0],
                    ]
                    patch = mpl.patches.Polygon(newverts_, **patch_kw)
                else:
                    patch = mpl.patches.Rectangle(
                        xy_bl, width, height, angle=angle,
                        **patch_kw)
            patch.center = xy
        #if style == 'rounded'
        #elif node_shape in ['roundbox']:
        elif node_shape == 'stack':
            width, height = _get_node_size(graph, node, node_size)
            xy_bl = (xy[0] - width // 2, xy[1] - height // 2)
            patch = pt.cartoon_stacked_rects(xy_bl, width, height, **patch_kw)
            patch.xy = xy
        else:
            raise NotImplementedError('Unknown node_shape=%r' % (node_shape,))

        if True:
            # Add a frame around the node
            framewidth = nattrs.get('framewidth', 0)
            if framewidth > 0:
                framecolor = nattrs.get('framecolor', node_color)
                framecolor = fix_hex_color(framecolor)

                #print('framecolor = %r' % (framecolor,))
                alpha = 1.0
                if framecolor is None:
                    framecolor = pt.BLACK
                    alpha = 0.0
                if framewidth is True:
                    figsize = ut.get_argval('--figsize', type_=list, default=None)
                    if figsize is not None:
                        # HACK
                        graphsize = max(figsize)
                        framewidth = graphsize / 4
                    else:
                        framewidth = 3.0
                lw = framewidth
                frame = pt.make_bbox(bbox, bbox_color=framecolor, ax=ax, lw=lw, alpha=alpha)
                patch_dict['patch_frame_dict'][node] = frame

        #patch_dict[node] = patch
        x, y = xy
        text = str(node)
        if label is not None:
            #text += ': ' + str(label)
            text = label
        if kwargs.get('node_labels', hacknode or not using_image):
            text_args = ((x, y, text), dict(ax=ax, ha='center', va='center',
                                            fontproperties=font_prop))
            text_pseudo_objects.append(text_args)
        patch_dict['node_patch_dict'][node] = (patch)

    def get_default_edge_data(graph, edge):
        data = graph.get_edge_data(*edge)
        if data is None:
            if len(edge) == 3 and edge[2] is not None:
                data = graph.get_edge_data(edge[0], edge[1], int(edge[2]))
            else:
                data = graph.get_edge_data(edge[0], edge[1])
        if data is None:
            data = {}
        return data

    ###
    # Draw Edges
    # NEW WAY OF DRAWING EDGEES
    edge_pos = layout_info['edge'].get('ctrl_pts', None)
    if edge_pos is not None:
        for edge, pts in edge_pos.items():

            data = get_default_edge_data(graph, edge)

            if data.get('style', None) == 'invis':
                continue

            alpha = data.get('alpha', None)

            defaultcolor = pt.BLACK[0:3]
            if alpha is None:
                if data.get('implicit', False):
                    alpha = .5
                    defaultcolor = pt.GREEN[0:3]
                else:
                    alpha = 1.0
            color = data.get('color', defaultcolor)
            if color is None:
                color = defaultcolor
            color = fix_hex_color(color)
            color = color[0:3]

            #layout_info['edge']['ctrl_pts'][edge]
            #layout_info['edge']['start_pt'][edge]

            offset = 0 if graph.is_directed() else 0
            #color = data.get('color', color)[0:3]
            start_point = pts[offset]
            other_points = pts[offset + 1:].tolist()  # [0:3]
            verts = [start_point] + other_points

            MOVETO = mpl.path.Path.MOVETO
            LINETO = mpl.path.Path.LINETO

            if splines in ['line', 'polyline', 'ortho']:
                CODE = LINETO
            elif splines == 'curved':
                #CODE = mpl.path.Path.CURVE3
                # CODE = mpl.path.Path.CURVE3
                CODE = mpl.path.Path.CURVE4
            elif splines == 'spline':
                CODE = mpl.path.Path.CURVE4
            else:
                raise AssertionError('splines = %r' % (splines,))

            astart_code = MOVETO
            astart_code = MOVETO

            verts = [start_point] + other_points
            codes = [astart_code] + [CODE] * len(other_points)

            end_pt = layout_info['edge']['end_pt'][edge]

            # HACK THE ENDPOINTS TO TOUCH THE BOUNDING BOXES
            if end_pt is not None:
                verts += [end_pt]
                codes += [LINETO]

            path = mpl.path.Path(verts, codes)

            figsize = ut.get_argval('--figsize', type_=list, default=None)
            if figsize is not None:
                # HACK
                graphsize = max(figsize)
                lw = graphsize / 8
                width =  graphsize / 15
                width = ut.get_argval('--arrow-width', default=width)
                lw = ut.get_argval('--line-width', default=lw)
                #print('width = %r' % (width,))
            else:
                width = .5
                lw = 1.0
                try:
                    # Compute arrow width using estimated graph size
                    if node_size is not None and node_pos is not None:
                        xys = np.array(ut.take(node_pos, node_pos.keys())).T
                        whs = np.array(ut.take(node_size, node_pos.keys())).T
                        bboxes = vt.bbox_from_xywh(xys, whs, [.5, .5])
                        extents = vt.extent_from_bbox(bboxes)
                        tl_pts = np.array([extents[0], extents[2]]).T
                        br_pts = np.array([extents[1], extents[3]]).T
                        pts = np.vstack([tl_pts, br_pts])
                        extent = vt.get_pointset_extents(pts)
                        graph_w, graph_h = vt.bbox_from_extent(extent)[2:4]
                        graph_dim = np.sqrt(graph_w ** 2 + graph_h ** 2)
                        width = graph_dim * .0005
                except Exception:
                    pass

            if not as_directed and end_pt is not None:
                pass

            lw = data.get('lw', lw)
            linestyle = 'solid'
            linestyle = data.get('linestyle', linestyle)
            hatch = data.get('hatch', '')

            #effects = data.get('stroke', None)
            from matplotlib import patheffects
            path_effects = []

            #effects_css = data.get('path_effects', None)
            #if effects_css is not None:
            #    print('effects_css = %r' % (effects_css,))
            #    # Read data similar to Qt Style Sheets / CSS
            #    from tinycss.css21 import CSS21Parser
            #    css = effects_css
            #    #css = 'stroke{ linewith: 3; foreground: r; } shadow{}'
            #    stylesheet = CSS21Parser().parse_stylesheet(css)
            #    if stylesheet.errors:
            #        print('[pt.nx] css errors')
            #        print(stylesheet.errors)
            #    path_effects = []
            #    for rule in stylesheet.rules:
            #        if rule.selector.as_css() == 'stroke':
            #            selector = patheffects.withStroke
            #        elif rule.selector.as_css() == 'shadow':
            #            selector = patheffects.withSimplePatchShadow
            #        effectkw = {}
            #        for decl in rule.declarations:
            #            if len(decl.value) != 1:
            #                raise AssertionError(
            #                    'I dont know css %r' % (decl,))
            #            strval = decl.value[0].as_css()
            #            key = decl.name
            #            val = ut.smart_cast2(strval)
            #            effectkw[key] = val
            #        effect = selector(**effectkw)
            #        path_effects += [effect]

            ## http://matplotlib.org/1.2.1/examples/api/clippath_demo.html
            if data.get('shadow', None):
                # offset=(2, -2, shadow_rgbFace='g'))
                shadowkw = data.get('shadow', None)
                path_effects += [patheffects.withSimplePatchShadow(**shadowkw)]

            stroke_info = data.get('stroke', None)
            if stroke_info not in [None, False]:
                if stroke_info is True:
                    strokekw = {}
                elif isinstance(stroke_info, dict):
                    strokekw = stroke_info.copy()
                else:
                    #linewidth=3, foreground='r'
                    assert False
                if strokekw is not None:
                    # Hack to increase lw
                    strokekw['linewidth'] = lw + strokekw.get('linewidth', 3)
                    path_effects += [patheffects.withStroke(**strokekw)]

            #for vert, code in path.iter_segments():
            #    print('code = %r' % (code,))
            #    print('vert = %r' % (vert,))
            #    if code == MOVETO:
            #        pass

            #for verts, code in path.cleaned().iter_segments():
            #    print('code = %r' % (code,))
            #    print('verts = %r' % (verts,))
            #    pass

            patch = mpl.patches.PathPatch(path, facecolor='none', lw=lw,
                                          path_effects=path_effects,
                                          edgecolor=color,
                                          #facecolor=color,
                                          linestyle=linestyle,
                                          alpha=alpha,
                                          joinstyle='bevel',
                                          hatch=hatch)

            if as_directed:
                if end_pt is not None:
                    dxy = (np.array(end_pt) - other_points[-1])
                    dxy = (dxy / np.sqrt(np.sum(dxy ** 2))) * .1
                    dx, dy = dxy
                    rx, ry = end_pt[0], end_pt[1]
                    patch1 = mpl.patches.FancyArrow(rx, ry, dx, dy, width=width,
                                                    length_includes_head=True,
                                                    color=color,
                                                    head_starts_at_zero=False)
                else:
                    dxy = (np.array(other_points[-1]) - other_points[-2])
                    dxy = (dxy / np.sqrt(np.sum(dxy ** 2))) * .1
                    dx, dy = dxy
                    rx, ry = other_points[-1][0], other_points[-1][1]
                    patch1 = mpl.patches.FancyArrow(rx, ry, dx, dy, width=width,
                                                    length_includes_head=True,
                                                    color=color,
                                                    head_starts_at_zero=True)
                #ax.add_patch(patch1)
                patch_dict['arrow_patch_list'][edge] = (patch1)

            taillabel = layout_info['edge']['taillabel'][edge]
            #ha = 'left'
            #ha = 'right'
            ha = 'center'
            va = 'center'
            labelcolor = color  # TODO allow for different colors

            labelcolor = data.get('labelcolor', color)
            labelcolor = fix_hex_color(labelcolor)
            labelcolor = labelcolor[0:3]

            if taillabel:
                taillabel_pos = layout_info['edge']['tail_lp'][edge]
                ax.annotate(taillabel, xy=taillabel_pos, xycoords='data',
                            color=labelcolor,
                            va=va, ha=ha, fontproperties=font_prop)
            headlabel = layout_info['edge']['headlabel'][edge]
            if headlabel:
                headlabel_pos = layout_info['edge']['head_lp'][edge]
                ax.annotate(headlabel, xy=headlabel_pos, xycoords='data',
                            color=labelcolor,
                            va=va, ha=ha, fontproperties=font_prop)
            label = layout_info['edge']['label'][edge]
            if label:
                label_pos = layout_info['edge']['lp'][edge]
                ax.annotate(label, xy=label_pos, xycoords='data',
                            color=labelcolor,
                            va=va, ha=ha, fontproperties=font_prop)
            patch_dict['edge_patch_dict'][edge] = patch
            #ax.add_patch(patch)

    if verbose:
        print('Adding %r node patches ' % (len(patch_dict['node_patch_dict'],)))
        print('Adding %r edge patches ' % (len(patch_dict['edge_patch_dict'],)))

    for frame in patch_dict['patch_frame_dict'].values():
        ax.add_patch(frame)

    for patch1 in patch_dict['arrow_patch_list'].values():
        ax.add_patch(patch1)

    use_collections = False
    if use_collections:
        edge_coll = mpl.collections.PatchCollection(patch_dict['edge_patch_dict'].values())
        node_coll = mpl.collections.PatchCollection(patch_dict['node_patch_dict'].values())
        #coll.set_facecolor(fcolor)
        #coll.set_alpha(alpha)
        #coll.set_linewidth(lw)
        #coll.set_edgecolor(color)
        #coll.set_transform(ax.transData)
        ax.add_collection(node_coll)
        ax.add_collection(edge_coll)
    else:
        for patch in patch_dict['node_patch_dict'].values():
            if isinstance(patch, mpl.collections.PatchCollection):
                ax.add_collection(patch)
            else:
                ax.add_patch(patch)
        if not hacknoedge:
            for patch in patch_dict['edge_patch_dict'].values():
                ax.add_patch(patch)

    for text_args in text_pseudo_objects:
        pt.ax_absolute_text(*text_args[0], **text_args[1])
    return patch_dict