Ejemplo n.º 1
0
def plot_field(fig: plt.Figure, field: np.array, location: tuple, side: int):
    n = np.sqrt(field.shape[0]).astype(int)
    min_x, max_x = location[0] - side / 2, location[0] + side / 2
    min_y, max_y = location[1] - side / 2, location[1] + side / 2

    fig.gca().imshow(field.reshape(n, n),
                     extent=[min_x, max_x, max_y, min_y],
                     cmap=plt.get_cmap('Greens').reversed())
    fig.gca().set_yticks([])
    fig.gca().set_xticks([])
    return fig
Ejemplo n.º 2
0
def plot_image_hierarchy(fig: plt.Figure, G: nx.Graph):
    # Adapted from https://stackoverflow.com/questions/53967392/creating-a-graph-with-images-as-nodes
    ax = fig.gca()
    ax.set_aspect('equal', adjustable='box')

    write_dot(G, 'test.dot')
    pos = graphviz_layout(G, prog='dot', args='-Nsep="+1000,+1000";')
    nx.draw_networkx_edges(G, pos, ax=ax)

    trans = ax.transData.transform
    trans2 = fig.transFigure.inverted().transform

    piesize = 0.2  # this is the image size
    p2 = piesize / 2.0
    for n in G:
        xx, yy = trans(pos[n])  # figure coordinates
        xa, ya = trans2((xx, yy))  # axes coordinates
        a = plt.axes([xa - p2, ya - p2, piesize, piesize])
        a.set_aspect('equal')
        a.imshow(plt.imread(G.node[n]['image']))
        a.set_xticks([])
        a.set_yticks([])
        # a.set_title(n.name)

    ax.axis('off')
Ejemplo n.º 3
0
def plot_learning_path(fig: plt.Figure, paths):
    for history in paths:
        # First point will be blue
        sns.scatterplot(history[0:1, 0],
                        history[0:1, 1],
                        color='blue',
                        ax=fig.gca())
        # Intermediate steps are red
        sns.scatterplot(history[1:, 0],
                        history[1:, 1],
                        color='red',
                        ax=fig.gca())
        # Last point will be black
        sns.scatterplot(history[-1:, 0],
                        history[-1:, 1],
                        color='black',
                        ax=fig.gca())

    return fig
Ejemplo n.º 4
0
    def _fix_shap_force_figure(fig: plt.Figure) -> plt.Figure:
        """
        Replaces the figure annotation in shap force plots with "absent" if value = 0.0 and
        "present" if value = 1.0.

        :param fig: a matplotlib.pyplot.Figure as produced by shap.force_plot.
        :return: The same fig, modified as described.
        """
        ax = fig.gca()
        for c in ax.get_children():
            if isinstance(c, plt.Text):
                t = c.get_text()
                if t.endswith(' = 1.0'):
                    c.set_text(t.replace(' = 1.0', ' present'))
                elif t.endswith(' = 0.0'):
                    c.set_text(t.replace(' = 0.0', ' absent'))
                else:
                    pass
        return fig
Ejemplo n.º 5
0
    def reset_figure(fig: pyplot.Figure, ax: pyplot.Axes) -> None:
        """清空并重置一个画布
        """
        fig.gca().cla()
        fig.gca().set_title('Wireless Sensor Networks')
        fig.gca().set_xlabel('x')
        fig.gca().set_ylabel('y')
        fig.set_size_inches(8, 6)
        ax.set_position((0.1, 0.11, 0.6, 0.8))

        legend_elements = (
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='red',
                          label='source'),
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='green',
                          label='alive'),
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='orange',
                          label='received'),
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='yellow',
                          label='replied'),
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='blue',
                          label='sending'),
            pyplot.Line2D(xdata=[],
                          ydata=[],
                          marker='.',
                          linewidth=0,
                          color='black',
                          label='dead'),
            pyplot.Circle(xy=(0, 0),
                          radius=0,
                          alpha=0.4,
                          color='red',
                          label='range of signal\n(source node)'),
            pyplot.Circle(xy=(0, 0),
                          radius=0,
                          alpha=0.4,
                          color='green',
                          label='range of signal\n(alive node)'),
            pyplot.Circle(xy=(0, 0),
                          radius=0,
                          alpha=0.4,
                          color='orange',
                          label='range of signal\n(received node)'),
            pyplot.Circle(xy=(0, 0),
                          radius=0,
                          alpha=0.4,
                          color='yellow',
                          label='range of signal\n(replied node)'),
            pyplot.Circle(xy=(0, 0),
                          radius=0,
                          alpha=0.4,
                          color='blue',
                          label='range of signal\n(sending node)'),
        )
        ax.legend(handles=legend_elements,
                  loc='upper left',
                  bbox_to_anchor=(1.02, 1),
                  borderaxespad=0)