def plot_it(self, words, postags, nes, deps): # words和postags里有Root self.words = ['Root'] + words self.postags = [''] + postags self.nes = nes self.deps = deps self.__words = [] fig, ax = plt.subplots() inv = ax.transData.inverted() fig.canvas.draw() # 有些数据必须在渲染的过程中获取,比如一个词在画布上的长度,必须渲染出来 # 渲染词 start_x = INDENT_LEFT for i in range(len(self.words)): the_word_dict = {} the_word_dict['x'] = start_x word_text = self.words[i] the_word_dict['word'] = word_text the_text = ax.text(start_x, WORD_Y, word_text) renderer = ax.get_renderer_cache() position_data = inv.transform(the_text.get_window_extent(renderer)) the_text_width = position_data[1][0] - position_data[0][0] the_word_dict['arc_x'] = start_x + the_text_width / 2 start_x = start_x + the_text_width + WORD_GAP self.__words.append(the_word_dict) ax.set_xlim(0, start_x - WORD_GAP + INDENT_LEFT) # 渲染词性 for i in range(len(self.postags)): ax.text(self.__words[i]['arc_x'], POS_Y, self.postags[i], horizontalalignment='center') # 渲染依存关系 y_lim = 0 for i in range(len(self.deps)): from_index = i + 1 to_index = self.deps[i].head start_point = [self.__words[from_index]['arc_x'], ARC_Y] end_point = [self.__words[to_index]['arc_x'], ARC_Y] rad = 0.5 if end_point[0] - start_point[0] > 0 else -0.5 the_arc = FancyArrowPatch(end_point, start_point, connectionstyle="arc3,rad=%f" % rad, **ARC_STYLE) the_arc = ax.add_patch(the_arc) position_data = the_arc.get_path().get_extents(transform=None) dep_relation_x = position_data.x0 + (position_data.x1 - position_data.x0) / 2 dep_relation_y = position_data.y0 + (position_data.y1 - position_data.y0) / 2 y_lim = dep_relation_y if y_lim < dep_relation_y else y_lim ax.text(dep_relation_x, dep_relation_y, self.deps[i].relation, bbox=RELATION_BACKGROUND_STYLE, horizontalalignment='center', verticalalignment='center') # ax.set_ylim(0,y_lim+POS_Y) ax.axis('off') plt.show()
class GraphArrow: def __init__(self, angleA, angleB, r=1, inner=False, outer=False, stretch_inner=0.8, stretch_outer=1.1, shrinkA=20, shrinkB=20, arrowstyle='-|>', head_length=3, head_width=2, connectionstyle="arc3,rad=0.3", outer_connectionstyle="arc3,rad=-0.3", edgewidth=1, edgecolor='k', ax=None, label=None, labelpos=0.5, label_facecolor='white', label_edgecolor='white', label_color='k', xytext=None, fontname='DejaVu Sans', fontsize=10): self.angleA = angleA self.angleB = angleB self.r = r self.posA = np.array([np.cos(self.angleA), np.sin(self.angleA)]) self.posB = np.array([np.cos(self.angleB), np.sin(self.angleB)]) self.outer = outer self.stretch_outer = stretch_outer self.shrinkA = shrinkA self.shrinkB = shrinkB self.arrowstyle = arrowstyle self.head_length = head_length self.head_width = head_width self.connectionstyle = connectionstyle self.outer_connectionstyle = outer_connectionstyle self.edgewidth = edgewidth self.edgecolor = edgecolor self.ax = ax self.label = label self.labelpos = labelpos self.label_facecolor = label_facecolor self.label_edgecolor = label_edgecolor self.label_color = label_color self.inner = inner self.stretch_inner = stretch_inner self.xytext = xytext self.fontname = fontname self.fontsize = fontsize def draw(self, ax=None): self.ax = ax or self.ax or plt.gca() self.arrow = FancyArrowPatch( posA=(self.posA if not self.outer else (self.posA * self.stretch_outer)) * self.r, posB=(self.posB if not self.outer else (self.posB * self.stretch_outer)) * self.r, arrowstyle=ArrowStyle(self.arrowstyle, head_length=self.head_length, head_width=self.head_width), shrinkA=(self.shrinkA * (self.stretch_outer**2 if self.outer else 1)), shrinkB=(self.shrinkB * (self.stretch_outer**2 if self.outer else 1)), connectionstyle=self.connectionstyle if not self.outer else self.outer_connectionstyle, linewidth=self.edgewidth, color=self.edgecolor) self.ax.add_patch(self.arrow) if self.label: verts = self.arrow.get_path().vertices point = verts[1] norm = np.linalg.norm(point) self.ax.text( *((self.xytext if self.xytext is not None else point) * (self.stretch_inner if self.inner else 1.12)), self.label, horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor=self.label_facecolor, edgecolor=self.label_edgecolor, alpha=0), color=self.label_color, fontname=self.fontname, fontsize=self.fontsize)