Пример #1
0
def render_frames(frames, figsize=(30, 20)):
    if isinstance(frames[0], StringIO):  # textual frame
        for i in range(len(frames)):
            if i > 0:
                clear_output()
            print(frames[i].getvalue())
            time.sleep(1)
    else:  # RGB frame
        fig = plt.figure(figsize=figsize)
        plt.axis('off')

        plot = plt.imshow(frames[0])

        def init():
            pass

        def update(i):
            plot.set_data(frames[i])
            return plot,

        anim = FuncAnimation(fig=plt.gcf(),
                             func=update,
                             frames=len(frames),
                             init_func=init,
                             interval=20,
                             repeat=True,
                             repeat_delay=20)
        plt.close(anim._fig)
        display(HTML(anim.to_jshtml()))
Пример #2
0
def animater(buffer, mode="js"):
    """ Animates the buffer for three modes.
    """
    plt.ioff()
    heigth, width, _ = buffer[0].shape
    ratio = width / heigth
    figure, ax = plt.subplots(figsize=(4 * ratio, 4))
    im = plt.imshow(buffer[0])
    ax.axis('off')

    def update(i):
        im.set_array(buffer[i])
        return im,

    ani = FuncAnimation(figure,
                        update,
                        frames=len(buffer),
                        interval=1000 / 60,
                        blit=True,
                        repeat=False)
    if mode == "html":
        return HTML(ani.to_html5_video())
    elif mode == "js":
        return HTML(ani.to_jshtml())
    elif mode == "plot":
        plt.show()
Пример #3
0
def animate(pos, vel, n, nstep, interval=20):
    fig, ax = plt.subplots()
    plt.close(fig)
    ax.set_aspect(aspect=1.0)
    ln, = ax.plot(pos[:, 0], pos[:, 1], 'ro')
    data = []
    for i in range(nstep):
        pos, vel = step(n, pos, vel)
        data.append(pos.copy())

    def init():
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        return ln,

    def update(i):
        ln.set_data(data[i][:, 0], data[i][:, 1])
        return ln,

    anim = FuncAnimation(fig,
                         update,
                         frames=range(nstep),
                         init_func=init,
                         blit=True,
                         interval=interval)
    return HTML(anim.to_jshtml())
Пример #4
0
def makeAnimation(frames):
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=1000/30.0)
    display(HTML(anim.to_jshtml()))
Пример #5
0
    def analysis(self, value_buffer, mode="notebook"):
        plt.ioff()
        heigth, width = self.background.shape
        ratio = width / heigth
        figure, ax = plt.subplots(figsize=(3 * ratio, 3))
        X, Y = np.meshgrid(np.arange(heigth), np.arange(width))
        ax.axis('off')

        plt.imshow(self.background,
                   cmap=self.cmap,
                   norm=self.norm,
                   animated=False)
        im = plt.imshow(np.zeros(shape=self.background.shape),
                        cmap="Blues",
                        vmin=0,
                        vmax=1,
                        animated=True,
                        alpha=0.5)

        coords = np.argwhere(self.background != 0)
        quiver = plt.quiver(coords[:, 1], coords[:, 0],
                            *([np.ones(shape=X.shape)] * 2))

        arr_u = np.array([-1, 0, 1, 0])
        arr_v = np.array([0, 1, 0, -1])

        def update(i):
            values = np.array([
                max(value_buffer[i][(x, y)])
                for (x, y) in zip(X.ravel(), Y.ravel())
            ])
            actions = np.array([
                max(range(4), key=lambda a: value_buffer[i][(x, y)][a])
                for (x, y) in coords
            ])

            quiver_u = arr_u[actions]
            quiver_v = arr_v[actions]

            quiver.set_UVC(quiver_u, quiver_v)

            im.set_array(values.reshape(X.shape).transpose())
            return im,

        ani = FuncAnimation(figure,
                            update,
                            frames=len(value_buffer),
                            interval=1000 / 60,
                            blit=True,
                            repeat=True)

        if mode == "notebook":
            return HTML(ani.to_jshtml())
        elif mode == "plot":
            plt.show()
Пример #6
0
def plot_animation(data):
    fig, ax = plt.subplots(figsize=(5, 3))
    #ax.set( xlim=(1.4, 1.9), ylim=(-40, data.max() ) )
    ax.set(ylim=(-40, data.max() ) )

    line = ax.plot(data[:,0], data[:,1], color='k', lw=2)[0]

    anim = FuncAnimation(
        fig, animate, interval=1000, frames = range(1, len(data[0,:])), fargs= (data.T,line) )

    return HTML(anim.to_html5_video()) , HTML(anim.to_jshtml())
Пример #7
0
 def animar(self):
   d=experimento(self.abscisa,self.ordenada)
   fig = plt.figure()
   ax = fig.gca()
   plt.xlim((-5,5))
   plt.ylim((-5,-4.825))
   d.crear_capa(2, ax)
   fig = plt.figure()
   ax = fig.gca()
   animation = FuncAnimation(fig, d.crear_capa, frames=100, fargs=(ax,))
   animation.save('animation_exp.mp4', writer='ffmpeg', fps=20);
   HTML(animation.to_jshtml())
Пример #8
0
def animate(*maps, res=150, interval=50, frames=100, titles=None):
    """Animate several maps side-by-side in a Jupyter notebook."""
    # Let's first render each map over a grid of thetas
    images = []
    thetas = np.linspace(0, 360, frames)
    x, y = np.meshgrid(np.linspace(-1, 1, res), np.linspace(-1, 1, res))
    for map in np.atleast_1d(maps):
        images.append([
            np.array([map(theta=theta, x=x[j], y=y[j]) for j in range(res)])
            for theta in thetas
        ])
    images = np.array(images)
    nmaps = images.shape[0]
    if titles is not None:
        titles = np.atleast_1d(titles)

    # Set up the plots
    fig, axes = plt.subplots(1, nmaps, figsize=(4 * nmaps, 4))
    axes = np.atleast_1d(axes)
    for i, ax in enumerate(axes):
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.axis('off')
        if titles is not None:
            ax.set_title(titles[i], y=1.05, fontsize=16)
    kwargs = dict(origin="lower",
                  extent=(-1, 1, -1, 1),
                  cmap="plasma",
                  vmin=np.nanmin(images),
                  vmax=np.nanmax(images))
    ims = [ax.imshow([[]], **kwargs) for ax in axes]

    # Initializer function
    def init():
        for im in ims:
            im.set_data([[]])
        return ims

    # Function to animate each frame
    def animate(i):
        for j, im in enumerate(ims):
            im.set_data(images[j, i])
        return ims

    # Generate the animation
    ani = FuncAnimation(fig,
                        animate,
                        init_func=init,
                        frames=frames,
                        interval=interval,
                        blit=False)
    plt.close()
    display(HTML(ani.to_jshtml()))
