コード例 #1
0
    def draw(self, axes, feature, bbox, loc, style_param):
        from matplotlib.patches import FancyArrowPatch, ArrowStyle
        from matplotlib.path import Path

        x_center = bbox.x0 + bbox.width / 2
        y_center = bbox.y0 + bbox.height / 2

        path = Path(vertices=[
            (bbox.x0, y_center),
            (bbox.x0, y_center - bbox.height / 2 * self._head_height),
            (bbox.x1, y_center - bbox.height / 2 * self._head_height),
        ],
                    codes=[Path.MOVETO, Path.CURVE3, Path.CURVE3])
        style = ArrowStyle.CurveFilledB(head_width=self._head_width,
                                        head_length=self._head_length)
        arrow = FancyArrowPatch(path=path,
                                arrowstyle=style,
                                linewidth=self._line_width,
                                color="black")
        axes.add_patch(arrow)

        if "note" in feature.qual:
            axes.text(x_center,
                      y_center + bbox.height / 4,
                      feature.qual["note"],
                      color="black",
                      ha="center",
                      va="center",
                      size=9)
コード例 #2
0
    def plot_world(self, world):
        """
        :param world: class World, the world we want to plot
        :return: fig, ax:  a figure and an axes (matplotlib), these contain the plot of the world
        """
        # Declare figure and axis
        fig, ax = plt.subplots(1, 1)
        fig.set_size_inches(12, 12)

        # Draw graph (see draw_networkx for all possible parameters)
        nx.draw_networkx(world.graph,
                         pos=nx.circular_layout(world.graph),
                         ax=ax,
                         arrowstyle=ArrowStyle.CurveFilledB(head_length=0.5,
                                                            head_width=0.3),
                         arrowsize=max(-0.04 * world.graph.order() + 10, 1),
                         with_labels=True,
                         node_size=500,
                         node_color=[
                             self.determine_color_node(agent)
                             for agent in world.agents.values()
                         ],
                         alpha=0.8,
                         linewidths=0.0,
                         width=0.2)

        # Add a legend to the plot
        ax.legend(handles=self.legend)

        return fig, ax
コード例 #3
0
    def _next_frame(self, t):
        """
        Creates plot for the next frame in the animation.

        :param t: integer, the time step in the animation
        :return: matplotlib.plot, a plot which is the next frame in the animation
        """
        print(t, end='\r')
        # Clear the axis from the previous plot
        self.animation_axis.clear()

        # Fill the axis with the new plot
        world = self.simulation.simulation_data[t]
        nx.draw_networkx(world.graph,
                         pos=nx.circular_layout(world.graph),
                         ax=self.animation_axis,
                         arrowstyle=ArrowStyle.CurveFilledB(head_length=0.5,
                                                            head_width=0.3),
                         arrowsize=max(-0.04 * world.graph.order() + 10, 1),
                         with_labels=True,
                         node_size=500,
                         node_color=[
                             self.determine_color_node(agent)
                             for agent in world.agents.values()
                         ],
                         alpha=0.8,
                         linewidths=0.0,
                         width=0.2)

        # Add legend and title
        self.animation_axis.legend(handles=self.legend)
        self.animation_axis.set_title('Frame ' + str(t))

        # Return the plot
        return self.animation_axis.plot()
コード例 #4
0
ファイル: _graph.py プロジェクト: nkandhari/cellrank
    def plot_arrows(curves, G, pos, ax, edge_weight_scale):
        for line, (edge, val) in zip(curves, G.edges.items()):
            if edge[0] == edge[1]:
                continue

            mask = (~np.isnan(line)).all(axis=1)
            line = line[mask, :]
            if not len(line):  # can be all NaNs
                continue

            line = line.reshape((-1, 2))
            X, Y = line[:, 0], line[:, 1]

            node_start = pos[edge[0]]
            # reverse
            if np.where(np.isclose(node_start - line,
                                   [0, 0]).all(axis=1))[0][0]:
                X, Y = X[::-1], Y[::-1]

            mid = len(X) // 2
            posA, posB = zip(X[mid:mid + 2], Y[mid:mid + 2])  # noqa

            arrow = FancyArrowPatch(
                posA=posA,
                posB=posB,
                # we clip because too small values
                # cause it to crash
                arrowstyle=ArrowStyle.CurveFilledB(
                    head_length=np.clip(
                        val["weight"] * edge_weight_scale * 4,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                    head_width=np.clip(
                        val["weight"] * edge_weight_scale * 2,
                        _min_edge_weight,
                        edge_width_limit,
                    ),
                ),
                color="k",
                zorder=float("inf"),
                alpha=edge_alpha,
                linewidth=0,
            )
            ax.add_artist(arrow)
コード例 #5
0
    s_ax.set_xlabel("timestep")
    s_ax.set_ylabel("cascade size")
    s_ax.legend()

    g_ax.clear()
    clrs = [
        ColorMaps.coolwarm(1.0) if a in fakenews_spreader.keys() else
        ColorMaps.coolwarm(0.0) if a in counternews_spreader.keys() else
        ColorMaps.coolwarm(0.97) if a.states[0] == AgentState.ACTIVE else
        ColorMaps.coolwarm(0.02) if a.states[1] == AgentState.ACTIVE else
        (0.0, 0.0, 0.0) for a in w.agents.values()
    ]

    nx.draw_networkx(w.graph,
                     pos=layout,
                     ax=g_ax,
                     arrowstyle=ArrowStyle.CurveFilledB(head_length=0.5,
                                                        head_width=0.3),
                     arrowsize=max(-0.04 * w.graph.order() + 10, 1),
                     with_labels=False,
                     node_size=200,
                     node_color=clrs,
                     alpha=0.9,
                     linewidths=0.0,
                     width=0.2)
    plt.savefig("animation/counternews" + str(step) + ".png")
    x = w.update(verbose=True)
    #if x == 0 and step > max(starting_points.keys()): break
    print(step, end=", ", flush=True)
print("done")