def test_sample_from_simplex(self, n, dim, vmin):
     """Test `sample_from_simplex`."""
     x = utils.sample_from_simplex(n, dim=dim, vmin=vmin)
     np.testing.assert_allclose(np.sum(x, axis=1), np.ones(n))
     self.assertTrue(np.alltrue(x <= 1. - vmin))
     self.assertTrue(np.alltrue(x >= vmin))
Beispiel #2
0
    def streamplot(self,
                   dynamics,
                   initial_points=None,
                   dt=0.01,
                   density=1.,
                   min_length=0.4,
                   linewidth=None,
                   color="k",
                   **kwargs):
        """Visualizes the dynamics as a streamline plot.

    Mimics the visuals of `Axes.streamplot` for simplex plots.

    Args:
      dynamics: Population dynamics of type `dynamics.SinglePopulationDynamics`.
      initial_points: Starting points for streamlines
      dt: Integration step.
      density: Controls the density of streamlines in the plot.
      min_length: Streamlines with length < min_length will be discarded.
      linewidth: In `{None, float, "velocity"}`, optional, default: None. If
        `linewidth="velocity"`, line width is scaled by the velocity of the
        dynamics. Defaults to `rcParams` if `linewidth=None`.
      color: In `{None, string, (r,g,b), (r,g,b,a), "velocity"}`, default: None.
        If `color="velocity"`, velocity of dynamics is used to color the
        streamlines. Defaults to `rcParams` if `color=None`.
      **kwargs: Additional keyword arguments passed on to `Axes.streamplot`.

    Returns:
      The `SimplexStreamMask`.
    """
        mask = SimplexStreamMask(density=density)
        trajectories = []

        if initial_points is None:
            eps = 0.1
            initial_points = np.array([[1. - eps, eps / 2., eps / 2.],
                                       [eps / 2., 1. - eps, eps / 2.],
                                       [eps / 2., eps / 2., 1. - eps]])
            initial_points = np.vstack(
                (initial_points, utils.sample_from_simplex(100)))
            # TODO(author10): add heuristic for initial points

        else:
            initial_points = np.array(initial_points)
            assert initial_points.ndim == 2
            assert initial_points.shape[1] == 3

        # generate trajectories
        for p in initial_points:
            # center initial point on grid cell
            p = mask.point(mask.index(p))
            res = self._integrate(p, dynamics, mask, dt=dt)
            if res is not None:
                t, cells = res
                cum_len = np.cumsum(
                    np.sqrt(
                        np.diff(t[:, 0])**2 + np.diff(t[:, 1])**2 +
                        np.diff(t[:, 2])**2))
                if cum_len[-1] < min_length:
                    for cell in cells:
                        mask[mask.point(cell)] = False
                    continue
                trajectories.append(t)

        lc_color = arrow_color = color
        lc_linewidth = linewidth

        if linewidth == "velocity" or color == "velocity":
            vel_max = 0
            vel_min = np.float("inf")
            velocities = []
            for t in trajectories:
                dx = np.apply_along_axis(dynamics, 1, t)
                vel = np.sqrt(np.sum(dx**2, axis=1))
                vel_max = max(np.max(vel), vel_max)
                vel_min = min(np.min(vel), vel_min)
                velocities.append(vel)

        # add trajectories to plot
        for i, t in enumerate(trajectories):
            cum_len = np.cumsum(
                np.sqrt(
                    np.diff(t[:, 0])**2 + np.diff(t[:, 1])**2 +
                    np.diff(t[:, 2])**2))
            mid_idx = np.searchsorted(cum_len, cum_len[-1] / 2.)

            if linewidth == "velocity" or color == "velocity":
                vel = (velocities[i] - vel_min) / vel_max

                if linewidth == "velocity":
                    lc_linewidth = 3. * vel + 0.5

                if color == "velocity":
                    cmap = matplotlib.cm.get_cmap(rcParams["image.cmap"])
                    lc_color = cmap(vel)
                    arrow_color = cmap(vel[mid_idx])

            lc = self._linecollection(t,
                                      linewidth=lc_linewidth,
                                      color=lc_color)
            self.add_collection(lc)

            # add arrow centered on trajectory
            arrow_tail = self._simplex_transform.transform(t[mid_idx - 1])
            arrow_head = self._simplex_transform.transform(t[mid_idx])
            arrow_kw = dict(arrowstyle="-|>", mutation_scale=10 * 1.)
            arrow_patch = FancyArrowPatch(arrow_tail,
                                          arrow_head,
                                          linewidth=None,
                                          color=arrow_color,
                                          zorder=3,
                                          **arrow_kw)
            self.add_patch(arrow_patch)
        return mask