Пример #9
0
 def animar(self):
     d = simulacion(self.anguloi, self.velanguloi, self.longitud,
                    self.tiempoi)
     fig = plt.figure()
     ax = fig.gca()
     plt.xlim((-5, 5))
     plt.ylim((-5, -4.825))
     d.crear_capa(2, ax)
     fig = plt.figure()
     ax = fig.gca()
     animation = FuncAnimation(fig, d.crear_capa, frames=100, fargs=(ax, ))
     animation.save('animation.mp4', writer='ffmpeg', fps=20)
     HTML(animation.to_jshtml())
Пример #10
0
def animate_footstep_plan(terrain, step_span, position_left, position_right, title=None):

    # initialize figure for animation
    fig, ax = plt.subplots()

    # plot stepping stones
    terrain.plot(title=title, ax=ax)

    # initial position of the feet
    left_foot = ax.scatter(0, 0, color='r', zorder=3, label='Left foot')
    right_foot = ax.scatter(0, 0, color='b', zorder=3, label='Right foot')

    # initial step limits
    left_limits = plot_rectangle(
        [0 ,0],    # center
        step_span, # width
        step_span, # eight
        ax=ax,
        edgecolor='b',
        label='Left-foot limits'
    )
    right_limits = plot_rectangle(
        [0 ,0],    # center
        step_span, # width
        step_span, # eight
        ax=ax,
        edgecolor='r',
        label='Right-foot limits'
    )

    # misc settings
    plt.close()
    ax.legend(loc='upper left', bbox_to_anchor=(0, 1.3), ncol=2)

    def animate(n_steps):

        # scatter feet
        left_foot.set_offsets(position_left[n_steps])
        right_foot.set_offsets(position_right[n_steps])

        # limits of reachable set for each foot
        c2c = np.ones(2) * step_span / 2
        right_limits.set_xy(position_left[n_steps] - c2c)
        left_limits.set_xy(position_right[n_steps] - c2c)

    # create ad display animation
    ani = FuncAnimation(fig, animate, frames=n_steps+1, interval=1e3)
    display(HTML(ani.to_jshtml()))
Пример #11
0
def step_5(t,resonance,FID,fig,ax_one,ax_FID,ax_mult,ax_FT,freqs,nb_freq,trial,mult,integ,posinteg,neginteg,n,ann):
    ax_mult.fill_between([],[],color="blue",label="Positive Areas")
    ax_mult.fill_between([],[],color="red",label="Negative Areas")

    if ax_mult.get_legend() is None:
        ax_mult.legend(loc=1)
    con_3 = ConnectionPatch(xyB=(1.2,0.5), xyA=(-0.2,0.5), coordsA="axes fraction", coordsB="axes fraction",axesA=ax_mult, axesB=ax_FT,color="red")
    con_3.set_arrowstyle("simple",head_length=0.5, head_width=1, tail_width=0.3)
    ax_mult.add_patch(con_3)
    freq_range=np.arange(0, nb_freq-1, int(nb_freq*session["trace_every"]/(session["freq_max"]-session["freq_min"]+2)))
    freq_range=np.append(freq_range,nb_freq-1)
    if ann is not None:ann.remove()
    anim=FuncAnimation(fig, animate, frames=freq_range, fargs=(t,resonance,FID,fig,ax_one,ax_FID,ax_mult,ax_FT,freqs,nb_freq,trial,mult,integ,posinteg,neginteg,n),interval=session["time_delay"]*1000,blit=False,repeat=False)
    movie_data=anim.to_jshtml()
    session["message"]="""You can use the controls below the graph to see the animation.<br><br>
When you finish, we invite you to click on Parameters button. From there, you will be able to tune some parameters and observe the effect on your spectrum.<br>"""
    session["FT_step"]=5
    return movie_data
Пример #12
0
    def animate(self, mode="js"):
        plt.ioff()
        heigth, width = self.background.shape
        ratio = width / heigth
        figure, ax = plt.subplots(figsize=(3 * ratio, 3))
        im = plt.imshow(self.background,
                        cmap=self.cmap,
                        norm=self.norm,
                        animated=True)
        title = ax.text(0.5,
                        0.90,
                        "",
                        bbox={
                            'facecolor': 'w',
                            'alpha': 0.5,
                            'pad': 5
                        },
                        transform=ax.transAxes,
                        ha="center")
        ax.axis('off')

        def update(i):
            data = self.frames[i]
            if isinstance(data, tuple):
                img, text = data
                title.set_text(text)
            else:
                img = data
            im.set_array(img)
            return im, title

        ani = FuncAnimation(figure,
                            update,
                            frames=len(self.frames),
                            interval=1000 / 60,
                            blit=True,
                            repeat=False)
        if mode == "html":
            return HTML(ani.to_html5_video())
        elif mode == "js":
            return HTML(ani.to_jshtml())
        elif mode == "plot":
            plt.show()
Пример #13
0
def render(episode, env):

    fig = plt.figure()
    img = plt.imshow(env.render(mode='rgb_array'))
    plt.axis('off')

    def animate(i):
        img.set_data(episode[i])
        return img,

    anim = FuncAnimation(fig,
                         animate,
                         frames=len(episode),
                         interval=24,
                         blit=True)
    html = HTML(anim.to_jshtml())

    plt.close(fig)
    # !rm None0000000.png

    return html
Пример #14
0
def animation_maker(path, save=None):
    '''
    Creates an matplotlib.animation object that can be saved
    either in .mp4 or as a jshtml str
    '''
    if save:
        vid_path = os.path.join(path,'videos')
        if not os.path.isdir(vid_path):
            os.makedirs(vid_path)

    fig, axs = plt.subplots(1,3,figsize=(15,5))
    fig.tight_layout()
    fig.subplots_adjust(top=0.95)
    images = load_images(path)

    def update(frame, *fargs):
        game_mode = fargs[0]
        for i, reward_mode in enumerate(reward_modes):
            axs[i].cla()
            axs[i].set_title(reward_mode, color='b', fontsize=20)
            try:
                img = images[reward_mode+'_'+game_mode][frame]
            except IndexError:
                # in case of IndeError use the last available.
                #done to keep the videos with different number of total frames
                img = images[reward_mode+'_'+game_mode][len(images[reward_mode+'_'+game_mode])-1]
            axs[i].imshow(plt.imread(img))

    for game_mode in game_modes:
        fig.suptitle(game_mode.title(), fontsize=30)
        max_len = max(len(images[reward_modes[0]+'_'+game_mode]),len(images[reward_modes[1]+'_'+game_mode]),len(images[reward_modes[2]+'_'+game_mode]))
        anim =FuncAnimation(fig, update, frames=range(max_len),fargs=(game_mode,), interval=100)

        if save == 'jshtml':
            with open(os.path.join(vid_path, game_mode+'.vidstr'), 'w') as file:
                file.write(anim.to_jshtml())
        elif save == 'video':
            anim.save(os.path.join(vid_path, game_mode+'.mp4'))
