Beispiel #1
0
def interactive_hist(
    arr,
    density=False,
    bins="auto",
    weights=None,
    ax=None,
    slider_formats=None,
    force_ipywidgets=False,
    play_buttons=False,
    controls=None,
    display_controls=True,
    **kwargs,
):
    """
    Control the contents of a histogram using widgets.

    See https://github.com/ianhi/mpl-interactions/pull/73#issue-470638134 for a discussion
    of the limitations of this function. These limitations will be improved once
    https://github.com/matplotlib/matplotlib/pull/18275 has been merged.

    parameters
    ----------
    arr : arraylike or function
        The array or the funciton that returns an array that is to be histogrammed
    density : bool, optional
        whether to plot as a probability density. Passed to np.histogram
    bins : int or sequence of scalars or str, optional
        bins argument to np.histogram
    weights : array_like, optional
        passed to np.histogram
    ax : matplotlib axis, optional
        The axis on which to plot. If none the current axis will be used.
    slider_formats : None, string, or dict
        If None a default value of decimal points will be used. Uses the new {} style formatting
    force_ipywidgets : boolean
        If True ipywidgets will always be used, even if not using the ipympl backend.
        If False the function will try to detect if it is ok to use ipywidgets
        If ipywidgets are not used the function will fall back on matplotlib widgets
    play_buttons : bool or str or dict, optional
        Whether to attach an ipywidgets.Play widget to any sliders that get created.
        If a boolean it will apply to all kwargs, if a dictionary you choose which sliders you
        want to attach play buttons too.
            - None: no sliders
            - True: sliders on the lft
            - False: no sliders
            - 'left': sliders on the left
            - 'right': sliders on the right
    controls : mpl_interactions.controller.Controls
        An existing controls object if you want to tie multiple plot elements to the same set of
        controls
    display_controls : boolean
        Whether the controls should display themselve on creation. Ignored if controls is specified.

    returns
    -------
    controls

    Examples
    --------

    With numpy arrays::

        loc = np.linspace(-5, 5, 500)
        scale = np.linspace(1, 10, 100)
        def f(loc, scale):
            return np.random.randn(1000)*scale + loc
        interactive_hist(f, loc=loc, scale=scale)

    with tuples::

        def f(loc, scale):
            return np.random.randn(1000)*scale + loc
        interactive_hist(f, loc=(-5, 5, 500), scale=(1, 10, 100))
    """

    ipympl = notebook_backend()
    fig, ax = gogogo_figure(ipympl, ax=ax)
    use_ipywidgets = ipympl or force_ipywidgets
    slider_formats = create_slider_format_dict(slider_formats)
    controls, params = gogogo_controls(kwargs, controls, display_controls,
                                       slider_formats, play_buttons)
    pc = PatchCollection([])
    ax.add_collection(pc, autolim=True)

    def update(params, indices, cache):
        arr_ = callable_else_value(arr, params, cache)
        new_x, new_y, new_patches = simple_hist(arr_,
                                                density=density,
                                                bins=bins,
                                                weights=weights)
        stretch(ax, new_x, new_y)
        pc.set_paths(new_patches)
        ax.autoscale_view()

    controls.register_function(update, fig, params.keys())

    new_x, new_y, new_patches = simple_hist(callable_else_value(arr, params),
                                            density=density,
                                            bins=bins,
                                            weights=weights)
    sca(ax)
    pc.set_paths(new_patches)
    ax.set_xlim(new_x)
    ax.set_ylim(new_y)

    return controls
