Exemplo n.º 1
0
    def plot_it(self, words, postags, nes, deps):
        # words和postags里有Root
        self.words = ['Root'] + words
        self.postags = [''] + postags
        self.nes = nes
        self.deps = deps
        self.__words = []

        fig, ax = plt.subplots()
        inv = ax.transData.inverted()
        fig.canvas.draw()
        # 有些数据必须在渲染的过程中获取,比如一个词在画布上的长度,必须渲染出来
        # 渲染词
        start_x = INDENT_LEFT
        for i in range(len(self.words)):
            the_word_dict = {}
            the_word_dict['x'] = start_x
            word_text = self.words[i]
            the_word_dict['word'] = word_text
            the_text = ax.text(start_x, WORD_Y, word_text)
            renderer = ax.get_renderer_cache()
            position_data = inv.transform(the_text.get_window_extent(renderer))
            the_text_width = position_data[1][0] - position_data[0][0]
            the_word_dict['arc_x'] = start_x + the_text_width / 2
            start_x = start_x + the_text_width + WORD_GAP
            self.__words.append(the_word_dict)
        ax.set_xlim(0, start_x - WORD_GAP + INDENT_LEFT)

        # 渲染词性
        for i in range(len(self.postags)):
            ax.text(self.__words[i]['arc_x'],
                    POS_Y,
                    self.postags[i],
                    horizontalalignment='center')

        # 渲染依存关系
        y_lim = 0
        for i in range(len(self.deps)):
            from_index = i + 1
            to_index = self.deps[i].head
            start_point = [self.__words[from_index]['arc_x'], ARC_Y]
            end_point = [self.__words[to_index]['arc_x'], ARC_Y]
            rad = 0.5 if end_point[0] - start_point[0] > 0 else -0.5
            the_arc = FancyArrowPatch(end_point,
                                      start_point,
                                      connectionstyle="arc3,rad=%f" % rad,
                                      **ARC_STYLE)
            the_arc = ax.add_patch(the_arc)
            position_data = the_arc.get_path().get_extents(transform=None)
            dep_relation_x = position_data.x0 + (position_data.x1 -
                                                 position_data.x0) / 2
            dep_relation_y = position_data.y0 + (position_data.y1 -
                                                 position_data.y0) / 2
            y_lim = dep_relation_y if y_lim < dep_relation_y else y_lim
            ax.text(dep_relation_x,
                    dep_relation_y,
                    self.deps[i].relation,
                    bbox=RELATION_BACKGROUND_STYLE,
                    horizontalalignment='center',
                    verticalalignment='center')
        # ax.set_ylim(0,y_lim+POS_Y)
        ax.axis('off')

        plt.show()
Exemplo n.º 2
0
class GraphArrow:
    def __init__(self,
                 angleA,
                 angleB,
                 r=1,
                 inner=False,
                 outer=False,
                 stretch_inner=0.8,
                 stretch_outer=1.1,
                 shrinkA=20,
                 shrinkB=20,
                 arrowstyle='-|>',
                 head_length=3,
                 head_width=2,
                 connectionstyle="arc3,rad=0.3",
                 outer_connectionstyle="arc3,rad=-0.3",
                 edgewidth=1,
                 edgecolor='k',
                 ax=None,
                 label=None,
                 labelpos=0.5,
                 label_facecolor='white',
                 label_edgecolor='white',
                 label_color='k',
                 xytext=None,
                 fontname='DejaVu Sans',
                 fontsize=10):
        self.angleA = angleA
        self.angleB = angleB
        self.r = r
        self.posA = np.array([np.cos(self.angleA), np.sin(self.angleA)])
        self.posB = np.array([np.cos(self.angleB), np.sin(self.angleB)])
        self.outer = outer
        self.stretch_outer = stretch_outer
        self.shrinkA = shrinkA
        self.shrinkB = shrinkB
        self.arrowstyle = arrowstyle
        self.head_length = head_length
        self.head_width = head_width
        self.connectionstyle = connectionstyle
        self.outer_connectionstyle = outer_connectionstyle
        self.edgewidth = edgewidth
        self.edgecolor = edgecolor
        self.ax = ax
        self.label = label
        self.labelpos = labelpos
        self.label_facecolor = label_facecolor
        self.label_edgecolor = label_edgecolor
        self.label_color = label_color
        self.inner = inner
        self.stretch_inner = stretch_inner
        self.xytext = xytext
        self.fontname = fontname
        self.fontsize = fontsize

    def draw(self, ax=None):
        self.ax = ax or self.ax or plt.gca()
        self.arrow = FancyArrowPatch(
            posA=(self.posA if not self.outer else
                  (self.posA * self.stretch_outer)) * self.r,
            posB=(self.posB if not self.outer else
                  (self.posB * self.stretch_outer)) * self.r,
            arrowstyle=ArrowStyle(self.arrowstyle,
                                  head_length=self.head_length,
                                  head_width=self.head_width),
            shrinkA=(self.shrinkA *
                     (self.stretch_outer**2 if self.outer else 1)),
            shrinkB=(self.shrinkB *
                     (self.stretch_outer**2 if self.outer else 1)),
            connectionstyle=self.connectionstyle
            if not self.outer else self.outer_connectionstyle,
            linewidth=self.edgewidth,
            color=self.edgecolor)
        self.ax.add_patch(self.arrow)

        if self.label:
            verts = self.arrow.get_path().vertices
            point = verts[1]
            norm = np.linalg.norm(point)
            self.ax.text(
                *((self.xytext if self.xytext is not None else point) *
                  (self.stretch_inner if self.inner else 1.12)),
                self.label,
                horizontalalignment='center',
                verticalalignment='center',
                bbox=dict(facecolor=self.label_facecolor,
                          edgecolor=self.label_edgecolor,
                          alpha=0),
                color=self.label_color,
                fontname=self.fontname,
                fontsize=self.fontsize)