Пример #15
0
def animation_route(transport):
    city = ox.gdf_from_place('Montpellier, France')
    
    G = ox.graph_from_place('Montpellier, France', network_type=transport)
    G = ox.add_edge_speeds(G) 
    G = ox.add_edge_travel_times(G) 
    
    start = (43.61032245, 3.8966295)
    end = (43.6309201, 3.8611052550025553)
    start_node = ox.get_nearest_node(G, start)
    end_node = ox.get_nearest_node(G, end)
    route = nx.shortest_path(G, start_node, end_node)
    route_len = nx.shortest_path_length(G, start_node, end_node)
    
    # bounds for axis
    x_min, y_min, x_max, y_max = city.total_bounds
    
    fig, ax = ox.plot_graph_route(G, route, route_linewidth=1, node_size=0, edge_linewidth=0.3, show=False, close=False)
    city.plot(ax=ax, edgecolor='black', linewidth=0.5, alpha=0.1)
    ax.set(xlim=(x_min, x_max), ylim=(y_min, y_max))
    
    sc = ax.scatter(G.nodes[route[0]]['x'], # x coordiante 
                                   G.nodes[route[0]]['y'], # y coordiante 
                                   s=100, c="b", alpha=1)
    
    def animate(i):               
        x = G.nodes[route[i]]['x']
        y = G.nodes[route[i]]['y']
        sc.set_offsets(np.c_[x, y])
        return sc
   

    anim = FuncAnimation(fig, animate, frames=route_len)
    anim.save('animation_weight.gif', fps=10, writer = 'imagemagick')

    return HTML(anim.to_jshtml())
Пример #16
0
def visualize(image, **kwargs):
    # Get kwargs
    cmap = kwargs.pop("cmap", "plasma")
    grid = kwargs.pop("grid", True)
    interval = kwargs.pop("interval", 75)
    file = kwargs.pop("file", None)
    html5_video = kwargs.pop("html5_video", True)
    vmin = kwargs.pop("vmin", None)
    vmax = kwargs.pop("vmax", None)
    dpi = kwargs.pop("dpi", None)
    figsize = kwargs.pop("figsize", None)
    bitrate = kwargs.pop("bitrate", None)
    colorbar = kwargs.pop("colorbar", False)
    shrink = kwargs.pop("shrink", 0.01)
    ax = kwargs.pop("ax", None)
    if ax is None:
        custom_ax = False
    else:
        custom_ax = True

    # Animation
    nframes = image.shape[0]
    animated = nframes > 1
    borders = []
    latlines = []
    lonlines = []

    # Set up the plot
    if figsize is None:
        figsize = (7, 3.75)
    if ax is None:
        fig, ax = plt.subplots(1, figsize=figsize)
    else:
        fig = ax.figure

    # Mollweide
    dx = 2.0 / image.shape[1]
    extent = (1 + shrink) * np.array([
        -(1 + dx) * 2 * np.sqrt(2),
        2 * np.sqrt(2),
        -(1 + dx) * np.sqrt(2),
        np.sqrt(2),
    ])
    ax.axis("off")
    ax.set_xlim(-2 * np.sqrt(2) - 0.05, 2 * np.sqrt(2) + 0.05)
    ax.set_ylim(-np.sqrt(2) - 0.05, np.sqrt(2) + 0.05)

    # Anti-aliasing at the edges
    x = np.linspace(-2 * np.sqrt(2), 2 * np.sqrt(2), 10000)
    y = np.sqrt(2) * np.sqrt(1 - (x / (2 * np.sqrt(2)))**2)
    borders += [ax.fill_between(x, 1.1 * y, y, color="w", zorder=-1)]
    borders += [
        ax.fill_betweenx(0.5 * x, 2.2 * y, 2 * y, color="w", zorder=-1)
    ]
    borders += [ax.fill_between(x, -1.1 * y, -y, color="w", zorder=-1)]
    borders += [
        ax.fill_betweenx(0.5 * x, -2.2 * y, -2 * y, color="w", zorder=-1)
    ]

    if grid:
        x = np.linspace(-2 * np.sqrt(2), 2 * np.sqrt(2), 10000)
        a = np.sqrt(2)
        b = 2 * np.sqrt(2)
        y = a * np.sqrt(1 - (x / b)**2)
        borders += ax.plot(x, y, "k-", alpha=1, lw=1.5, zorder=0)
        borders += ax.plot(x, -y, "k-", alpha=1, lw=1.5, zorder=0)
        lats = get_moll_latitude_lines()
        latlines = [None for n in lats]
        for n, l in enumerate(lats):
            (latlines[n], ) = ax.plot(l[0],
                                      l[1],
                                      "k-",
                                      lw=0.5,
                                      alpha=0.5,
                                      zorder=0)
        lons = get_moll_longitude_lines()
        lonlines = [None for n in lons]
        for n, l in enumerate(lons):
            (lonlines[n], ) = ax.plot(l[0],
                                      l[1],
                                      "k-",
                                      lw=0.5,
                                      alpha=0.5,
                                      zorder=0)

    # Plot the first frame of the image
    if vmin is None:
        vmin = np.nanmin(image)
    if vmax is None:
        vmax = np.nanmax(image)
    # Set a minimum contrast
    if np.abs(vmin - vmax) < 1e-12:
        vmin -= 1e-12
        vmax += 1e-12

    img = ax.imshow(
        image[0],
        origin="lower",
        extent=extent,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        interpolation="none",
        animated=animated,
        zorder=-3,
    )

    # Add a colorbar
    if colorbar:
        if not custom_ax:
            fig.subplots_adjust(right=0.85)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        fig.colorbar(img, cax=cax, orientation="vertical")

    # Display or save the image / animation
    if animated:

        def updatefig(i):
            img.set_array(image[i])
            return (img, *borders, *latlines, *lonlines)

        ani = FuncAnimation(fig,
                            updatefig,
                            interval=interval,
                            blit=True,
                            frames=image.shape[0])

        # Business as usual
        if (file is not None) and (file != ""):
            if file.endswith(".mp4"):
                ani.save(file, writer="ffmpeg", dpi=dpi, bitrate=bitrate)
            elif file.endswith(".gif"):
                ani.save(file, writer="imagemagick", dpi=dpi, bitrate=bitrate)
            else:
                # Try and see what happens!
                ani.save(file, dpi=dpi, bitrate=bitrate)
            if not custom_ax:
                if not plt.isinteractive():
                    plt.close()
        else:  # if not custom_ax:
            try:
                if "zmqshell" in str(type(get_ipython())):
                    plt.close()
                    with matplotlib.rc_context({
                            "savefig.dpi":
                            dpi if dpi is not None else "figure",
                            "animation.bitrate":
                            bitrate if bitrate is not None else -1,
                    }):
                        if html5_video:
                            display(HTML(ani.to_html5_video()))
                        else:
                            display(HTML(ani.to_jshtml()))
                else:
                    raise NameError("")
            except NameError:
                plt.show()
                if not plt.isinteractive():
                    plt.close()

        # Matplotlib generates an annoying empty
        # file when producing an animation. Delete it.
        try:
            os.remove("None0000000.png")
        except FileNotFoundError:
            pass

    else:
        if (file is not None) and (file != ""):
            fig.savefig(file, bbox_inches="tight")
            if not custom_ax:
                if not plt.isinteractive():
                    plt.close()
        elif not custom_ax:
            plt.show()