def interactive_hist(
    f,
    density=False,
    bins="auto",
    weights=None,
    figsize=None,
    ax=None,
    slider_format_string=None,
    display=True,
    force_ipywidgets=False,
    play_buttons=False,
    play_button_pos="right",
    **kwargs,
):
    """
    Control the contents of a histogram using widgets.

    See https://github.com/ianhi/mpl-interactions/pull/73#issue-470638134 for a discussion
    of the limitations of this function. These limitations will be improved once
    https://github.com/matplotlib/matplotlib/pull/18275 has been merged.

    parameters
    ----------
    f : function
        A function that will return a 1d array of which to take the histogram
    density : bool, optional
        whether to plot as a probability density. Passed to np.histogram
    bins : int or sequence of scalars or str, optional
        bins argument to np.histogram
    weights : array_like, optional
        passed to np.histogram
    figsize : tuple or scalar
        If tuple it will be used as the matplotlib figsize. If a number
        then it will be used to scale the current rcParams figsize
    ax : matplotlib axis, optional
        If None a new figure and axis will be created
    slider_format_string : None, string, or dict
        If None a default value of decimal points will be used. For ipywidgets this uses the new f-string formatting
        For matplotlib widgets you need to use `%` style formatting. A string will be used as the default
        format for all values. A dictionary will allow assigning different formats to different sliders.
        note: For matplotlib >= 3.3 a value of None for slider_format_string will use the matplotlib ScalarFormatter
        object for matplotlib slider values.
    display : boolean
        If True then the output and controls will be automatically displayed
    force_ipywidgets : boolean
        If True ipywidgets will always be used, even if not using the ipympl backend.
        If False the function will try to detect if it is ok to use ipywidgets
        If ipywidgets are not used the function will fall back on matplotlib widgets
    play_buttons : bool or dict, optional
        Whether to attach an ipywidgets.Play widget to any sliders that get created.
        If a boolean it will apply to all kwargs, if a dictionary you choose which sliders you
        want to attach play buttons too.
    play_button_pos : str, or dict, or list(str)
        'left' or 'right'. Whether to position the play widget(s) to the left or right of the slider(s)

    returns
    -------
    fig : matplotlib figure
    ax : matplotlib axis
    controls : list of widgets

    Examples
    --------

    With numpy arrays::

        loc = np.linspace(-5, 5, 500)
        scale = np.linspace(1, 10, 100)
        def f(loc, scale):
            return np.random.randn(1000)*scale + loc
        interactive_hist(f, loc=loc, scale=scale)

    with tuples::

        def f(loc, scale):
            return np.random.randn(1000)*scale + loc
        interactive_hist(f, loc=(-5, 5, 500), scale=(1, 10, 100))
    """

    params = {}
    funcs = np.atleast_1d(f)
    # supporting more would require more thought
    if len(funcs) != 1:
        raise ValueError(
            f"Currently only a single function is supported. You passed in {len(funcs)} functions"
        )

    ipympl = notebook_backend()
    fig, ax = gogogo_figure(ipympl, figsize=figsize, ax=ax)
    use_ipywidgets = ipympl or force_ipywidgets

    pc = PatchCollection([])
    ax.add_collection(pc, autolim=True)

    slider_format_strings = create_slider_format_dict(slider_format_string,
                                                      use_ipywidgets)

    # update plot
    def update(change, key, label):
        if label:
            # continuous
            params[key] = kwargs[key][change["new"]]
            label.value = slider_format_strings[key].format(
                kwargs[key][change["new"]])
        else:
            # categorical
            params[key] = change["new"]
        arr = funcs[0](**params)
        new_x, new_y, new_patches = simple_hist(arr,
                                                density=density,
                                                bins=bins,
                                                weights=weights)
        stretch(ax, new_x, new_y)
        pc.set_paths(new_patches)
        ax.autoscale_view()
        fig.canvas.draw_idle()

    # this line implicitly fills the params dict
    if use_ipywidgets:
        (
            sliders,
            slabels,
            controls,
            play_buttons,
        ) = kwargs_to_ipywidgets(kwargs, params, update, slider_format_strings,
                                 play_buttons, play_button_pos)
    else:
        controls = kwargs_to_mpl_widgets(kwargs, params, update,
                                         slider_format_strings)

    new_x, new_y, new_patches = simple_hist(funcs[0](**params),
                                            density=density,
                                            bins=bins,
                                            weights=weights)
    pc.set_paths(new_patches)
    ax.set_xlim(new_x)
    ax.set_ylim(new_y)

    controls = gogogo_display(ipympl, use_ipywidgets, display, controls, fig)
    return fig, ax, controls