Пример #17
0
import matplotlib.image as mpimg
import matplotlib.animation
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
# plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'


def show_image(i):
    img = mpimg.imread('../images/image_{}.jpg'.format(i))
    #plt.figure(figsize = (20,2))
    plt.imshow(img, interpolation='nearest')
    plt.axis('off')
    plt.figsize = (10, 6)
    plt.tight_layout()


fig, ax = plt.subplots(figsize=(10, 8))
animator = FuncAnimation(fig, show_image, frames=range(0, 23))

FFwriter = matplotlib.animation.FFMpegWriter(fps=0.5)
animator.save('../figures_html/story.mp4', writer=FFwriter, dpi=100)

animator.to_jshtml(fps=0.5)
plt.close()
Пример #18
0
def create_animation(agent, every_n_steps=1, display_mode='gif', fps=30):
    history = agent.history
    fig = plt.figure(figsize=(11, 6))
    fig.set_tight_layout(True)
    main_rows = gridspec.GridSpec(2,
                                  1,
                                  figure=fig,
                                  top=.9,
                                  left=.05,
                                  right=.95,
                                  bottom=.25)

    def create_top_row_im(i, title='', actions_cmap=False):
        top_row = main_rows[0].subgridspec(1, 5, wspace=.3)
        ax = fig.add_subplot(top_row[i])
        ax.axis('off')
        ax.set_title(title)
        im = ax.imshow(np.zeros((len(history.q_a_frames_spec.ys),
                                 len(history.q_a_frames_spec.xs))),
                       origin='lower')
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        if actions_cmap is True:
            im.set_clim(history.q_a_frames_spec.amin,
                        history.q_a_frames_spec.amax)
            im.set_cmap("RdYlGn")
            cb = fig.colorbar(im, cax=cax)
        else:
            cb = fig.colorbar(im, cax=cax, format='%.3g')
        cb.ax.tick_params(labelsize=8)
        return im

    def create_bottom_row_plot(i, title=''):
        bottom_row = main_rows[1].subgridspec(1, 3)
        ax = fig.add_subplot(bottom_row[i])
        ax.set_title(title)
        return ax

    Q_max_im = create_top_row_im(0, title='Q max')
    Q_std_im = create_top_row_im(1, title='Q standard deviation')
    action_gradients_im = create_top_row_im(2, title="Action Gradients")
    max_action_im = create_top_row_im(3,
                                      title="Action with Q max",
                                      actions_cmap=True)
    actor_policy_im = create_top_row_im(4, title="Policy", actions_cmap=True)

    scores_ax = create_bottom_row_plot(0, title="Scores")
    scores_ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    scores_ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    training_scores_line, = scores_ax.plot([], 'bo', label='training')
    test_scores_line, = scores_ax.plot([], 'ro', label='test')
    scores_ax.set_xlim(1, len(history.training_episodes))
    scores_combined = np.array([e.score for e in history.training_episodes ]+\
                               [e.score for e in history.test_episodes ])
    scores_ax.set_ylim(scores_combined.min(), scores_combined.max())
    scores_ax.set_xlabel('episode')
    scores_ax.set_ylabel('total reward')
    scores_ax.legend(loc='upper left', bbox_to_anchor=(0, -.1))

    training_episode_ax = create_bottom_row_plot(1)
    training_episode_ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    # TODO: get axis names from q_a_grid_spec
    training_episode_position_line, = training_episode_ax.plot(
        [], 'b-', label='position')
    training_episode_velocity_line, = training_episode_ax.plot(
        [], 'm-', label='velocity')
    training_episode_action_line, = training_episode_ax.plot([],
                                                             'r-',
                                                             label='action')
    training_episode_reward_line, = training_episode_ax.plot([],
                                                             'g-',
                                                             label='reward')
    training_episode_ax.set_ylim((-1.1, 1.1))
    training_episode_ax.axes.get_yaxis().set_visible(False)
    #     training_episode_ax.legend(loc='upper left', ncol=2)

    test_episode_ax = create_bottom_row_plot(2)
    test_episode_ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    test_episode_position_line, = test_episode_ax.plot([],
                                                       'b-',
                                                       label='position')
    test_episode_velocity_line, = test_episode_ax.plot([],
                                                       'm-',
                                                       label='velocity')
    test_episode_action_line, = test_episode_ax.plot([], 'r-', label='action')
    test_episode_reward_line, = test_episode_ax.plot([], 'g-', label='reward')
    test_episode_ax.set_ylim((-1.1, 1.1))
    test_episode_ax.axes.get_yaxis().set_visible(False)
    test_episode_ax.legend(loc='upper left', ncol=2, bbox_to_anchor=(-.5, -.1))

    def update(step_idx):
        num_frames = math.ceil(last_step / every_n_steps)
        frame_idx = math.ceil(step_idx / every_n_steps)
        print("Drawing frame: %i/%i, %.2f%%\r"%\
              (frame_idx+1, num_frames, 100*(frame_idx+1)/float(num_frames) ), end='')
        training_episode = history.get_training_episode_for_step(step_idx)
        episode_step_idx = step_idx - training_episode.first_step

        q_a_frames = history.get_q_a_frames_for_step(step_idx)

        Q_max_im.set_data(q_a_frames.Q_max)
        Q_max_im.set_clim(q_a_frames.Q_max.min(), q_a_frames.Q_max.max())
        Q_std_im.set_data(q_a_frames.Q_std)
        Q_std_im.set_clim(q_a_frames.Q_std.min(), q_a_frames.Q_std.max())
        action_gradients_im.set_data(
            q_a_frames.action_gradients.reshape(agent.q_a_frames_spec.ny,
                                                agent.q_a_frames_spec.nx))
        action_gradients_im.set_clim(q_a_frames.action_gradients.min(),
                                     q_a_frames.action_gradients.max())
        max_action_im.set_data(q_a_frames.max_action)
        actor_policy_im.set_data(
            q_a_frames.actor_policy.reshape(agent.q_a_frames_spec.ny,
                                            agent.q_a_frames_spec.nx))

        # Plot scores
        xdata = range(1, training_episode.episode_idx + 1)
        training_scores_line.set_data(xdata, [
            e.score for e in history.training_episodes
        ][:training_episode.episode_idx])
        test_scores_line.set_data(xdata,
                                  [e.score for e in history.test_episodes
                                   ][:training_episode.episode_idx])

        #Plot training episode
        training_episode_ax.set_title(
            "Training episode %i, eps=%.3f, score: %.3f" %
            (training_episode.episode_idx, training_episode.epsilon,
             training_episode.score))

        current_end_idx = episode_step_idx + every_n_steps
        if current_end_idx >= len(training_episode.states):
            current_end_idx = len(training_episode.states) - 1

        training_xdata = range(0, current_end_idx + 1)
        training_episode_ax.set_xlim(
            training_xdata[0],
            training_episode.last_step - training_episode.first_step + 1)
        episode_states = [
            agent.preprocess_state(s) for s in training_episode.states
        ]
        training_episode_position_line.set_data(training_xdata,
                                                [s[0] for s in episode_states
                                                 ][:current_end_idx + 1])
        training_episode_velocity_line.set_data(training_xdata,
                                                [s[1] for s in episode_states
                                                 ][:current_end_idx + 1])
        training_episode_action_line.set_data(
            training_xdata, training_episode.actions[:current_end_idx + 1])
        training_episode_reward_line.set_data(
            training_xdata, training_episode.rewards[:current_end_idx + 1])

        #Plot test episode
        test_episode = history.get_test_episode_for_step(step_idx)
        if test_episode is not None:
            test_episode_ax.set_title(
                "Test episode %i, score: %.3f" %
                (test_episode.episode_idx, test_episode.score))
            test_xdata = range(1, len(test_episode.states) + 1)
            test_episode_ax.set_xlim(test_xdata[0], test_xdata[-1])
            episode_states = [
                agent.preprocess_state(e) for e in test_episode.states
            ]
            test_episode_position_line.set_data(test_xdata,
                                                [s[0] for s in episode_states])
            test_episode_velocity_line.set_data(test_xdata,
                                                [s[1] for s in episode_states])
            test_episode_action_line.set_data(test_xdata, test_episode.actions)
            test_episode_reward_line.set_data(test_xdata, test_episode.rewards)

    last_step = history.training_episodes[-1].last_step + 1
    anim = FuncAnimation(fig,
                         update,
                         interval=1000 / fps,
                         frames=range(0, last_step, every_n_steps))

    if display_mode == 'video' or display_mode == 'video_file':
        from matplotlib.animation import FFMpegWriter
        writer = FFMpegWriter(fps=fps)
        if writer.isAvailable():
            print("Using ffmpeg at '%s'." % writer.bin_path())
        else:
            raise ("FFMpegWriter not available for video output.")
    if display_mode == 'js':
        display(HTML(anim.to_jshtml()))
    elif display_mode == 'video':
        display(HTML(anim.to_html5_video()))
    elif display_mode == 'video_file':
        filename = 'training_animation_%i.mp4' % int(
            datetime.now().timestamp())
        img = anim.save(filename, writer=writer)
        print("\rVideo saved to %s." % filename)
        # import io, base64
        # encoded = base64.b64encode(io.open(filename, 'r+b').read())
        # display(HTML(data='''<video alt="training animation" controls loop autoplay>
        #                 <source src="data:video/mp4;base64,{0}" type="video/mp4" />
        #              </video>'''.format(encoded.decode('ascii'))))
        display(
            HTML(
                data='''<video alt="training animation" controls loop autoplay>
                        <source src="{0}" type="video/mp4" />
                     </video>'''.format(filename)))
    else:
        filename = 'training_animation_%i.gif' % int(
            datetime.now().timestamp())
        img = anim.save(filename, dpi=80, writer='imagemagick')
        display(HTML("<img src='%s'/>" % filename))
    plt.close()
Пример #19
0
                     frames=range(0, len(x_test)),
                     interval=interval_display, blit=False, repeat=False)


###############################################################################

# Plot online classification

# Plot complete visu: a dynamic display is required
plt.show()

# Plot only 10s, for animated documentation
try:
    from IPython.display import HTML
except ImportError:
    raise ImportError("Install IPython to plot animation in documentation")

plt.rcParams["animation.embed_limit"] = 10
HTML(visu.to_jshtml(fps=5, default_mode='loop'))


###############################################################################
# References
# ----------
# .. [1] P.T. Fletcher, C. Lu, S.M. Pizer and S. Joshi, "Principal geodesic
#    analysis for the study of nonlinear statistics of shape", IEEE Trans Med
#    Imaging, 2004.
# .. [2] E. K. Kalunga, S. Chevallier, Q. Barthélemy, K. Djouani, E. Monacelli,
#     Y. Hamam, "Online SSVEP-based BCI using Riemannian geometry",
#     Neurocomputing, 2016.
Пример #20
0
In English, we set up (or initialize) the figure and then make a function that does all of the updates for each frame of the animation. Finally, we pass the figure, the function, and the frame numbers to `FuncAnimation()` and we have our animation.

Here's an example in which we plot a sinusoid of different heights, and allow the user to adjust the heights with a slider.

x = linspace(0.,2.,1001)                 # Define x from 0 to 2 with 1001 steps.
lines = plot(x, 0. * sin(x*pi))   # Make the first plot, save the curve in "lines"
axis([0, 2, -1, 1])                     # Set the x and y limits in the plot
title("plot number = 0")                # ... and label the plot.