Beispiel #3
0
class Strat(object):
    def __init__(self, gui):
        '''
        initiation of the main strat object
        '''

        self.gui = gui
        # self.gui.strat_ax = gui.strat_ax
        self.fig = gui.fig

        self.sm = gui.sm
        self.config = gui.config

        self.Bast = self.sm.Bast

        self.avul_num = 0
        self.color = False
        self.avulCmap = plt.cm.Set1(range(9))

        # self._paused = gui._paused

        # create an active channel and corresponding PatchCollection
        self.activeChannel = ActiveChannel(Bast=self.Bast,
                                           age=0,
                                           Ta=self.sm.Ta,
                                           avul_num=0,
                                           sm=self.sm)
        self.activeChannelPatchCollection = PatchCollection([
            Rectangle(self.activeChannel.state.ll, self.activeChannel.state.Bc,
                      self.activeChannel.state.H)
        ])

        # create a channelbody and corresponding PatchCollection
        self.channelBodyList = []
        self.channelBodyPatchCollection = PatchCollection(self.channelBodyList)

        # add PatchCollestions
        self.gui.strat_ax.add_collection(self.channelBodyPatchCollection)
        self.gui.strat_ax.add_collection(self.activeChannelPatchCollection)

        # set fixed color attributes of PatchCollections
        self.channelBodyPatchCollection.set_edgecolor('0')
        self.activeChannelPatchCollection.set_facecolor('0.6')
        self.activeChannelPatchCollection.set_edgecolor('0')

        self.BastLine, = self.gui.strat_ax.plot(
            [-self.sm.Bbmax * 1000 / 2, gui.sm.Bbmax * 1000 / 2],
            [self.Bast, self.Bast],
            'k--',
            animated=False)  # plot basin top
        self.VE_val = plt.text(0.675,
                               0.025,
                               'VE = ' +
                               str(round(self.sm.Bb / self.sm.yView, 1)),
                               fontsize=12,
                               transform=self.gui.strat_ax.transAxes,
                               backgroundcolor='white')

    def __call__(self, i):
        '''
        called every loop
        '''

        # find new slider vals
        self.sm.get_all()

        if not self.gui._paused:
            # timestep the current channel objects
            dz = self.sm.sig * self.sm.dt
            for c in self.channelBodyList:
                c.subside(dz)

            if not self.activeChannel.avulsed:
                # when an avulsion has not occurred:
                self.activeChannel.timestep()

            else:
                # once an avulsion has occurred:
                self.channelBodyList.append(ChannelBody(self.activeChannel))
                self.avul_num += 1
                self.color = True

                # create a new Channel
                self.activeChannel = ActiveChannel(Bast=self.Bast,
                                                   age=i,
                                                   Ta=self.sm.Ta,
                                                   avul_num=self.avul_num,
                                                   sm=self.sm)

                # remove outdated channels
                stratMin = self.Bast - self.sm.yViewmax
                outdatedIdx = [
                    c.polygonYs.max() < stratMin for c in self.channelBodyList
                ]
                self.channelBodyList = [
                    c for (c, i) in zip(self.channelBodyList, outdatedIdx)
                    if not i
                ]

        # generate new patch lists for updating the PatchCollection objects
        self.activeChannelPatches = [
            Rectangle(s.ll, s.Bc, s.H)
            for s in iter(self.activeChannel.stateList)
        ]
        self.channelBodyPatchList = [
            c.get_patch() for c in self.channelBodyList
        ]

        # set paths of the PatchCollection Objects
        self.channelBodyPatchCollection.set_paths(self.channelBodyPatchList)
        self.activeChannelPatchCollection.set_paths(self.activeChannelPatches)

        # self.qs = sedtrans.qsEH(D50, Cf,
        #                         sedtrans.taubfun(self.channel.H, self.channel.S, cong, conrhof),
        #                         conR, cong, conrhof)  # sedment transport rate based on new geom

        # update plot
        if self.color:
            if self.sm.colFlag == 'age':
                age_array = np.array([c.age for c in self.channelBodyList])
                if age_array.size > 0:
                    self.channelBodyPatchCollection.set_array(age_array)
                    self.channelBodyPatchCollection.set_clim(
                        vmin=age_array.min(), vmax=age_array.max())
                    self.channelBodyPatchCollection.set_cmap(plt.cm.viridis)
            elif self.sm.colFlag == 'Qw':
                self.channelBodyPatchCollection.set_array(
                    np.array([c.Qw for c in self.channelBodyList]))
                self.channelBodyPatchCollection.set_clim(
                    vmin=self.config.Qwmin, vmax=self.config.Qwmax)
                self.channelBodyPatchCollection.set_cmap(plt.cm.viridis)
            elif self.sm.colFlag == 'avul':
                self.channelBodyPatchCollection.set_array(
                    np.array([c.avul_num % 9 for c in self.channelBodyList]))
                self.channelBodyPatchCollection.set_clim(vmin=0, vmax=9)
                self.channelBodyPatchCollection.set_cmap(plt.cm.Set1)
            elif self.sm.colFlag == 'sig':
                sig_array = np.array([c.sig for c in self.channelBodyList])
                self.channelBodyPatchCollection.set_array(sig_array)
                self.channelBodyPatchCollection.set_clim(
                    vmin=self.config.sigmin / 1000,
                    vmax=self.config.sigmax / 1000)
                self.channelBodyPatchCollection.set_cmap(plt.cm.viridis)

        # yview and xview
        ylims = utils.new_ylims(yView=self.sm.yView, Bast=self.Bast)
        self.gui.strat_ax.set_ylim(ylims)
        self.gui.strat_ax.set_xlim(-self.sm.Bb / 2, self.sm.Bb / 2)

        # vertical exagg text
        if i % 10 == 0:
            self.axbbox = self.gui.strat_ax.get_window_extent().transformed(
                self.fig.dpi_scale_trans.inverted())
            width, height = self.axbbox.width, self.axbbox.height
            self.VE_val.set_text(
                'VE = ' +
                str(round((self.sm.Bb / width) / (self.sm.yView / height), 1)))

        return self.BastLine, self.VE_val, \
               self.channelBodyPatchCollection, self.activeChannelPatchCollection
Beispiel #4
0
class SceneVisualizer:
    """Context for social nav vidualization"""
    def __init__(self,
                 scene,
                 output=None,
                 limits=None,
                 writer="imagemagick",
                 cmap="viridis",
                 agent_colors=None,
                 **kwargs):
        self.scene = scene
        self.states, self.group_states = self.scene.get_states()
        self.cmap = cmap
        self.agent_colors = agent_colors
        self.frames = self.scene.get_length()
        self.output = output
        self.writer = writer
        self.limits = limits

        self.fig, self.ax = plt.subplots(**kwargs)

        self.ani = None

        self.group_actors = None
        self.group_collection = PatchCollection([])
        self.group_collection.set(
            animated=True,
            alpha=0.2,
            cmap=self.cmap,
            facecolors="none",
            edgecolors="purple",
            linewidth=2,
            clip_on=True,
        )

        self.human_actors = None
        self.human_collection = PatchCollection([])
        self.human_collection.set(animated=True,
                                  alpha=0.6,
                                  cmap=self.cmap,
                                  clip_on=True)

    def plot(self):
        """Main method for create plot"""
        self.plot_obstacles()
        self.plot_fires()
        self.plot_exits()
        groups = self.group_states[0]  # static group for now
        if not groups:
            for ped in range(self.scene.peds.size()):
                x = self.states[:, ped, 0]
                y = self.states[:, ped, 1]
                self.ax.plot(x, y, "-o", label=f"ped {ped}", markersize=2.5)
        else:

            colors = plt.cm.rainbow(np.linspace(0, 1, len(groups)))

            for i, group in enumerate(groups):
                for ped in group:
                    x = self.states[:, ped, 0]
                    y = self.states[:, ped, 1]
                    self.ax.plot(x,
                                 y,
                                 "-o",
                                 label=f"ped {ped}",
                                 markersize=2.5,
                                 color=colors[i])
        self.ax.legend()
        return self.fig

    def animate(self):
        """Main method to create animation"""

        self.ani = mpl_animation.FuncAnimation(
            self.fig,
            init_func=self.animation_init,
            func=self.animation_update,
            frames=self.frames,
            blit=True,
            interval=200,
        )

        return self.ani

    def __enter__(self):
        logger.info("Start plotting.")
        self.fig.set_tight_layout(True)
        self.ax.grid(linestyle="dotted")
        self.ax.set_aspect("equal")
        self.ax.margins(2.0)
        self.ax.set_axisbelow(True)
        self.ax.set_xlabel("x [m]")
        self.ax.set_ylabel("y [m]")

        plt.rcParams["animation.html"] = "jshtml"

        # x, y limit from states, only for animation
        margin = 2.0
        if self.limits is None:
            xy_limits = np.array([minmax(state) for state in self.states
                                  ])  # (x_min, y_min, x_max, y_max)
            xy_min = np.min(xy_limits[:, :2], axis=0) - margin
            xy_max = np.max(xy_limits[:, 2:4], axis=0) + margin
            self.ax.set(xlim=(xy_min[0], xy_max[0]),
                        ylim=(xy_min[1], xy_max[1]))
        else:
            limits = self.limits
            self.ax.set(xlim=(limits[0] - margin, limits[1] + margin),
                        ylim=(limits[2] - margin, limits[3] + margin))

        # # recompute the ax.dataLim
        # self.ax.relim()
        # # update ax.viewLim using the new dataLim
        # self.ax.autoscale_view()
        return self

    def __exit__(self, exception_type, exception_value, traceback):
        if exception_type:
            logger.error(
                f"Exception type: {exception_type}; Exception value: {exception_value}; Traceback: {traceback}"
            )
        logger.info("Plotting ends.")
        if self.output:
            if self.ani:
                output = self.output + ".gif"
                logger.info(f"Saving animation as {output}")
                self.ani.save(output, writer=self.writer)
            else:
                output = self.output + ".png"
                logger.info(f"Saving plot as {output}")
                self.fig.savefig(output, dpi=300)
        plt.close(self.fig)

    def plot_human(self, step=-1):
        """Generate patches for human
        :param step: index of state, default is the latest
        :return: list of patches
        """
        states, _ = self.scene.get_states()
        current_state = states[step]
        # radius = 0.2 + np.linalg.norm(current_state[:, 2:4], axis=-1) / 2.0 * 0.3
        radius = [0.2] * current_state.shape[0]
        if self.human_actors:
            for i, human in enumerate(self.human_actors):
                human.center = current_state[i, :2]
                human.set_radius(0.2)
                # human.set_radius(radius[i])
        else:
            self.human_actors = [
                Circle(pos, radius=r)
                for pos, r in zip(current_state[:, :2], radius)
            ]
        self.human_collection.set_paths(self.human_actors)
        if not self.agent_colors:
            self.human_collection.set_array(np.arange(current_state.shape[0]))
        else:
            # set colors for each agent
            assert len(self.human_actors) == len(
                self.agent_colors
            ), "agent_colors must be the same length as the agents"
            self.human_collection.set_facecolor(self.agent_colors)

    def plot_groups(self, step=-1):
        """Generate patches for groups
        :param step: index of state, default is the latest
        :return: list of patches
        """
        states, group_states = self.scene.get_states()
        current_state = states[step]
        current_groups = group_states[step]
        if self.group_actors:  # update patches, else create
            points = [current_state[g, :2] for g in current_groups]
            for i, p in enumerate(points):
                self.group_actors[i].set_xy(p)
        else:
            self.group_actors = [
                Polygon(current_state[g, :2]) for g in current_groups
            ]

        self.group_collection.set_paths(self.group_actors)

    def plot_obstacles(self):
        for s in self.scene.get_obstacles():
            self.ax.add_patch(
                Rectangle((s[:, 0][0], s[:, 1][0]),
                          s[:, 0][-1] - s[:, 0][0],
                          s[:, 1][-1] - s[:, 1][0],
                          fill=True,
                          color="black"))

    def plot_fires(self):
        if self.scene.get_fires() is not None:
            f = self.scene.get_fires()[0]
            self.ax.add_patch(
                Rectangle((f[:, 0][0], f[:, 1][0]),
                          f[:, 0][-1] - f[:, 0][0],
                          f[:, 1][-1] - f[:, 1][0],
                          fill=True,
                          color="red"))
            if len(self.scene.get_fires()) > 1:
                b = self.scene.get_fires()[1]
                self.ax.add_patch(
                    Rectangle((b[:, 0][0], b[:, 1][0]),
                              b[:, 0][-1] - b[:, 0][0],
                              b[:, 1][-1] - b[:, 1][0],
                              fill=False,
                              color="red"))

    def plot_exits(self):
        if self.scene.get_exits() is not None:
            for e in self.scene.get_exits():
                # continue
                # self.ax.add_patch(Circle((e[0],e[1]), 1, fill=True, color="green", alpha=0.1))
                self.ax.add_patch(
                    Circle((e[0], e[1]),
                           e[2],
                           fill=True,
                           color="green",
                           alpha=0.1))

    def plot_smoke(self, step=-1):
        if self.scene.get_fires() is not None:
            f = self.scene.get_fires()[0]
            fcx = f[:, 0][0] + (f[:, 0][-1] - f[:, 0][0]) / 2
            fcy = f[:, 1][0] + (f[:, 1][-1] - f[:, 1][0]) / 2
            rad = self.scene.peds.get_smoke_radii()[step]
            self.ax.add_patch(
                Circle((fcx, fcy), rad, fill=True, color="black", alpha=0.005))

    def animation_init(self):
        self.plot_obstacles()
        self.plot_fires()
        self.plot_exits()
        self.ax.add_collection(self.group_collection)
        self.ax.add_collection(self.human_collection)

        return (self.group_collection, self.human_collection)

    def animation_update(self, i):
        self.plot_groups(i)
        self.plot_human(i)
        self.plot_smoke(i)
        return (self.group_collection, self.human_collection)

    def plot_data(self):
        fig, ax = plt.subplots()
        escaped = [
            i / self.scene.peds.get_nr_peds() * 100
            for i in self.scene.peds.escaped
        ]
        health = [i * 100 for i in self.scene.peds.av_health]
        panic = [i * 100 for i in self.scene.peds.av_panic]
        dead = [
            i / self.scene.peds.get_nr_peds() * 100
            for i in self.scene.peds.dead
        ]
        timesteps = [t for t in range(len(escaped))]
        ax.plot(timesteps, escaped, color="green", label="escaped")
        ax.plot(timesteps, health, color="red", label="health")
        ax.plot(timesteps, panic, color="purple", label="panic")
        ax.plot(timesteps, dead, color="black", label="dead")
        ax.set_xlabel("Timestep")
        ax.set_ylabel("[%]")
        ax.set_ylim([0, 100])
        ax.set_title(
            "Number of escaped and dead people and average health and panic.")
        ax.legend()
        ax.grid(linestyle="dotted")
        fig.savefig(self.output + "_data.png")
        logger.info("Created plot of data.")