def animate(frame):                         # Define the function to perform the animation.
    lines[0].set_ydata(float(frame) / 100. * sin(x * pi)) # Change the y values at each x location
    title('plot number = ' + str(frame))# Update the title with the new plot number
    
fig = FuncAnimation(gcf(), animate, frames=range(100))
HTML(fig.to_jshtml())

## Example 19: Load MATLAB data into Python
For our last example let's load a MATLAB file in the `.mat` format into Python. Before doing so, let's clear all of the variables and functions we have defined. This command is not necessary, but we perform it here so that any new variables we subsequently load are obvious.

%reset

Then, let's import the `scipy.io` module, which we'll use to import the `.mat` data,

import scipy.io as sio

Now, let's load a data file using the function `loadmat`,

mat = sio.loadmat('matfiles/sample_data.mat')
type(mat)
Пример #21
0
                       frames=range(train_covs, test_covs_max),
                       interval=interval_display,
                       blit=False,
                       repeat=False)

###############################################################################

# Plot online detection

# Plot complete visu: a dynamic display is required
plt.show()

# Plot only 10s, for animated documentation
try:
    from IPython.display import HTML
except ImportError:
    raise ImportError("Install IPython to plot animation in documentation")

plt.rcParams["animation.embed_limit"] = 10
HTML(potato.to_jshtml(fps=5, default_mode='loop'))

###############################################################################
# References
# ----------
# .. [1] A. Barachant, A. Andreev, M. Congedo, "The Riemannian Potato: an
#    automatic and adaptive artifact detection method for online experiments
#    using Riemannian geometry", Proc. TOBI Workshop IV, 2013.
#
# .. [2] Q. Barthélemy, L. Mayaud, D. Ojeda, M. Congedo, "The Riemannian potato
#    field: a tool for online signal quality index of EEG", IEEE TNSRE, 2019.
Пример #22
0
def get_js_html(animation: FuncAnimation):
    import matplotlib.pyplot as plt
    plt.rcParams['animation.ffmpeg_path'] = 'C:/FFmpeg/bin/ffmpeg.exe'
    return animation.to_jshtml()
            np.array(poly1)[:, 1]))

clb3 = plt.colorbar(scats[i])
clb3.ax.set_xlabel('Age [d]')

t = np.datetime_as_string(timerange[0], unit='m')
title = axs[0].set_title('Particles at t = ' + t)


def animate(i):
    t = np.datetime_as_string(timerange[i], unit='m')
    title.set_text('Particles at t = ' + t)

    time_id = np.where((data_xarray['time'] >= timerange[i])
                       & (data_xarray['time'] < timerange[i + 1]))

    for i, ext in enumerate(exts):
        scats[i].set_offsets(np.c_[data_xarray['lon'].values[time_id],
                                   data_xarray['lat'].values[time_id]])
        scats[i].set_array(data_xarray['age'][time_id] / 86400)


anim = FuncAnimation(fig, animate, frames=len(timerange) - 1, interval=500)

# In[23]:

if BuildAnim:
    from IPython.display import HTML
    HTML(anim.to_jshtml())
    anim.save('GAL1.mp4', fps=5, extra_args=['-vcodec', 'libx264'])
Пример #24
0
class animate_array():
    """
    Animate a 2D graph over time.
    """
    def __init__(self, array, x_points, times=None):
        """
        Parameters
        ----------
        array : 2d array
            Each row has the values at each x_node for that timestep
        x_points : 1d array
            The x_nodes that array corresponds to
        times : 1d array, optional
            The times corresponding to each row in array for text display.
            Defaults to displaying the row index.

        Optional Parameters
        -------------
        (Set these by changing the attributes manually)
        html : bool
            Pass the animation into a html wrapper (for jupyter notebooks)
        frame_interval : int
            The milisecond interval between frames
        frame_skip : int
            The frequency of frames to plot (good for saving animations)
        """

        # validate input
        assert array.ndim == 2, 'must be 2d'
        assert len(x_points) == len(array[0]), 'spacial dimension mismatch'

        # setup data
        self.arr = array
        self.N = len(array)
        self.x = x_points

        # set the times
        if times is None:
            self.pre_string = 'i='
            self.times = np.arange(self.N).astype(str)
        else:
            assert len(times) == self.N, 'The length of the time array is off'
            self.pre_string = 't='
            self.times = times.astype(str)

        # defualt options
        self.html = False
        self.frame_interval = 20
        self.frame_skip = 1

    def blank(self):
        "The blank animation frame"
        self.line.set_data([], [])
        self.text.set_text('')
        return self.line, self.text

    def update(self, i):
        "Plot the ith animation frame"
        self.line.set_data(self.x, self.arr[i, :])
        self.text.set_text(self.pre_string + self.times[i])
        return self.line, self.text

    def set_figure(self):
        "Set the figure seperately so it can be customised if wanted"
        # initialise figure
        self.fig, self.ax = plt.subplots()
        self.line, = self.ax.plot([], [], lw=3, label='Numerical')
        self.text = self.ax.text(0.05, 0.95, '', transform=self.ax.transAxes)

        # set axis limits
        x_r = (self.x[-1] - self.x[0]) * 0.05
        y_r = (self.arr[0, :].max() - self.arr[0, :].min()) * 0.05
        self.ax.set_xlim(self.x[0] - x_r, self.x[-1] + x_r)
        self.ax.set_ylim(self.arr[0, :].min() - y_r,
                         self.arr[0, :].max() + y_r)

    def animate(self):
        "Run the animation, call set_figure first to customise plot"
        if 'fig' not in vars(self):
            self.set_figure()

        # animate
        self.ani = FuncAnimation(self.fig,
                                 self.update,
                                 frames=range(self.N),
                                 interval=self.frame_interval,
                                 blit=True,
                                 init_func=self.blank)

        # html wrapper if wanted
        if self.html:
            plt.close(self.fig)
            self.HTML = HTML(self.ani.to_jshtml())
            return self.HTML

    def save(self, path):
        "Save the animation to a file in path (takes a few seconds to run)"
        # ensure the animation has been made
        if 'ani' not in vars(self):
            self.animate()

        # make the save
        with open(path + '.html', 'w') as f:
            f.write(self.ani.to_jshtml())
        f.close()
Пример #25
0
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

fig, ax = plt.subplots()
xdata, ydata = [], []
ln, = plt.plot([], [], 'ro')