Beispiel #5
0
class SceneVisualizer:
    """Context for social nav vidualization"""
    def __init__(self,
                 scene,
                 output=None,
                 writer="imagemagick",
                 cmap="viridis",
                 **kwargs):
        self.scene = scene
        self.states, self.group_states = self.scene.get_states()
        self.cmap = cmap
        self.frames = self.scene.get_length()
        self.output = output
        self.writer = writer

        self.fig, self.ax = plt.subplots(**kwargs)

        self.ani = None

        self.group_actors = None
        self.group_collection = PatchCollection([])
        self.group_collection.set(
            animated=True,
            alpha=0.2,
            cmap=self.cmap,
            facecolors="none",
            edgecolors="purple",
            linewidth=2,
            clip_on=True,
        )

        self.human_actors = None
        self.human_collection = PatchCollection([])
        self.human_collection.set(animated=True,
                                  alpha=0.6,
                                  cmap=self.cmap,
                                  clip_on=True)

    def plot(self):
        """Main method for create plot"""
        self.plot_obstacles()
        groups = self.group_states[0]  # static group for now
        if not groups:
            for ped in range(self.scene.peds.size()):
                x = self.states[:, ped, 0]
                y = self.states[:, ped, 1]
                self.ax.plot(x, y, "-o", label=f"ped {ped}", markersize=2.5)
        else:

            colors = plt.cm.rainbow(np.linspace(0, 1, len(groups)))

            for i, group in enumerate(groups):
                for ped in group:
                    x = self.states[:, ped, 0]
                    y = self.states[:, ped, 1]
                    self.ax.plot(x,
                                 y,
                                 "-o",
                                 label=f"ped {ped}",
                                 markersize=2.5,
                                 color=colors[i])
        self.ax.legend()
        return self.fig

    def animate(self):
        """Main method to create animation"""

        self.ani = mpl_animation.FuncAnimation(
            self.fig,
            init_func=self.animation_init,
            func=self.animation_update,
            frames=self.frames,
            blit=True,
        )

        return self.ani

    def __enter__(self):
        logger.info("Start plotting.")
        self.fig.set_tight_layout(True)
        self.ax.grid(linestyle="dotted")
        self.ax.set_aspect("equal")
        self.ax.margins(2.0)
        self.ax.set_axisbelow(True)
        self.ax.set_xlabel("x [m]")
        self.ax.set_ylabel("y [m]")

        plt.rcParams["animation.html"] = "jshtml"

        # x, y limit from states, only for animation
        margin = 2.0
        xy_limits = np.array([minmax(state) for state in self.states
                              ])  # (x_min, y_min, x_max, y_max)
        xy_min = np.min(xy_limits[:, :2], axis=0) - margin
        xy_max = np.max(xy_limits[:, 2:4], axis=0) + margin
        self.ax.set(xlim=(xy_min[0], xy_max[0]), ylim=(xy_min[1], xy_max[1]))

        # # recompute the ax.dataLim
        # self.ax.relim()
        # # update ax.viewLim using the new dataLim
        # self.ax.autoscale_view()
        return self

    def __exit__(self, exception_type, exception_value, traceback):
        if exception_type:
            logger.error(
                f"Exception type: {exception_type}; Exception value: {exception_value}; Traceback: {traceback}"
            )
        logger.info("Plotting ends.")
        if self.output:
            if self.ani:
                output = self.output + ".gif"
                logger.info(f"Saving animation as {output}")
                self.ani.save(output, writer=self.writer)
            else:
                output = self.output + ".png"
                logger.info(f"Saving plot as {output}")
                self.fig.savefig(output, dpi=300)
        plt.close(self.fig)

    def plot_human(self, step=-1):
        """Generate patches for human
        :param step: index of state, default is the latest
        :return: list of patches
        """
        states, _ = self.scene.get_states()
        current_state = states[step]
        # radius = 0.2 + np.linalg.norm(current_state[:, 2:4], axis=-1) / 2.0 * 0.3
        radius = [0.2] * current_state.shape[0]
        if self.human_actors:
            for i, human in enumerate(self.human_actors):
                human.center = current_state[i, :2]
                human.set_radius(0.2)
                # human.set_radius(radius[i])
        else:
            self.human_actors = [
                Circle(pos, radius=r)
                for pos, r in zip(current_state[:, :2], radius)
            ]
        self.human_collection.set_paths(self.human_actors)
        self.human_collection.set_array(np.arange(current_state.shape[0]))

    def plot_groups(self, step=-1):
        """Generate patches for groups
        :param step: index of state, default is the latest
        :return: list of patches
        """
        states, group_states = self.scene.get_states()
        current_state = states[step]
        current_groups = group_states[step]
        if self.group_actors:  # update patches, else create
            points = [current_state[g, :2] for g in current_groups]
            for i, p in enumerate(points):
                self.group_actors[i].set_xy(p)
        else:
            self.group_actors = [
                Polygon(current_state[g, :2]) for g in current_groups
            ]

        self.group_collection.set_paths(self.group_actors)

    def plot_obstacles(self):
        for s in self.scene.get_obstacles():
            self.ax.plot(s[:, 0], s[:, 1], "-o", color="black", markersize=2.5)

    def animation_init(self):
        self.plot_obstacles()
        self.ax.add_collection(self.group_collection)
        self.ax.add_collection(self.human_collection)

        return (self.group_collection, self.human_collection)

    def animation_update(self, i):
        self.plot_groups(i)
        self.plot_human(i)
        return (self.group_collection, self.human_collection)