def init():
    ax.set_xlim(0, 2 * np.pi)
    ax.set_ylim(-1, 1)
    return ln,

def update(frame):
    xdata.append(frame)
    ydata.append(np.sin(frame))
    ln.set_data(xdata, ydata)
    return ln,

ani = FuncAnimation(fig, update, frames=np.linspace(0, 2 * np.pi, 128),
                    init_func=init, blit=True, interval=50)

HTML(ani.to_jshtml())


# In[ ]:




Пример #26
0
    def show(
            self,
            t,
            cmap="plasma",
            res=300,
            interval=75,
            file=None,
            figsize=(3, 3),
            html5_video=True,
            window_pad=1.0,
    ):
        """Visualize the Keplerian system.

        Args:
            t (scalar or vector): The time(s) at which to evaluate the orbit and
                the map in units of :py:attr:`time_unit`.
            cmap (string or colormap instance, optional): The matplotlib colormap
                to use. Defaults to ``plasma``.
            res (int, optional): The resolution of the map in pixels on a
                side. Defaults to 300.
            figsize (tuple, optional): Figure size in inches. Default is
                (3, 3) for orthographic maps and (7, 3.5) for rectangular
                maps.
            interval (int, optional): Interval between frames in milliseconds
                (animated maps only). Defaults to 75.
            file (string, optional): The file name (including the extension)
                to save the animation to (animated maps only). Defaults to None.
            html5_video (bool, optional): If rendering in a Jupyter notebook,
                display as an HTML5 video? Default is True. If False, displays
                the animation using Javascript (file size will be larger.)
            window_pad (float, optional): Padding around the primary in units
                of the primary radius. Bodies outside of this window will be
                cropped. Default is 1.0.
        """
        # Not yet implemented
        if self._primary._map.nw is not None:  # pragma: no cover
            raise NotImplementedError(
                "Method not implemented for spectral maps.")

        # Render the maps & get the orbital positions
        if self._rv:
            self._primary.map._set_RV_filter()
            for sec in self._secondaries:
                sec.map._set_RV_filter()
        img_pri, img_sec, x, y, z = self.ops.render(
            math.reshape(math.to_array_or_tensor(t), [-1]) * self._time_factor,
            res,
            self._primary._r,
            self._primary._m,
            self._primary._prot,
            self._primary._t0,
            self._primary._theta0,
            self._primary._map._amp,
            self._primary._map._inc,
            self._primary._map._obl,
            self._primary._map._y,
            self._primary._map._u,
            self._primary._map._f,
            self._primary._map._alpha,
            math.to_array_or_tensor([sec._r for sec in self._secondaries]),
            math.to_array_or_tensor([sec._m for sec in self._secondaries]),
            math.to_array_or_tensor([sec._prot for sec in self._secondaries]),
            math.to_array_or_tensor([sec._t0 for sec in self._secondaries]),
            math.to_array_or_tensor([sec._theta0
                                     for sec in self._secondaries]),
            self._get_periods(),
            math.to_array_or_tensor([sec._ecc for sec in self._secondaries]),
            math.to_array_or_tensor([sec._w for sec in self._secondaries]),
            math.to_array_or_tensor([sec._Omega for sec in self._secondaries]),
            math.to_array_or_tensor([sec._inc for sec in self._secondaries]),
            math.to_array_or_tensor(
                [sec._map._amp for sec in self._secondaries]),
            math.to_array_or_tensor(
                [sec._map._inc for sec in self._secondaries]),
            math.to_array_or_tensor(
                [sec._map._obl for sec in self._secondaries]),
            math.to_array_or_tensor([sec._map._y
                                     for sec in self._secondaries]),
            math.to_array_or_tensor([sec._map._u
                                     for sec in self._secondaries]),
            math.to_array_or_tensor([sec._map._f
                                     for sec in self._secondaries]),
            math.to_array_or_tensor(
                [sec._map._alpha for sec in self._secondaries]),
        )

        # Convert to units of the primary radius
        fac = np.reshape([sec._length_factor for sec in self._secondaries],
                         [-1, 1])
        fac = fac * self._primary._r
        x, y, z = x / fac, y / fac, z / fac
        r = math.to_array_or_tensor([sec._r for sec in self._secondaries])
        r = r / self._primary._r

        # Evaluate if needed
        if config.lazy:
            img_pri = img_pri.eval()
            img_sec = img_sec.eval()
            x = x.eval()
            y = y.eval()
            z = z.eval()
            r = r.eval()

        # We need this to be of shape (nplanet, nframe)
        x = x.T
        y = y.T
        z = z.T

        # Ensure we have an array of frames
        if len(img_pri.shape) == 3:
            nframes = img_pri.shape[0]
        else:  # pragma: no cover
            nframes = 1
            img_pri = np.reshape(img_pri, (1, ) + img_pri.shape)
            img_sec = np.reshape(img_sec, (1, ) + img_sec.shape)
        animated = nframes > 1

        # Set up the plot
        fig, ax = plt.subplots(1, figsize=figsize)
        ax.axis("off")
        ax.set_xlim(-1.0 - window_pad, 1.0 + window_pad)
        ax.set_ylim(-1.0 - window_pad, 1.0 + window_pad)

        # Render the first frame
        img = [None for n in range(1 + len(self._secondaries))]
        circ = [None for n in range(1 + len(self._secondaries))]
        extent = np.array([-1.0, 1.0, -1.0, 1.0])
        img[0] = ax.imshow(
            img_pri[0],
            origin="lower",
            extent=extent,
            cmap=cmap,
            interpolation="none",
            vmin=np.nanmin(img_pri),
            vmax=np.nanmax(img_pri),
            animated=animated,
            zorder=0.0,
        )
        circ[0] = plt.Circle((0, 0),
                             1,
                             color="k",
                             fill=False,
                             zorder=1e-3,
                             lw=2)
        ax.add_artist(circ[0])
        for i, _ in enumerate(self._secondaries):
            extent = np.array([x[i, 0], x[i, 0], y[i, 0], y[i, 0]
                               ]) + (r[i] * np.array([-1.0, 1.0, -1.0, 1.0]))
            img[i + 1] = ax.imshow(
                img_sec[i, 0],
                origin="lower",
                extent=extent,
                cmap=cmap,
                interpolation="none",
                vmin=np.nanmin(img_sec),
                vmax=np.nanmax(img_sec),
                animated=animated,
                zorder=z[i, 0],
            )
            circ[i] = plt.Circle(
                (x[i, 0], y[i, 0]),
                r[i],
                color="k",
                fill=False,
                zorder=z[i, 0] + 1e-3,
                lw=2,
            )
            ax.add_artist(circ[i])

        # Animation
        if animated:

            def updatefig(k):

                # Update Primary map
                img[0].set_array(img_pri[k])

                # Update Secondary maps & positions
                for i, _ in enumerate(self._secondaries):
                    extent = np.array([x[i, k], x[i, k], y[i, k], y[i, k]]) + (
                        r[i] * np.array([-1.0, 1.0, -1.0, 1.0]))
                    if np.any(np.abs(extent) < 1.0 + window_pad):
                        img[i + 1].set_array(img_sec[i, k])
                        img[i + 1].set_extent(extent)
                        img[i + 1].set_zorder(z[i, k])
                        circ[i].center = (x[i, k], y[i, k])
                        circ[i].set_zorder(z[i, k] + 1e-3)

                return img + circ

            ani = FuncAnimation(fig,
                                updatefig,
                                interval=interval,
                                blit=False,
                                frames=nframes)

            # Business as usual
            if (file is not None) and (file != ""):
                if file.endswith(".mp4"):
                    ani.save(file, writer="ffmpeg")
                elif file.endswith(".gif"):
                    ani.save(file, writer="imagemagick")
                else:  # pragma: no cover
                    # Try and see what happens!
                    ani.save(file)
                plt.close()
            else:  # pragma: no cover
                try:
                    if "zmqshell" in str(type(get_ipython())):
                        plt.close()
                        if html5_video:
                            display(HTML(ani.to_html5_video()))
                        else:
                            display(HTML(ani.to_jshtml()))
                    else:
                        raise NameError("")
                except NameError:
                    plt.show()
                    plt.close()

            # Matplotlib generates an annoying empty
            # file when producing an animation. Delete it.
            try:
                os.remove("None0000000.png")
            except FileNotFoundError:
                pass

        else:

            if (file is not None) and (file != ""):
                fig.savefig(file)
                plt.close()
            else:  # pragma: no cover
                plt.show()

        if self._rv:
            self._primary.map._unset_RV_filter()
            for sec in self._secondaries:
                sec.map._unset_RV_filter()
Пример #27
0
class multi_animate():
    """
    Animate a 2D graph over time.
    """
    def __init__(self, arrays, x_points, times=None, labels=None):
        """
        Parameters
        ----------
        array : list of 2d arrays
            each array has rows with values at each x_node for that timestep
        x_points : 1d array
            The x_nodes that array corresponds to
        times : 1d array, optional
            The times corresponding to each row in array for text display.
            Defaults to displaying the row index.
        labels : list
            The legend labels for each array given

        Optional Parameters
        -------------
        (Set these by changing the attributes manually)
        html : bool
            Pass the animation into a html wrapper (for jupyter notebooks)
        frame_interval : int
            The milisecond interval between frames
        frame_skip : int
            The frequency of frames to plot (good for saving animations)
        """

        # validate input
        if type(arrays) is not list and type(arrays) is not tuple:
            arrays = [arrays]
        assert arrays[0].shape[1] == len(x_points),\
            "array dimensions dont match"
        for arr in arrays:
            assert arr.ndim == 2, 'must be 2d arrays'
            assert arr.shape == arrays[0].shape, \
                'all arrays must have same shape'

        # setup data
        self.arrs = arrays
        self.N_arrs = len(arrays)
        self.N = len(arrays[0])
        self.x = x_points

        # set the times
        if times is None:
            self.pre_string = 'i='
            self.times = np.arange(self.N).astype(str)
        else:
            assert len(times) == self.N, 'The length of the time array os off'
            self.pre_string = 't='
            self.times = times.astype(str)

        # set the labels
        if labels is None:
            self.labels = [str('Array %i' % n) for n in range(self.N_arrs)]
        else:
            assert len(labels) == self.N_arrs, 'wrong number of labels'
            self.labels = labels

        # defualt options
        self.html = False
        self.frame_interval = 20
        self.frame_skip = 1
        self.legend = True
        self.figsize = [5, 5]
        self.titles = ['', '', '']
        
        # plotting format string
        if self.N_arrs > 1:
            self.fmts = ['-'] + ['--']*(self.N_arrs - 1)
        else:
            self.fmts = ['-']

        # set axis limits
        x_r = (self.x[-1] - self.x[0]) * 0.05
        y_max = max([arr[0, :].max() for arr in self.arrs])
        y_min = min([arr[0, :].min() for arr in self.arrs])
        y_r = (y_max - y_min) * 0.05
        self.xlims = (self.x[0] - x_r, self.x[-1] + x_r)
        self.ylims = (y_min - y_r, y_max + y_r)

    def blank(self):
        "The blank animation frame"
        for line in self.lines:
            line.set_data([], [])
        self.text.set_text('')
        return (self.text, *self.lines)

    def update(self, i):
        "Plot the ith animation frame"
        for line, arr in zip(self.lines, self.arrs):
            line.set_data(self.x, arr[i, :])
        self.text.set_text(self.pre_string + self.times[i])
        return (self.text, *self.lines)

    def set_figure(self):
        "Set the figure seperately so it can be customised if wanted"
        # initialise figure
        self.fig, self.ax = plt.subplots(figsize=self.figsize)
        self.lines = [self.ax.plot([], [], fmt, lw=3, label=l)[0]
                      for l,fmt in zip(self.labels, self.fmts)]
        self.text = self.ax.text(0.05, 0.95, '', transform=self.ax.transAxes)

        # set axis limits
        self.ax.set_xlim(*self.xlims)
        self.ax.set_ylim(*self.ylims)

        # set the titles
        self.ax.set(title=self.titles[0], xlabel=self.titles[1], ylabel=self.titles[2])
        
        # set legend
        if self.legend:
            self.ax.legend()

    def animate(self):
        "Run the animation, call set_figure first to customise plot"
        self.set_figure()

        # animate
        self.ani = FuncAnimation(self.fig,
                                 self.update,
                                 frames=range(self.N),
                                 interval=self.frame_interval,
                                 blit=True,
                                 init_func=self.blank)

        # html wrapper if wanted
        if self.html:
            plt.close(self.fig)
            self.HTML = HTML(self.ani.to_jshtml())
            return self.HTML

    def save(self, path):
        "Save the animation to a file in path (takes a few seconds to run)"

        # ensure the animation has been made
        if 'ani' not in vars(self):
            self.animate()

        # make the save
        with open(path + '.html', 'w') as f:
            f.write(self.ani.to_jshtml())
        f.close()