示例#1
0
文件: plot.py 项目: zcl-maker/duat
    def time_2d_animation(self, output_path=None, dataset_selector=None, axes_selector=None, time_selector=None,
                          dpi=200, fps=1, cmap=None, norm=None, rasterized=True, z_min=None,
                          z_max=None, latex_label=True, interval=200):
        """
        Generate a plot of 2d data as a color map which animated in time.

        If an output path with a suitable extension is supplied, the method will export it. Available formats are mp4
        and gif. The returned objects allow for minimal customization and representation. For example in Jupyter you
        might use `IPython.display.HTML(animation.to_html5_video())`, where `animation` is the returned `FuncAnimation`
        instance.

        Note:
            Exporting a high resolution animated gif with many frames might eat your RAM.

        Args:
            output_path (str): The place where the plot is saved. If "" or None, the plot is shown in matplotlib.
            dataset_selector: See :func:`~duat.osiris.plot.Diagnostic.get_generator` method.
            axes_selector: See :func:`~duat.osiris.plot.Diagnostic.get_generator` method.
            time_selector: See :func:`~duat.osiris.plot.Diagnostic.get_generator` method.
            interval (float): Delay between frames in ms. If exporting to mp4, the fps is used instead to generate the
                              file, although the returned objects do use this value.
            dpi (int): The resolution of the frames in dots per inch (only if exporting).
            fps (int): The frames per seconds (only if exporting to mp4).
            latex_label (bool): Whether for use LaTeX code for the plot.
            cmap (str or `matplotlib.colors.Colormap`): The Colormap to use in the plot.
            norm (str or `matplotlib.colors.Normalize`): How to scale the colormap. For advanced manipulation, use some
                           Normalize subclass, e.g., colors.SymLogNorm(0.01). Automatic scales can be selected with
                           the following strings:

                           * "lin": Linear scale from minimum to maximum.
                           * "log": Logarithmic scale from minimum to maximum up to vmax/vmin>1E9, otherwise increasing vmin.


            rasterized (bool): Whether the map is rasterized. This does not apply to axes, title... Note non-rasterized
                               images with large amount of data exported to PDF might challenging to handle.
        Returns:
            (`matplotlib.figure.Figure`, `matplotlib.axes.Axes`, `matplotlib.animation.FuncAnimation`):
            Objects representing the generated plot and its animation.
            
        Raises:
            FileNotFoundError: If tried to export to mp4 but ffmpeg is not found in the system.

        """
        if output_path:
            ensure_dir_exists(os.path.dirname(output_path))
        axes = self.get_axes(dataset_selector=dataset_selector, axes_selector=axes_selector)
        if len(axes) != 2:
            raise ValueError("Expected 2 axes plot, but %d were provided" % len(axes))

        gen = self.get_generator(dataset_selector=dataset_selector, axes_selector=axes_selector,
                                 time_selector=time_selector)

        # Set plot labels
        fig, ax = plt.subplots()
        fig.set_tight_layout(True)

        x_name = axes[0]["LONG_NAME"]
        x_units = axes[0]["UNITS"]
        y_name = axes[1]["LONG_NAME"]
        y_units = axes[1]["UNITS"]
        title_name = self.data_name
        title_units = self.units

        ax.set_xlabel(_create_label(x_name, x_units, latex_label))
        ax.set_ylabel(_create_label(y_name, y_units, latex_label))

        # Gather the points
        x_min, x_max = axes[0]["MIN"], axes[0]["MAX"]
        y_min, y_max = axes[1]["MIN"], axes[1]["MAX"]
        z = np.transpose(np.asarray(next(gen)))

        time_list = self.get_time_list(time_selector)
        if len(time_list) < 2:
            raise ValueError("At least two time snapshots are needed to make an animation")

        norm = _autonorm(norm, z)

        plot_function = ax.pcolormesh
        if rasterized:
            # Rasterizing in contourf is a bit tricky
            # Cf. http://stackoverflow.com/questions/33250005/size-of-matplotlib-contourf-image-files
            plot = plot_function(axes[0]["LIST"], axes[1]["LIST"], z, norm=norm, cmap=cmap, zorder=-9)
            ax.set_rasterization_zorder(-1)
        else:
            plot = plot_function(axes[0]["LIST"], axes[1]["LIST"], z, norm=norm, cmap=cmap)

        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)

        ax.set_title(_create_label(title_name, title_units, latex_label))

        _fix_colorbar(fig.colorbar(plot))

        # Prepare a function for the updates
        def update(i):
            """Update the plot, returning the artists which must be redrawn."""
            try:
                new_dataset = np.transpose(np.asarray(next(gen)))
            except StopIteration:
                logger.warning("Tried to add a frame to the animation, but all data was used.")
                return
            label = 't = {0}'.format(time_list[i])
            # BEWARE: The set_array syntax is rather problematic. Depending on the shading used in pcolormesh, the
            #         following might not work.
            plot.set_array(new_dataset[:-1, :-1].ravel())
            # For more details, check lumbric's answer to
            # https://stackoverflow.com/questions/18797175/animation-with-pcolormesh-routine-in-matplotlib-how-do-i-initialize-the-data
            ax.set_title(label)
            return plot, ax

        anim = FuncAnimation(fig, update, frames=range(1, len(time_list) - 2), interval=interval)

        if not output_path:  # "" or None
            pass
        else:
            filename = os.path.basename(output_path)
            if "." in filename:
                extension = output_path.split(".")[-1].lower()
            else:
                extension = None
            if extension == "gif":
                anim.save(output_path, dpi=dpi, writer='imagemagick')
            elif extension == "mp4":
                metadata = dict(title=os.path.split(self.data_path)[-1], artist='duat', comment=self.data_path)
                writer = FFMpegWriter(fps=fps, metadata=metadata)
                with writer.saving(fig, output_path, dpi):
                    # Iterate over frames
                    for i in range(1, len(time_list) - 1):
                        update(i)
                        writer.grab_frame()
                    # Keep showing the last frame for the fixed time
                    writer.grab_frame()
            else:
                logger.warning("Unknown extension in path %s. No output produced." % output_path)

        plt.close()

        return fig, ax, anim
示例#2
0
文件: plot.py 项目: dream81/grond
    def start(self):
        nfx = 1
        nfy = 1

        problem = self.problem

        ixpar = problem.name_to_index(self.xpar_name)
        iypar = problem.name_to_index(self.ypar_name)

        mpl_init(fontsize=self.fontsize)
        fig = plt.figure(figsize=(9.6, 5.4))
        labelpos = mpl_margins(fig,
                               nw=nfx,
                               nh=nfy,
                               w=7.,
                               h=5.,
                               wspace=7.,
                               hspace=2.,
                               units=self.fontsize)

        xpar = problem.parameters[ixpar]
        ypar = problem.parameters[iypar]

        if xpar.unit == ypar.unit:
            axes = fig.add_subplot(nfy, nfx, 1, aspect=1.0)
        else:
            axes = fig.add_subplot(nfy, nfx, 1)

        labelpos(axes, 2.5, 2.0)

        axes.set_xlabel(xpar.get_label())
        axes.set_ylabel(ypar.get_label())

        axes.get_xaxis().set_major_locator(plt.MaxNLocator(4))
        axes.get_yaxis().set_major_locator(plt.MaxNLocator(4))

        xref = problem.get_reference_model()
        axes.axvline(xpar.scaled(xref[ixpar]), color='black', alpha=0.3)
        axes.axhline(ypar.scaled(xref[iypar]), color='black', alpha=0.3)

        self.fig = fig
        self.problem = problem
        self.xpar = xpar
        self.ypar = ypar
        self.axes = axes
        self.ixpar = ixpar
        self.iypar = iypar
        from matplotlib import colors
        n = self.optimiser.nbootstrap + 1
        hsv = num.vstack((num.random.uniform(0., 1., n),
                          num.random.uniform(0.5, 0.9, n), num.repeat(0.7,
                                                                      n))).T

        self.bcolors = colors.hsv_to_rgb(hsv[num.newaxis, :, :])[0, :, :]
        self.bcolors[0, :] = [0., 0., 0.]

        bounds = self.problem.get_combined_bounds()

        from grond import plot
        self.xlim = plot.fixlim(*xpar.scaled(bounds[ixpar]))
        self.ylim = plot.fixlim(*ypar.scaled(bounds[iypar]))

        self.set_limits()

        from matplotlib.colors import LinearSegmentedColormap

        self.cmap = LinearSegmentedColormap.from_list('probability',
                                                      [(1.0, 1.0, 1.0),
                                                       (0.5, 0.9, 0.6)])

        self.writer = None
        if self.movie_filename:
            from matplotlib.animation import FFMpegWriter

            metadata = dict(title=problem.name, artist='Grond')

            self.writer = FFMpegWriter(fps=30,
                                       metadata=metadata,
                                       codec='libx264',
                                       bitrate=200000,
                                       extra_args=[
                                           '-pix_fmt', 'yuv420p', '-profile:v',
                                           'baseline', '-level', '3', '-an'
                                       ])

            self.writer.setup(self.fig, self.movie_filename, dpi=200)

        if self.show:
            plt.ion()
            plt.show()
# save the map in the swarm routine
swarm.plant = plant

if kalman_centralized:
    swarm.central_kalman = centralized_kal(swarm)
'''
    SIMULATION
'''
# Initialize the graph
swarm.generate_graph()

# Initialize the animation object
metadata = dict(title='Distributed AGV collision avoidance',
                artist='Fadini & Piazza',
                comment='Now on file!')
writer = FFMpegWriter(fps=int(1 / dt), metadata=metadata)

print(HEADER + '*' * 25 + '  PARAMETERS SUMMARY  ' + '*' * 25 + ENDC)
print(
    '\tT = {:0.2f}s\n\tN_robots = {:1}\n\tAlgorithm = {:2}\n\tMap = {:3}\n\tCentralized Kalman filter = {:4}\n\tMHE filter = {:5}'
    .format(T, n_robots, avoidance_algorithm, map_case,
            str(kalman_centralized), str(mhe_filter)))
print(OKGREEN + '*' * 24 + ' ALL READY PRESS ENTER  ' + '*' * 24 + ENDC)

_ = input()

tic()
timestamp = 0
time = 0.0

plt.figure('Simulation')
示例#4
0
pop = 150
length = 250

stream = logging.StreamHandler(sys.stdout)
file = logging.FileHandler(filename='output.txt', mode='w')
logging.basicConfig(
    level=logging.DEBUG,
    handlers=(stream, file))  # set up an output file and also print to console
print('Starting simulation with population of {} and length {}...'.format(
    pop, length))

# Create simulation with grid
sim = Simulation()

# Set up initial scatterplot
fig, ax = pp.subplots(figsize=(12, 12))
ax.set_xlim(0, sim.size_x)
ax.set_ylim(0, sim.size_y)
ax.set_title('COVID simulation, {} individuals, t=0'.format(sim.population()))
ax.grid(b=True)

frames = sim.run(pop=pop, length=length)
scatter, = ax.scatter(frames[0][0], frames[0][1], c=frames[0][2], s=5, lw=1),

# Set up animation
animation = FuncAnimation(fig, animation_step, interval=100, blit=True)
pp.show()

# Save a video
writer = FFMpegWriter(fps=20, bitrate=1800)
animation.save('simulation.mp4', writer=writer)
示例#5
0
def save_entity(ani, file_name):
    from matplotlib.animation import FFMpegWriter
    writer = FFMpegWriter(fps=15, metadata=dict(artist='Sham'), bitrate=1800)
    ani.save(f"videomp4/{file_name}.mp4", writer=writer)
示例#6
0
def animate_top_side(
    whole_solution,
    filename='',
):
    """takes the objects and animates them using funcanimation"""

    # defined quantities
    num_of_frames = int(p.snapshots / p.display_rate)
    frame_interval = 1000 / p.fps

    # fig = plt.figure(figsize=(12, 12))
    # ax = plt.axes()
    fig, (ax_top, ax_side) = plt.subplots(1, 2)
    fig.set_figwidth(12)
    ax_top.axis('scaled')
    ax_side.axis('scaled')
    ax_top.set_title('Top view')
    ax_side.set_title('Side view')
    ax_top.set_xlabel('x'), ax_top.set_ylabel('y')
    ax_side.set_xlabel('x'), ax_side.set_ylabel('z')

    min_x = np.nanmin(whole_solution.rs[:, :, 0])
    max_x = np.nanmax(whole_solution.rs[:, :, 0])
    min_y = np.nanmin(whole_solution.rs[:, :, 1])
    max_y = np.nanmax(whole_solution.rs[:, :, 1])
    min_z = np.nanmin(whole_solution.rs[:, :, 2])
    max_z = np.nanmax(whole_solution.rs[:, :, 2])

    if p.display_size:
        ax_top.set_xlim(-p.display_size, p.display_size)
        ax_top.set_ylim(-p.display_size, p.display_size)
        ax_side.set_xlim(-p.display_size, p.display_size)
        ax_side.set_ylim(-p.display_size, p.display_size)
    else:
        ax_top.set_xlim(min_x, max_x)
        ax_top.set_ylim(min_y, max_y)
        ax_side.set_xlim(min_x, max_x)
        ax_side.set_ylim(min_z, max_z)

    circles_top = [
        plt.Circle((0, 0), find_particle_radius(1))
        for _ in range(p.num_particles)
    ]
    circles_side = [
        plt.Circle((0, 0), find_particle_radius(1))
        for _ in range(p.num_particles)
    ]
    for circle in circles_top:
        ax_top.add_artist(circle)
    for circle in circles_side:
        ax_side.add_artist(circle)

    def init():
        """initialisation function for FuncAnimation"""
        for i in range(len(circles_top)):
            circles_top[i].set_center((0, 0))
            circles_top[i].set_color('b')
            circles_top[i].set_radius(1)
        for i in range(len(circles_side)):
            circles_side[i].set_center((0, 0))
            circles_side[i].set_color('b')
            circles_side[i].set_radius(1)

        return circles_top, circles_side

    def update_screen(i):
        """update function for FuncAnimation"""
        i = i * p.display_rate
        N = int(whole_solution.Ns[i])
        for j in range(N):
            circles_top[j].set_center(whole_solution.rs[i, j, (0, 1)])
            circles_top[j].set_radius(
                find_particle_radius(whole_solution.ms[i, j]))
            circles_side[j].set_center(whole_solution.rs[i, j, (0, 2)])
            circles_side[j].set_radius(
                find_particle_radius(whole_solution.ms[i, j]))

        # remove particles that no longer exist
        for j in range(N, p.num_particles):
            circles_top[j].set_center((max_x + 1000, max_y + 1000))
            circles_side[j].set_center((max_x + 1000, max_z + 1000))

        return circles_top, circles_side

    ani = FuncAnimation(fig,
                        update_screen,
                        num_of_frames,
                        init_func=init,
                        blit=False,
                        interval=frame_interval)

    if filename:
        print("saving animation to {}".format(filename))
        if filename[-4:] == ".gif":
            writer = PillowWriter(fps=p.fps)
        elif filename[-4:] == ".mp4":
            writer = FFMpegWriter(fps=p.fps)
        else:
            raise Exception("Animation file format not allowed")
        ani.save(filename, writer=writer)

    plt.show()
def create_video_with_keypoints_only(
    df,
    output_name,
    ind_links=None,
    pcutoff=0.6,
    dotsize=8,
    alpha=0.7,
    background_color="k",
    skeleton_color="navy",
    color_by="bodypart",
    colormap="viridis",
    fps=25,
    dpi=200,
    codec=default_codec,
):
    bodyparts = df.columns.get_level_values("bodyparts")[::3]
    bodypart_names = bodyparts.unique()
    n_bodyparts = len(bodypart_names)
    nx = int(np.nanmax(df.xs("x", axis=1, level="coords")))
    ny = int(np.nanmax(df.xs("y", axis=1, level="coords")))

    n_frames = df.shape[0]
    xyp = df.values.reshape((n_frames, -1, 3))

    if color_by == "bodypart":
        map_ = bodyparts.map(dict(zip(bodypart_names, range(n_bodyparts))))
        cmap = plt.get_cmap(colormap, n_bodyparts)
    elif color_by == "individual":
        try:
            individuals = df.columns.get_level_values("individuals")[::3]
            individual_names = individuals.unique().to_list()
            n_individuals = len(individual_names)
            map_ = individuals.map(dict(zip(individual_names, range(n_individuals))))
            cmap = plt.get_cmap(colormap, n_individuals)
        except KeyError as e:
            raise Exception(
                "Coloring by individuals is only valid for multi-animal data"
            ) from e
    else:
        raise ValueError(f"Invalid color_by={color_by}")

    prev_backend = plt.get_backend()
    plt.switch_backend("agg")
    fig = plt.figure(frameon=False, figsize=(nx / dpi, ny / dpi))
    ax = fig.add_subplot(111)
    scat = ax.scatter([], [], s=dotsize ** 2, alpha=alpha)
    coords = xyp[0, :, :2]
    coords[xyp[0, :, 2] < pcutoff] = np.nan
    scat.set_offsets(coords)
    colors = cmap(map_)
    scat.set_color(colors)
    segs = coords[tuple(zip(*tuple(ind_links))), :].swapaxes(0, 1) if ind_links else []
    coll = LineCollection(segs, colors=skeleton_color, alpha=alpha)
    ax.add_collection(coll)
    ax.set_xlim(0, nx)
    ax.set_ylim(0, ny)
    ax.axis("off")
    ax.add_patch(
        plt.Rectangle(
            (0, 0), 1, 1, facecolor=background_color, transform=ax.transAxes, zorder=-1
        )
    )
    ax.invert_yaxis()
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

    writer = FFMpegWriter(fps=fps, codec=codec)
    with writer.saving(fig, output_name, dpi=dpi):
        writer.grab_frame()
        for index, _ in enumerate(trange(n_frames - 1), start=1):
            coords = xyp[index, :, :2]
            coords[xyp[index, :, 2] < pcutoff] = np.nan
            scat.set_offsets(coords)
            if ind_links:
                segs = coords[tuple(zip(*tuple(ind_links))), :].swapaxes(0, 1)
            coll.set_segments(segs)
            writer.grab_frame()
    plt.close(fig)
    plt.switch_backend(prev_backend)
def main():
    from utils.UtilsMyData import test_framework_MyData as test_framework

    weights = torch.load(args.pretrained_posenet)
    seq_length = int(weights['state_dict']['conv1.0.weight'].size(1) / 3)
    pose_net = PoseExpNet(nb_ref_imgs=seq_length - 1, output_exp=False).to(device)
    pose_net.load_state_dict(weights['state_dict'], strict=False)

    dataset_dir = Path(args.dataset_dir)
    sequences = [args.sequence_idx]
    framework = test_framework(dataset_dir, sequences, seq_length)

    print('{} snippets to test'.format(len(framework)))
    errors = np.zeros((len(framework), 2), np.float32)
    optimized_errors = np.zeros((len(framework), 2), np.float32)
    ICP_iterations = np.zeros(len(framework))
    VO_processing_time = np.zeros(len(framework))
    ICP_iteration_time = np.zeros(len(framework))

    if args.output_dir is not None:
        output_dir = Path(args.output_dir)
        output_dir.makedirs_p()
        predictions_array = np.zeros((len(framework), seq_length, 3, 4))

        '''绝对位姿列表初始化'''
        # 对齐到雷达坐标系,VO模型输出的带有尺度的绝对位姿
        abs_VO_poses = np.zeros((len(framework), 12))
        # 对齐到雷达坐标系,LO模型输出的带有尺度的绝对位姿
        abs_LO_poses = np.zeros((len(framework), 12))
        # 位姿估计值,对齐到相机坐标系下,和真值直接比较(仅适用于有相机坐标系下的真值)
        est_poses = np.zeros((len(framework), 12))
        est_poses[0] = np.identity(4)[:3, :].reshape(-1, 12)[0]

        '''帧间位姿列表初始化'''
        # 对齐到相机坐标系,VO模型输出的带有尺度的帧间位姿
        cur_VO_poses_C = np.zeros((len(framework), 12))
        # 对齐到雷达坐标系,VO模型输出的带有尺度的帧间位姿
        cur_VO_poses = np.zeros((len(framework), 6))
        # 对齐到雷达坐标系,VO模型输出的带有尺度的帧间位姿
        cur_LO_poses = np.zeros((len(framework), 12))
        '''尺度因子'''
        scale_factors = np.zeros((len(framework), 1))

    abs_VO_pose = np.identity(4)[:3, :]

    '''循环周期所需变量初始化'''
    # 用来估计尺度因子
    last_pose = np.identity(4)
    last_gap2_pose = np.identity(4)
    last_VO_pose = np.identity(4)
    # 周期中涉及到的点云,当前,前一个,前前一个
    curr_pts = None
    last_pts = None
    sec_last_pts = None
    # 当前节点和前一个,前前一个节点的帧间预估
    gap1_VO_pose = np.identity(4)
    gap2_VO_pose = np.identity(4)
    last_gap2_VO_pose = np.identity(4)

    # L和C的转换矩阵,对齐输入位姿到雷达坐标系
    Transform_matrix_L2C = np.identity(4)
    Transform_matrix_L2C[:3, :3] = np.array([[-1.51482698e-02, -9.99886648e-01, 5.36310553e-03],
                                             [-4.65337018e-03, -5.36307196e-03, -9.99969412e-01],
                                             [9.99870070e-01, -1.56647995e-02, -4.48880010e-03]])
    Transform_matrix_L2C[:3, -1:] = np.array([4.29029924e-03, -6.08539196e-02, -9.20346161e-02]).reshape(3, 1)
    Transform_matrix_L2C = GramSchmidtHelper(Transform_matrix_L2C)
    Transform_matrix_C2L = np.linalg.inv(Transform_matrix_L2C)

    pointClouds = loadPointCloud(args.dataset_dir + "/sequences/" + args.sequence_idx + "/velodyne")

    '''Pose Graph Manager (for back-end optimization) initialization'''
    PGM = PoseGraphManager()
    PGM.addPriorFactor()
    num_frames = len(tqdm(framework))
    save_dir = "result/" + args.sequence_idx
    if not os.path.exists(save_dir): os.makedirs(save_dir)
    ResultSaver = PoseGraphResultSaver(init_pose=PGM.curr_se3,
                                       save_gap=args.save_gap,
                                       num_frames=num_frames,
                                       seq_idx=args.sequence_idx,
                                       save_dir=save_dir)
    '''Scan Context Manager (for loop detection) initialization'''
    SCM = ScanContextManager(shape=[args.num_rings, args.num_sectors],
                             num_candidates=args.num_candidates,
                             threshold=args.loop_threshold)
    '''Mapping initialzation'''
    if args.mapping is True:
        Map = MappingManager()

    # for save the result as a video
    fig_idx = 1
    fig = plt.figure(fig_idx)
    writer = FFMpegWriter(fps=15)
    video_name = args.sequence_idx + "_" + str(args.num_icp_points) + "_prop@" + str(
        args.proposal) + "_tolerance@" + str(
        args.tolerance) + "_scm@" + str(args.scm_type) + "_thresh@" + str(args.loop_threshold) + ".mp4"
    num_frames_to_skip_to_show = 5
    num_frames_to_save = np.floor(num_frames / num_frames_to_skip_to_show)
    with writer.saving(fig, video_name, num_frames_to_save):  # this video saving part is optional

        for j, sample in enumerate(tqdm(framework)):
            '''
            ***************************************VO部分*******************************************
            '''
            imgs = sample['imgs']
            w, h = imgs[0].size
            if (not args.no_resize) and (h != args.img_height or w != args.img_width):
                imgs = [(np.array(img.resize((args.img_width, args.img_height)))).astype(np.float32) for img in imgs]
            imgs = [np.transpose(img, (2, 0, 1)) for img in imgs]

            ref_imgs = []
            for i, img in enumerate(imgs):
                img = torch.from_numpy(img).unsqueeze(0)
                img = ((img / 255 - 0.5) / 0.5).to(device)
                if i == len(imgs) // 2:
                    tgt_img = img
                else:
                    ref_imgs.append(img)

            startTimeVO = time.time()
            _, poses = pose_net(tgt_img, ref_imgs)
            VO_processing_time[j] = time.time() - startTimeVO

            poses = poses.cpu()[0]
            poses = torch.cat([poses[:len(imgs) // 2], torch.zeros(1, 6).float(), poses[len(imgs) // 2:]])

            inv_transform_matrices = pose_vec2mat(poses, rotation_mode=args.rotation_mode).numpy().astype(np.float64)

            rot_matrices = np.linalg.inv(inv_transform_matrices[:, :, :3])
            tr_vectors = -rot_matrices @ inv_transform_matrices[:, :, -1:]

            transform_matrices = np.concatenate([rot_matrices, tr_vectors], axis=-1)
            # print('**********DeepVO result: time_cost {:.3} s'.format(timeCostVO / (len(imgs) - 1)))
            # 将对[0 1 2]中间1的转换矩阵变成对0的位姿转换
            first_inv_transform = inv_transform_matrices[0]
            final_poses = first_inv_transform[:, :3] @ transform_matrices
            final_poses[:, :, -1:] += first_inv_transform[:, -1:]

            # gap2_VO_pose取final poses的第3项(0-2)
            gap2_VO_pose[:3, :] = final_poses[2]
            # cur_VO_pose取final poses的第2项(0-1)
            gap1_VO_pose[:3, :] = final_poses[1]

            # 尺度因子的确定:采用上一帧的LO输出位姿和VO输出位姿的尺度比值作为当前帧的尺度因子,初始尺度为1
            if j == 0:
                scale_factor = 7
            else:
                scale_factor = math.sqrt(np.sum(last_pose[:3, -1] ** 2) / np.sum(last_VO_pose[:3, -1] ** 2))
            scale_factors[j] = scale_factor
            last_VO_pose = copy.deepcopy(gap1_VO_pose)  # 注意深拷贝
            # 先尺度修正,再对齐坐标系,施密特正交化避免病态矩阵
            gap2_VO_pose[:3, -1:] = gap2_VO_pose[:3, -1:] * scale_factor
            gap2_VO_pose = Transform_matrix_C2L @ gap2_VO_pose @ np.linalg.inv(Transform_matrix_C2L)
            gap2_VO_pose = GramSchmidtHelper(gap2_VO_pose)

            gap1_VO_pose[:3, -1:] = gap1_VO_pose[:3, -1:] * scale_factor
            cur_VO_poses_C[j] = gap1_VO_pose[:3, :].reshape(-1, 12)[0]
            gap1_VO_pose = Transform_matrix_C2L @ gap1_VO_pose @ np.linalg.inv(Transform_matrix_C2L)
            gap1_VO_pose = GramSchmidtHelper(gap1_VO_pose)

            '''*************************LO部分******************************************'''
            # 初始化
            if j == 0:
                last_pts = random_sampling(pointClouds[j], args.num_icp_points)
                sec_last_pts = last_pts
                SCM.addNode(j, last_pts)
                if args.mapping is True:
                    Map.updateMap(curr_se3=PGM.curr_se3,curr_local_ptcloud=last_pts)

            curr_pts = random_sampling(pointClouds[j + 1], args.num_icp_points)

            from modules.ICP import icp
            # 选择LO的初值预估,分别是无预估,上一帧位姿,VO位姿
            if args.proposal == 0:
                init_pose_1 = None
                init_pose_2 = None
            elif args.proposal == 1:
                init_pose_1 = last_pose
                init_pose_2 = last_gap2_pose
            elif args.proposal == 2:
                init_pose_1 = gap1_VO_pose
                init_pose_2 = last_gap2_VO_pose
            print("init_pose_1 ")
            print(init_pose_1)
            print("init_pose_2 ")
            print(init_pose_2)
            startTime = time.time()
            icp_odom_transform_1, distacnces, iterations = icp(curr_pts, last_pts, init_pose=init_pose_1,
                                                               tolerance=args.tolerance,
                                                               max_iterations=50)
            if j > 0:
                icp_odom_transform_2, distacnces_2, iterations_2 = icp(last_pts, sec_last_pts, init_pose=init_pose_2,
                                                                       tolerance=args.tolerance,
                                                                       max_iterations=50)
            else:
                icp_odom_transform_2 = icp_odom_transform_1
                distacnces_2 = distacnces
                iterations_2 = iterations

            ICP_iteration_time[j] = time.time() - startTime
            ICP_iterations[j] = (iterations + iterations_2) / 2

            print("last_pose ")
            print(last_pose)
            '''更新指针'''
            # 点云
            sec_last_pts = last_pts
            last_pts = curr_pts
            # gap2预估位姿
            last_gap2_VO_pose = copy.deepcopy(gap2_VO_pose)
            # icp 输出位姿
            last_pose = icp_odom_transform_1
            last_gap2_pose = icp_odom_transform_2

            print("LO优化后的位姿,mean_dis: ", np.asarray(distacnces).mean())
            print(icp_odom_transform_1)

            print("icp_odom_transform_2")
            print(icp_odom_transform_2)

            '''Update loop detection nodes'''
            SCM.addNode(j + 1, curr_pts)
            '''Update the edges and nodes of pose graph'''
            PGM.curr_node_idx = j + 1
            PGM.curr_se3 = np.matmul(PGM.curr_se3, icp_odom_transform_1)
            PGM.addOdometryFactor(icp_odom_transform_1, -1)
            PGM.addOdometryFactor(icp_odom_transform_2, -2)
            PGM.sec_prev_node_idx = PGM.prev_node_idx
            PGM.prev_node_idx = PGM.curr_node_idx

            # 建图更新
            if args.mapping is True:
                Map.updateMap(curr_se3=PGM.curr_se3,curr_local_ptcloud=curr_pts)

            # loop detection and optimize the graph
            if (PGM.curr_node_idx > 1 and PGM.curr_node_idx % args.try_gap_loop_detection == 0):
                # 1/ loop detection
                loop_idx, loop_dist, yaw_diff_deg = SCM.detectLoop()
                if (loop_idx == None):  # NOT FOUND
                    pass
                else:
                    print("Loop event detected: ", PGM.curr_node_idx, loop_idx, loop_dist)
                    # 2-1/ add the loop factor
                    loop_scan_down_pts = SCM.getPtcloud(loop_idx)
                    loop_transform, _, _ = icp(curr_pts, loop_scan_down_pts,
                                               init_pose=yawdeg2se3(yaw_diff_deg), max_iterations=20)
                    PGM.addLoopFactor(loop_transform, loop_idx)

                    # 2-2/ graph optimization
                    PGM.optimizePoseGraph()

                    # 2-2/ save optimized poses
                    ResultSaver.saveOptimizedPoseGraphResult(PGM.curr_node_idx, PGM.graph_optimized)

                    # 2-3/ updateMap
                    if args.mapping is True:
                        Map.optimizeGlobalMap(PGM.graph_optimized, PGM.curr_node_idx)

            # 定时进行位姿图优化
            if (PGM.curr_node_idx > 1 and PGM.curr_node_idx % args.optimization_period == 0):
                # 2-2/ graph optimization
                PGM.optimizePoseGraph()

                # 2-2/ save optimized poses
                ResultSaver.saveOptimizedPoseGraphResult(PGM.curr_node_idx, PGM.graph_optimized)

                # 2-3/ updateMap
                if args.mapping is True:
                    Map.optimizeGlobalMap(PGM.graph_optimized, PGM.curr_node_idx)

            # save the ICP odometry pose result (no loop closure)
            ResultSaver.saveUnoptimizedPoseGraphResult(PGM.curr_se3, PGM.curr_node_idx)
            if (j % num_frames_to_skip_to_show == 0):
                ResultSaver.vizCurrentTrajectory(fig_idx=fig_idx)
                writer.grab_frame()
            if args.vizmapping is True:
                Map.vizMapWithOpen3D()

            # 对齐到雷达坐标系下,VO输出的绝对位姿
            abs_VO_pose[:, :3] = gap1_VO_pose[:3, :3] @ abs_VO_pose[:, :3]
            abs_VO_pose[:, -1:] += gap1_VO_pose[:3, -1:]

            if args.output_dir is not None:
                predictions_array[j] = final_poses
                # cur_VO_poses[j]=cur_VO_pose[:3, :].reshape(-1, 12)[0]
                cur_LO_poses[j] = icp_odom_transform_1[:3, :].reshape(-1, 12)[0]
                abs_VO_poses[j] = abs_VO_pose[:3, :].reshape(-1, 12)[0]
                abs_LO_poses[j] = PGM.curr_se3[:3, :].reshape(-1, 12)[0]
                est_pose = Transform_matrix_L2C @ PGM.curr_se3 @ np.linalg.inv(Transform_matrix_L2C)
                est_poses[j + 1] = est_pose[:3, :].reshape(-1, 12)[0]

        if args.mapping is True:
            Map.saveMap2File('map_' + args.sequence_idx + "_" + str(args.num_icp_points) + "_prop@" + str(
                args.proposal) + "_tolerance@" + str(args.tolerance) + "_scm@" + str(args.scm_type) +
                             "_thresh@" + str(args.loop_threshold) + '.pcd')
        if args.output_dir is not None:
            # np.save(output_dir / 'predictions.npy', predictions_array)
            np.savetxt(output_dir / 'scale_factors.txt', scale_factors)
            np.savetxt(output_dir / 'cur_VO_poses_C.txt', cur_VO_poses_C)
            np.savetxt(output_dir / 'cur_VO_poses.txt', cur_VO_poses)
            np.savetxt(output_dir / 'abs_VO_poses.txt', abs_VO_poses)
            np.savetxt(output_dir / 'cur_LO_poses.txt', cur_LO_poses)
            np.savetxt(output_dir / 'abs_LO_poses.txt', abs_LO_poses)
            np.savetxt(output_dir / 'iterations.txt', ICP_iterations)
            np.savetxt(output_dir / 'est_kitti_{0}_poses.txt'.format(args.sequence_idx), est_poses)

        # VO输出位姿的精度指标
        mean_errors = errors.mean(0)
        std_errors = errors.std(0)
        error_names = ['ATE', 'RE']
        print('')
        print("VO_Results")
        print("\t {:>10}, {:>10}".format(*error_names))
        print("mean \t {:10.4f}, {:10.4f}".format(*mean_errors))
        print("std \t {:10.4f}, {:10.4f}".format(*std_errors))

        # LO二次优化后的精度指标
        optimized_mean_errors = optimized_errors.mean(0)
        optimized_std_errors = optimized_errors.std(0)
        optimized_error_names = ['optimized_ATE', 'optimized_RE']
        print('')
        print("LO_optimized_Results")
        print("\t {:>10}, {:>10}".format(*optimized_error_names))
        print("mean \t {:10.4f}, {:10.4f}".format(*optimized_mean_errors))
        print("std \t {:10.4f}, {:10.4f}".format(*optimized_std_errors))

        # 迭代次数
        mean_iterations = ICP_iterations.mean()
        std_iterations = ICP_iterations.std()
        _names = ['iteration']
        print('')
        print("LO迭代次数")
        print("\t {:>10}".format(*_names))
        print("mean \t {:10.4f}".format(mean_iterations))
        print("std \t {:10.4f}".format(std_iterations))

        # 迭代时间
        mean_iter_time = ICP_iteration_time.mean()
        std_iter_time = ICP_iteration_time.std()
        _names = ['iter_time']
        print('')
        print("LO迭代时间:单位/s")
        print("\t {:>10}".format(*_names))
        print("mean \t {:10.4f}".format(mean_iter_time))
        print("std \t {:10.4f}".format(std_iter_time))

def init():  # only required for blitting to give a clean slate.
    for j in range(len(lines)):
        x, y = ellipse(aa[j], er[j], 0.0, tt)
        xy = np.array([[xi, yi] for xi, yi in zip(x, y)])
        lines[j].set_xy(xy)
    return lines


def animate(i):
    for j in range(len(lines)):
        x, y = ellipse(aa[j], er[j], pr[j] * i, tt)
        xy = np.array([[xi, yi] for xi, yi in zip(x, y)])
        lines[j].set_xy(xy)
    return lines


ani = animation.FuncAnimation(fig,
                              animate,
                              init_func=init,
                              interval=1,
                              blit=True,
                              frames=800,
                              save_count=50)

from matplotlib.animation import FFMpegWriter
writer = FFMpegWriter(fps=50,
                      metadata=dict(artist=ARTIST, comment=COMMENT),
                      bitrate=300)
ani.save(OUTPUTFILE, writer=writer)
示例#10
0
def update(r):
    text.set_text(r[0])
    k.set_height(r[1] + 0.1)
    g.set_height(r[2] + 0.1)
    c.set_height(r[3] + 0.1)
    o.set_height(r[4] + 0.1)
    return text, k, g, c, o


fig, ax = plt.subplots()
k = patches.Rectangle((0, 0), 0.8, 0.1, fc='C8')
g = patches.Rectangle((1, 0), 0.8, 0.1, fc='C0')
c = patches.Rectangle((2, 0), 0.8, 0.1, fc='C6')
o = patches.Rectangle((3, 0), 0.8, 0.1, fc='C3')
labels = ['Kittiwakes', 'Guillemots', 'Chicks', 'Others']
text = plt.text(0, max_count, '', va='top', size=18)
ax.set_ylim(0, max_count)
ax.set_xlim(-0.2, 4)
plt.xticks([0.4, 1.4, 2.4, 3.4], labels)
values = zip(date, kc, gc, cc, oc)
ani = FuncAnimation(fig,
                    update,
                    values,
                    interval=interval,
                    init_func=init,
                    save_count=len(date))

writer = FFMpegWriter(fps=1000 // interval)
ani.save('seabirdwatch_by_{0}_smooth.mp4'.format(freq), writer=writer, dpi=100)
示例#11
0
def animation_bar_chart(csv_path,
                        xticklabels=None,
                        ylim=None,
                        sort_by_reward=True,
                        style='whitegrid',
                        save=False,
                        filename_tag='expert_weight',
                        timestamp=False,
                        plot_path=''):
    """Animate bar chart"""

    data = ReadCSV(csv_path)
    data.df.set_index('time', inplace=True)
    expert_weights = data.df['expert_weights']
    n_experts = len(expert_weights.loc[1])
    sns.set_theme(style=style)
    if xticklabels is None:
        xticklabels = np.arange(n_experts)
        rotate_xticklabels = True
    else:
        rotate_xticklabels = False
    x = np.arange(n_experts)
    if ylim is None:
        expert_final_weights = expert_weights.loc[len(expert_weights)]
        ylim = (0, min(1, 1.1 * np.max(expert_final_weights)))
    if sort_by_reward:
        expert_mean_rewards = np.array(
            data.df['expert_mean_rewards'].values.tolist())
        expert_mean_rewards_over_time = np.sum(
            expert_mean_rewards, axis=0) / len(expert_mean_rewards)
        sorted_ids = expert_mean_rewards_over_time.argsort()
        for t in range(1, len(expert_weights) + 1):
            expert_weights.at[t] = expert_weights.at[t][sorted_ids]
        xticklabels = np.array(xticklabels)[sorted_ids]
    width = .6
    fig, ax = plt.subplots()

    def init():
        ax.clear()
        ax.set_xlim(-.5, n_experts - .5)
        ax.set_ylim(*ylim)

    def animate(frame_id):
        init()
        time = frame_id + 1
        ax.bar(x, expert_weights.loc[time], width)
        ax.set_xlabel('Expert')
        ax.set_xticks(x)
        ax.set_ylabel('Weight')
        ax.set_title(f'Time {time}')
        if rotate_xticklabels:
            ax.set_xticklabels(xticklabels)
        else:
            ax.set_xticklabels(xticklabels, rotation=45, ha='right')
        fig.tight_layout()

    anim = FuncAnimation(fig,
                         animate,
                         init_func=init,
                         frames=len(expert_weights),
                         interval=200,
                         repeat=False)

    if save:
        if timestamp:
            filename = f'animated_bar_chart_{filename_tag}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.mp4'
        else:
            filename = f'animated_bar_chart_{filename_tag}.mp4'
        FFwriter = FFMpegWriter(fps=10)
        anim.save(plot_path + filename, writer=FFwriter)
    plt.show()
    plt.pause(3)
    plt.close()
示例#12
0
def main():
    t = time_logging.start()
    device = torch.device("cuda:0")

    torch.manual_seed(16)
    f = Model().to(device)

    x = np.load("{}/119.npy".format(os.path.dirname(__file__)))
    x = np.pad(x, 30, "constant")
    x = x.astype(np.float32)
    x = torch.tensor(x, device=device, dtype=torch.float32)

    # x = x[::2, ::2, ::2]

    n = 361
    angles = np.linspace(0, np.pi, n)

    fig = plt.figure(figsize=(13, 13))

    ax1 = fig.add_axes([0, 0, .5, .5])  # x
    ax2 = fig.add_axes([0, .5, .5, .5])  # R(x)
    ax3 = fig.add_axes([.5, .5, .5, .5])  # f(R(x))
    ax4 = fig.add_axes([.5, 0, .5, .5])  # R-1(f(R(x)))

    y = f(x)
    image1 = ax1.imshow(project(x), interpolation='none')
    image2 = ax2.imshow(project(x), interpolation='none')
    image3 = ax3.imshow(project(y), interpolation='none')
    image4 = ax4.imshow(project(y), interpolation='none')

    for image in [image1, image2, image3, image4]:
        image.set_cmap("summer")

    ax1.text(0.5,
             0.99,
             'stabilized input',
             horizontalalignment='center',
             verticalalignment='top',
             transform=ax1.transAxes,
             color='white',
             fontsize=30)
    ax2.text(0.5,
             0.99,
             'input',
             horizontalalignment='center',
             verticalalignment='top',
             transform=ax2.transAxes,
             color='white',
             fontsize=30)
    ax3.text(0.5,
             0.99,
             'featuremap',
             horizontalalignment='center',
             verticalalignment='top',
             transform=ax3.transAxes,
             color='white',
             fontsize=30)
    ax4.text(0.5,
             0.99,
             'stabilized featuremap',
             horizontalalignment='center',
             verticalalignment='top',
             transform=ax4.transAxes,
             color='white',
             fontsize=30)
    text = ax2.text(0.5,
                    0.8,
                    '',
                    horizontalalignment='center',
                    verticalalignment='center',
                    transform=ax2.transAxes,
                    color='white',
                    fontsize=40)

    for ax in [ax1, ax2, ax3, ax4]:
        plt.sca(ax)
        plt.axis('off')

    time_logging.end("init", t)

    def init():
        image2.set_data(project(x))
        y = f(x)
        image3.set_data(project(y))
        image4.set_data(project(y))
        text.set_text("")
        return image2, image3, image4, text

    def animate(i):
        print("rendering {} / {}".format(i, n))
        alpha = angles[i]

        rx = rotate(x, alpha)
        frx = f(rx)
        rfrx = rotate(frx, -alpha)

        image2.set_data(project(rx))
        image3.set_data(project(frx))
        image4.set_data(project(rfrx))
        text.set_text(r"$\frac{{ {} \pi }}{{ {} }}$".format(i, n - 1))
        return image2, image3, image4, text

    ani = animation.FuncAnimation(fig,
                                  animate,
                                  init_func=init,
                                  interval=2,
                                  blit=True,
                                  save_count=n)

    from matplotlib.animation import FFMpegWriter
    writer = FFMpegWriter(fps=12, bitrate=3000)
    ani.save("movie.mp4", writer=writer)

    print(time_logging.text_statistics())
示例#13
0
def make_real_vs_sampled_movies(ims_recon,
                                ims_recon_samp,
                                conditional,
                                save_file=None,
                                frame_rate=15):
    """Produce movie with (AE) reconstructed video and sampled video.

    Parameters
    ----------
    ims_recon : :obj:`np.ndarray`
        shape (n_frames, y_pix, x_pix)
    ims_recon_samp : :obj:`np.ndarray`
        shape (n_frames, y_pix, x_pix)
    conditional : :obj:`bool`
        conditional vs unconditional samples; for creating reconstruction title
    save_file : :obj:`str`, optional
        full save file (path and filename)
    frame_rate : :obj:`float`, optional
        frame rate of saved movie

    """

    n_frames = ims_recon.shape[0]

    n_plots = 2
    [y_pix, x_pix] = ims_recon[0].shape
    fig_dim_div = x_pix * n_plots / 10  # aiming for dim 1 being 10
    x_dim = x_pix * n_plots / fig_dim_div
    y_dim = y_pix / fig_dim_div
    fig, axes = plt.subplots(1, n_plots, figsize=(x_dim, y_dim))

    for j in range(2):
        axes[j].set_xticks([])
        axes[j].set_yticks([])

    axes[0].set_title('Real Reconstructions\n', fontsize=16)
    if conditional:
        title_str = 'Generative Reconstructions\n(Conditional)'
    else:
        title_str = 'Generative Reconstructions\n(Unconditional)'
    axes[1].set_title(title_str, fontsize=16)
    fig.tight_layout(pad=0)

    im_kwargs = {'cmap': 'gray', 'vmin': 0, 'vmax': 1, 'animated': True}
    ims = []
    for i in range(n_frames):
        ims_curr = []
        im = axes[0].imshow(ims_recon[i], **im_kwargs)
        ims_curr.append(im)
        im = axes[1].imshow(ims_recon_samp[i], **im_kwargs)
        ims_curr.append(im)
        ims.append(ims_curr)

    ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000)
    writer = FFMpegWriter(fps=frame_rate, bitrate=-1)

    if save_file is not None:
        make_dir_if_not_exists(save_file)
        if save_file[-3:] != 'mp4':
            save_file += '.mp4'
        print('saving video to %s...' % save_file, end='')
        ani.save(save_file, writer=writer)
        print('done')
示例#14
0
def make_syllable_movies(ims_orig,
                         state_list,
                         trial_idxs,
                         save_file=None,
                         max_frames=400,
                         frame_rate=10,
                         n_buffer=5,
                         n_pre_frames=3,
                         n_rows=None,
                         single_syllable=None):
    """Present video clips of each individual syllable in separate panels

    Parameters
    ----------
    ims_orig : :obj:`np.ndarray`
        shape (n_frames, n_channels, y_pix, x_pix)
    state_list : :obj:`list`
        each entry (one per state) contains all occurences of that discrete state by
        :obj:`[chunk number, starting index, ending index]`
    trial_idxs : :obj:`array-like`
        indices into :obj:`states` for which trials should be plotted
    save_file : :obj:`str`
        full save file (path and filename)
    max_frames : :obj:`int`, optional
        maximum number of frames to animate
    frame_rate : :obj:`float`, optional
        frame rate of saved movie
    n_buffer : :obj:`int`
        number of blank frames between syllable instances
    n_pre_frames : :obj:`int`
        number of behavioral frames to precede each syllable instance
    n_rows : :obj:`int` or :obj:`NoneType`
        number of rows in output movie
    single_syllable : :obj:`int` or :obj:`NoneType`
        choose only a single state for movie

    """

    K = len(state_list)

    # Initialize syllable movie frames
    plt.clf()
    if single_syllable is not None:
        K = 1
        fig_width = 5
        n_rows = 1
    else:
        fig_width = 10  # aiming for dim 1 being 10
    # get video dims
    bs, n_channels, y_dim, x_dim = ims_orig[0].shape
    movie_dim1 = n_channels * y_dim
    movie_dim2 = x_dim
    if n_rows is None:
        n_rows = int(np.floor(np.sqrt(K)))
    n_cols = int(np.ceil(K / n_rows))

    fig_dim_div = movie_dim2 * n_cols / fig_width
    fig_width = (movie_dim2 * n_cols) / fig_dim_div
    fig_height = (movie_dim1 * n_rows) / fig_dim_div
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height))

    for i, ax in enumerate(fig.axes):
        ax.set_yticks([])
        ax.set_xticks([])
        if i >= K:
            ax.set_axis_off()
        elif single_syllable is not None:
            ax.set_title('Syllable %i' % single_syllable, fontsize=16)
        else:
            ax.set_title('Syllable %i' % i, fontsize=16)
    fig.tight_layout(pad=0, h_pad=1.005)

    imshow_kwargs = {'animated': True, 'cmap': 'gray', 'vmin': 0, 'vmax': 1}

    ims = [[] for _ in range(max_frames + bs + 200)]

    # Loop through syllables
    for i_k, ax in enumerate(fig.axes):

        # skip if no syllable in this axis
        if i_k >= K:
            continue
        print('processing syllable %i/%i' % (i_k + 1, K))
        # skip if no syllables are longer than threshold
        if len(state_list[i_k]) == 0:
            continue

        if single_syllable is not None:
            i_k = single_syllable

        i_chunk = 0
        i_frame = 0

        while i_frame < max_frames:

            if i_chunk >= len(state_list[i_k]):
                # show blank if out of syllable examples
                im = ax.imshow(np.zeros((movie_dim1, movie_dim2)),
                               **imshow_kwargs)
                ims[i_frame].append(im)
                i_frame += 1
            else:

                # Get movies/latents
                chunk_idx = state_list[i_k][i_chunk, 0]
                which_trial = trial_idxs[chunk_idx]
                tr_beg = state_list[i_k][i_chunk, 1]
                tr_end = state_list[i_k][i_chunk, 2]
                batch = ims_orig[which_trial]
                movie_chunk = batch[max(tr_beg - n_pre_frames, 0):tr_end]

                movie_chunk = np.concatenate(
                    [movie_chunk[:, j] for j in range(movie_chunk.shape[1])],
                    axis=1)

                # if np.sum(states[chunk_idx][tr_beg:tr_end-1] != i_k) > 0:
                #     raise ValueError('Misaligned states for syllable segmentation')

                # Loop over this chunk
                for i in range(movie_chunk.shape[0]):

                    im = ax.imshow(movie_chunk[i], **imshow_kwargs)
                    ims[i_frame].append(im)

                    # Add red box if start of syllable
                    syllable_start = n_pre_frames if tr_beg >= n_pre_frames else tr_beg

                    if syllable_start <= i < (syllable_start + 2):
                        rect = matplotlib.patches.Rectangle((5, 5),
                                                            10,
                                                            10,
                                                            linewidth=1,
                                                            edgecolor='r',
                                                            facecolor='r')
                        im = ax.add_patch(rect)
                        ims[i_frame].append(im)

                    i_frame += 1

                # Add buffer black frames
                for j in range(n_buffer):
                    im = ax.imshow(np.zeros((movie_dim1, movie_dim2)),
                                   **imshow_kwargs)
                    ims[i_frame].append(im)
                    i_frame += 1

                i_chunk += 1

    print('creating animation...', end='')
    ani = animation.ArtistAnimation(
        fig, [ims[i] for i in range(len(ims)) if ims[i] != []],
        interval=20,
        blit=True,
        repeat=False)
    writer = FFMpegWriter(fps=max(frame_rate, 10), bitrate=-1)
    print('done')

    if save_file is not None:
        # put together file name
        if save_file[-3:] == 'mp4':
            save_file = save_file[:-3]
        if single_syllable is not None:
            state_str = str('_syllable-%02i' % single_syllable)
        else:
            state_str = ''
        save_file += state_str
        save_file += '.mp4'
        make_dir_if_not_exists(save_file)
        print('saving video to %s...' % save_file, end='')
        ani.save(save_file, writer=writer)
        print('done')
示例#15
0
def init():  # only required for blitting to give a clean slate.
    line.set_ydata(np.zeros(ecg.Time.shape[0]))
    return line,


def animate(t):
    mask = ecg.Time <= t
    line.set_xdata(ecg.Time[mask])  # update the data.
    line.set_ydata(ecg.L1[mask])  # update the data.
    point.set_xdata(ecg.Time[mask].values[-1])
    point.set_ydata(ecg.L1[mask].values[-1])
    return line,


ani = animation.FuncAnimation(fig,
                              animate,
                              init_func=init,
                              interval=0.01,
                              frames=ecg.Time)

# To save the animation, use e.g.
#
# ani.save("ecg.mp4")
#
# or
#
from matplotlib.animation import FFMpegWriter
writer = FFMpegWriter(fps=10, metadata=dict(artist='Me'), bitrate=1800)
ani.save("ecg.mp4", writer=writer)

plt.show()
示例#16
0
def dynamic_death_recover_plot(data, province):
    """ generating a video to record the dynamic changes of deaths and recovers in the given province

       Parameters:
       argument1 (DataFrame): the data read from covid19.csv by pandas
       argument2 (str): The name of province


       Returns:
       None

      """
    # filter the data by province name
    info_of_province = data[data["prname"] == province]
    # get the column 'date'
    dates = info_of_province.iloc[:, 3].values
    # get the column number of deaths
    num_deaths = info_of_province.iloc[:, 6].values
    # fill in the column 'numrecover' with its previous value
    info_of_province['numrecover'] = info_of_province['numrecover'].fillna(
        method='ffill')
    # get the column number of recover
    num_recovers = info_of_province.iloc[:, 9].values
    length = len(dates)

    fig, ax = plt.subplots()
    ax.set_ylabel('Number of people')
    plt.title("The change of Deaths and Recovers in " + province)
    metadata = dict(title='Dynamic changes of deaths and recovers',
                    artist='Matplotlib',
                    comment='Movie support!')
    # create a writer
    writer = FFMpegWriter(fps=2, metadata=metadata)
    # use writer to save into video
    with writer.saving(fig, "Dynamic-death-change-" + province + '.mp4', 100):
        # plotting pictures and use writer to capture it
        for i in range(0, int(length / 5) + 1):
            i += 1
            right = i * 5
            if right > length:
                right = length
            x = dates[0:right]
            y = num_deaths[0:right]
            y1 = num_recovers[0:right]
            ax.plot(x,
                    y,
                    '-o',
                    markevery=[right - 1],
                    color='black',
                    markerfacecolor='red',
                    label="number of deaths")
            ax.plot(x,
                    y1,
                    '-o',
                    markevery=[right - 1],
                    color='green',
                    markerfacecolor='blue',
                    label="number of recovers")
            if i == 1:
                ax.legend(loc='upper left')
            ax.axis([0, 10, 0, num_recovers[-1] + 2])
            xticks = list(range(0, len(dates), 7))
            xlabels = [dates[i] for i in xticks]
            ax.set_xticks(xticks)
            ax.set_xticklabels(xlabels, rotation=40)
            plt.tight_layout()
            writer.grab_frame()
示例#17
0
def create_overview_video(dpmat_fname,
                          regmat_fname,
                          rawmat_fname,
                          movie_fname,
                          fps: int,
                          metadata={},
                          mc_brighten_factor=1.2,
                          codec='vp9'):
    ''' Write an overview video of the generated calcium traces, motion correction and
    temporal-spatial extracted cell activity. '''

    # configure plotting defaults
    matplotlib.rcParams['figure.autolayout'] = True
    plt.rc('axes.spines', top=False, right=False)
    sns.set()
    sns.set_style('white')

    f_dp = h5py.File(dpmat_fname, 'r')
    imax_mc = np.array(f_dp['imax'])
    pix_w = imax_mc.shape[0]
    pix_h = imax_mc.shape[1]
    log.info('Result image dimensions: {}x{}'.format(pix_w, pix_h))

    with ThreadPoolExecutor(max_workers=4) as e:
        frames_res_future = e.submit(calculate_frames_res, dpmat_fname, pix_w,
                                     pix_h)
        frames_reg_future = e.submit(calculate_frames_reg, regmat_fname,
                                     mc_brighten_factor)
        raw_range_future = e.submit(calculate_rawdata_range, rawmat_fname)

        raw_vmin, raw_vmax = raw_range_future.result()
        frames_reg = frames_reg_future.result()
        frames_res, sigs_zsc = frames_res_future.result()

    log.info('Parallel processing finished.')
    f_raw = h5py.File(rawmat_fname, 'r')
    frames_raw = f_raw['frame_all']

    trace_range = 15 * fps
    trace_hl_range = 8 * fps

    plt.close('all')
    plt.ioff()
    if not metadata:
        metadata = dict(title='Calcium trace analysis overview')

    sns.set_palette('deep')
    plt.style.use('dark_background')

    # prepare figure for animation
    log.info('Preparing video figure template')
    fig = plt.figure(figsize=(20, 10), dpi=60)
    gs = fig.add_gridspec(2, 3)
    ax_raw = fig.add_subplot(gs[0, 0])
    ax_reg = fig.add_subplot(gs[0, 1])
    ax_res = fig.add_subplot(gs[0, 2])
    ax_trace = fig.add_subplot(gs[1, :])

    # prepare axes
    ax_raw.axis('off')
    ax_reg.axis('off')
    ax_res.axis('off')

    ax_trace.get_yaxis().set_visible(False)
    ax_trace.get_xaxis().set_visible(False)
    sns.despine(ax=ax_trace, left=True, top=True, right=True, trim=True)

    # add title placeholder
    fig.suptitle('frame: ??????', fontsize=12, ha='right', x=0.98)

    # data placeholders
    spf_raw = ax_raw.imshow(np.zeros((pix_h, pix_w)),
                            cmap=cmap_kindlmann_extended,
                            vmin=raw_vmin,
                            vmax=raw_vmax)
    ax_raw.set_title('Raw')

    spf_reg = ax_reg.imshow(np.zeros((pix_h, pix_w)),
                            cmap=cmap_kindlmann_extended,
                            vmin=0.0,
                            vmax=1.0)
    ax_reg.set_title('After MC')

    spf_res = ax_res.imshow(np.zeros((pix_h, pix_w)),
                            cmap=cmap_kindlmann_extended,
                            vmin=0.0,
                            vmax=1.0)
    ax_res.set_title('Processed')

    # create color mapping for our traces
    cmap_trace_base = discrete_cmap_for_n(len(sigs_zsc[0, :]), 'rainbow')
    cmap_trace_hl = list(map(lambda x: x * 0.90, cmap_trace_base))
    cmap_trace_bg = list(map(lambda x: x * 0.60, cmap_trace_hl))

    spf_traces = []
    sigs_zsc_plotadj = np.zeros(sigs_zsc.shape, dtype=sigs_zsc.dtype)
    plot_sigs_vismask = np.zeros(sigs_zsc.shape, dtype=bool)
    traces_ymin = 0
    traces_ymax = 0
    for i in range(0, len(sigs_zsc_plotadj[0, :])):
        plot_sigs_vismask[:, i] = np.ma.masked_inside(sigs_zsc[:, i], -1.5,
                                                      1.5)
        sigs_zsc_plotadj[:, i] = (sigs_zsc[:, i] * 1.5) + (i * 1.5)
        p, = ax_trace.plot(np.nan, np.nan, linewidth=1.5)
        p.set_color(cmap_trace_bg[i])
        spf_traces.append(p)
        if i == 0:
            traces_ymin = np.nanmin(sigs_zsc_plotadj[:, i])
        elif i == len(sigs_zsc_plotadj[0, :]) - 1:
            traces_ymax = np.nanmax(sigs_zsc_plotadj[:, i])

    # we intentionally do not set the actual limits here, as peaks that go way above the range
    # would otherwise make any other traces with smaller signals less visible. For a quick overview,
    # this display is good enough
    ax_trace.set_ylim(ymin=traces_ymin, ymax=traces_ymax)
    ax_trace.set_xlim(xmin=trace_range * -1, xmax=(trace_range + 1))
    ax_trace_vline = ax_trace.axvline(x=0,
                                      linestyle=':',
                                      linewidth=1,
                                      zorder=100)
    log.info('Trace plot limits calculated: {} to {}'.format(
        traces_ymax, traces_ymin))

    render_proglog_interval = 60 * fps
    log.info('Writing video')
    frames_n = frames_res.shape[0]
    mwriter = FFMpegWriter(fps=fps, metadata=metadata, codec=codec)
    with mwriter.saving(fig, movie_fname, 60):
        for i in range(0, frames_n):
            fig.suptitle('frame: {:06d}'.format(i),
                         fontsize=12,
                         ha='right',
                         x=0.98)

            # raw input frame
            spf_raw.set_data(np.transpose(frames_raw[i, :, :]))

            # motion-corrected data
            spf_reg.set_data(np.transpose(frames_reg[i, :, :]))

            # visulatization of resulting detected temporal-spatial units
            spf_res.set_data(frames_res[i, :, :].T)

            # plot trace overview
            trace_start = i - trace_range
            trace_end = trace_range + i
            ax_trace.set_xlim(xmin=trace_start, xmax=trace_end)
            if trace_start < 0:
                trace_start = 0
            if trace_end > len(sigs_zsc_plotadj[:, 0]):
                trace_end = len(sigs_zsc_plotadj[:, 0])
            trace_x = np.arange(trace_start, trace_end)
            ax_trace_vline.set_xdata(i)
            for j, spf_trace in enumerate(spf_traces):
                if np.count_nonzero(
                        plot_sigs_vismask[:, j][i - trace_hl_range:i +
                                                trace_hl_range]) >= (
                                                    trace_hl_range / 1.5):
                    spf_trace.set_color(cmap_trace_hl[j])
                else:
                    spf_trace.set_color(cmap_trace_bg[j])
                spf_trace.set_data(
                    trace_x, sigs_zsc_plotadj[:,
                                              j][trace_start:trace_range + i])

            if i % render_proglog_interval == 0:
                log.info('Rendered {} of {} frames'.format(i, frames_n))

            mwriter.grab_frame()
    log.info('Video created successfully ({} frames @ {}fps)'.format(
        frames_n, fps))
示例#18
0
if not os.path.exists(save_dir): os.makedirs(save_dir)
ResultSaver = PoseGraphResultSaver(init_pose=PGM.curr_se3, 
                             save_gap=args.save_gap,
                             num_frames=num_frames,
                             seq_idx=args.sequence_idx,
                             save_dir=save_dir)

# Scan Context Manager (for loop detection) initialization
SCM = ScanContextManager(shape=[args.num_rings, args.num_sectors], 
                                        num_candidates=args.num_candidates, 
                                        threshold=args.loop_threshold)

# for save the results as a video
fig_idx = 1
fig = plt.figure(fig_idx)
writer = FFMpegWriter(fps=15)
video_name = args.sequence_idx + "_" + str(args.num_icp_points) + ".mp4"
num_frames_to_skip_to_show = 5
num_frames_to_save = np.floor(num_frames/num_frames_to_skip_to_show)
with writer.saving(fig, video_name, num_frames_to_save): # this video saving part is optional

    # @@@ MAIN @@@: data stream
    for for_idx, scan_path in tqdm(enumerate(scan_paths), total=num_frames, mininterval=5.0):

        # get current information     
        curr_scan_pts = Ptutils.readScan(scan_path) 
        curr_scan_down_pts = Ptutils.random_sampling(curr_scan_pts, num_points=args.num_icp_points)

        # save current node
        PGM.curr_node_idx = for_idx # make start with 0
        SCM.addNode(node_idx=PGM.curr_node_idx, ptcloud=curr_scan_down_pts)
示例#19
0
def generate_animation(xs,
                       ys,
                       psis,
                       steers,
                       physical_params,
                       dt,
                       puddle_model=None):
    # Initialize the figure and artists
    fig = plt.figure()
    ax = fig.add_subplot(111)

    if puddle_model:
        plot_puddles(ax, puddle_model)

    cg_to_fa, = ax.plot([], [], color="k")
    cg_to_ra, = ax.plot([], [], color="k")
    patch_front = Rectangle((0.0, 0.0),
                            width=physical_params.wheel_length,
                            height=physical_params.wheel_width,
                            color="k")
    patch_rear = Rectangle((0.0, 0.0),
                           width=physical_params.wheel_length,
                           height=physical_params.wheel_width,
                           color="k")
    ax.add_patch(patch_front)
    ax.add_patch(patch_rear)

    def init():
        ax.set_xlim(-2, 10)
        ax.set_ylim(0, 10)
        return cg_to_fa, cg_to_ra, patch_front, patch_rear

    def animate(i):
        ax.set_xlim(-2, 10)
        ax.set_ylim(0, 10)

        # Centers of the front wheel x and y
        front_wheel_xy = np.array(
            [xs[i], ys[i]]) + physical_params.lf * np.array(
                [math.cos(psis[i]), math.sin(psis[i])])
        rear_wheel_xy = np.array(
            [xs[i], ys[i]]) - physical_params.lr * np.array(
                [math.cos(psis[i]), math.sin(psis[i])])
        front_wheel_xy_bl = utils.center_to_botleft(
            front_wheel_xy, psis[i] + steers[i], physical_params.wheel_length,
            physical_params.wheel_width)
        rear_wheel_xy_bl = utils.center_to_botleft(
            rear_wheel_xy, psis[i], physical_params.wheel_length,
            physical_params.wheel_width)

        # Update the patches by using transforms.
        t1 = Affine2D().rotate(psis[i] + steers[i])
        t1.translate(front_wheel_xy_bl[0], front_wheel_xy_bl[1])
        patch_front.set_transform(t1 + ax.transData)

        t2 = Affine2D().rotate(psis[i])
        t2.translate(rear_wheel_xy_bl[0], rear_wheel_xy_bl[1])
        patch_rear.set_transform(t2 + ax.transData)

        # Update the lines.
        cg_to_fa.set_data([xs[i], front_wheel_xy[0]],
                          [ys[i], front_wheel_xy[1]])
        cg_to_ra.set_data([xs[i], rear_wheel_xy[0]], [ys[i], rear_wheel_xy[1]])
        return cg_to_fa, cg_to_ra, patch_front, patch_rear

    ani = FuncAnimation(fig,
                        animate,
                        frames=steers.size,
                        interval=1e3 * dt,
                        blit=True)
    plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
    writer = FFMpegWriter(fps=1.0 / dt)
    # ani.save("animation.mp4", writer=writer, dpi=100)
    plt.show()
示例#20
0
def main():
    global tgt_img, disp_net
    args = parser.parse_args()
    '''加载训练后的模型'''
    weights = torch.load(args.pretrained_posenet)
    seq_length = int(weights['state_dict']['conv1.0.weight'].size(1) / 3)
    pose_net = PoseExpNet(nb_ref_imgs=seq_length - 1,
                          output_exp=False).to(device)
    pose_net.load_state_dict(weights['state_dict'], strict=False)
    # 网络模型的MD5 ID
    net_ID = MD5_ID(args.pretrained_posenet)
    # L和C的转换矩阵,对齐输入位姿到雷达坐标系
    Transform_matrix_L2C = np.identity(4)
    '''Kitti switch'''
    if args.isKitti:
        if not args.isDynamic:
            from kitti_eval.pose_evaluation_utils import test_framework_KITTI as test_framework
        else:
            from kitti_eval.pose_evaluation_utils_forDynamicTest import test_framework_KITTI as test_framework
        save_dir = os.path.join(args.output_dir, "kitti", args.sequences[0],
                                'net_' + net_ID)
        if args.trainedOnMydataset:
            downsample_img_height = args.img_height
            downsample_img_width = args.img_width
        else:
            # on kitti train set
            downsample_img_height = 128
            downsample_img_width = 416

        Transform_matrix_L2C[:3, :3] = np.array(
            [[7.533745e-03, -9.999714e-01, -6.166020e-04],
             [1.480249e-02, 7.280733e-04, -9.998902e-01],
             [9.998621e-01, 7.523790e-03, 1.480755e-02]])
        Transform_matrix_L2C[:3, -1:] = np.array(
            [-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1)
    else:
        from mydataset_eval.pose_evaluation_utils import test_framework_MYDATASET as test_framework
        save_dir = os.path.join(args.output_dir, "mydataset",
                                args.sequences[0], 'net_' + net_ID)
        if args.trainedOnMydataset:
            downsample_img_height = args.img_height
            downsample_img_width = args.img_width
        else:
            # on kitti train set
            downsample_img_height = 128
            downsample_img_width = 416
        Transform_matrix_L2C[:3, :3] = np.array(
            [[-1.51482698e-02, -9.99886648e-01, 5.36310553e-03],
             [-4.65337018e-03, -5.36307196e-03, -9.99969412e-01],
             [9.99870070e-01, -1.56647995e-02, -4.48880010e-03]])
        Transform_matrix_L2C[:3, -1:] = np.array(
            [4.29029924e-03, -6.08539196e-02, -9.20346161e-02]).reshape(3, 1)
    Transform_matrix_L2C = GramSchmidtHelper(Transform_matrix_L2C)
    Transform_matrix_C2L = np.linalg.inv(Transform_matrix_L2C)
    # *************************可删除*********************************
    # 为了进行动态场景下的Mask评估,这里需要引入disp net
    if args.isDynamic:
        from models import DispNetS
        disp_net = DispNetS().to(device)
        weights = torch.load(args.pretrained_dispnet)
        disp_net.load_state_dict(weights['state_dict'])
        disp_net.eval()

    # normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
    #                                         std=[0.5, 0.5, 0.5])
    # valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])
    # from datasets.sequence_folders import SequenceFolder
    # val_set = SequenceFolder(
    #     '/home/sda/mydataset/preprocessing/formatted/data/',
    #     transform=valid_transform,
    #     seed=0,
    #     train=False,
    #     sequence_length=3,
    # )
    # val_loader = torch.utils.data.DataLoader(
    #     val_set, batch_size=1, shuffle=False,
    #     num_workers=4, pin_memory=True)
    #
    # intrinsics = None
    # for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv) in enumerate(val_loader):
    #     intrinsics = intrinsics.to(device)
    #     break
    # *************************************************************************
    '''载入测试数据集'''
    dataset_dir = Path(args.dataset_dir)
    framework = test_framework(dataset_dir, args.sequences, seq_length)
    print('{} snippets to test'.format(len(framework)))
    errors = np.zeros((len(framework), 2), np.float32)
    '''输出到文件夹中的数据'''
    num_poses = len(framework) - (seq_length - 2)
    predictions_array = np.zeros((len(framework), seq_length, 3, 4))
    processing_time = np.zeros((num_poses - 1, 1))
    # 输出文件夹
    save_dir = Path(save_dir)
    print('Output files wiil be saved in: ' + save_dir)
    if not os.path.exists(save_dir): save_dir.makedirs_p()
    # Pose Graph Manager (for back-end optimization) initialization
    PGM = PoseGraphManager()
    PGM.addPriorFactor()
    # Result saver
    num_frames = len(framework)
    ResultSaver = PoseGraphResultSaver(init_pose=PGM.curr_se3,
                                       save_gap=args.save_gap,
                                       num_frames=num_frames,
                                       seq_idx=args.sequences[0],
                                       save_dir=save_dir)

    # for save the results as a video
    fig_idx = 1
    fig = plt.figure(fig_idx)
    writer = FFMpegWriter(fps=15)
    video_path = save_dir + '/' + args.sequences[0] + ".mp4"
    num_frames_to_skip_to_show = 5
    num_frames_to_save = np.floor(num_frames / num_frames_to_skip_to_show)
    with writer.saving(
            fig, video_path,
            num_frames_to_save):  # this video saving part is optional
        for j, sample in enumerate(tqdm(framework)):
            '''
            VO部分
            '''
            imgs = sample['imgs']
            w, h = imgs[0].size
            if (not args.no_resize) and (h != downsample_img_height
                                         or w != downsample_img_width):
                imgs = [
                    imresize(img, (downsample_img_height,
                                   downsample_img_width)).astype(np.float32)
                    for img in imgs
                ]
            imgs = [np.transpose(img, (2, 0, 1)) for img in imgs]

            ref_imgs = []
            for i, img in enumerate(imgs):
                img = torch.from_numpy(img).unsqueeze(0)
                img = ((img / 255 - 0.5) / 0.5).to(device)
                if i == len(imgs) // 2:
                    tgt_img = img
                else:
                    ref_imgs.append(img)

            startTimeVO = time.time()
            _, poses = pose_net(tgt_img, ref_imgs)
            processing_time[j] = (time.time() - startTimeVO) / (seq_length - 1)

            # **************************可删除********************************
            if args.isDynamic:
                '''测试Photo mask的效果'''
                if args.isKitti:
                    intrinsics = [[
                        2.416744631239935472e+02, 0.000000000000000000e+00,
                        2.041680103059581199e+02
                    ],
                                  [
                                      0.000000000000000000e+00,
                                      2.462848682666666491e+02,
                                      5.900083200000000261e+01
                                  ],
                                  [
                                      0.000000000000000000e+00,
                                      0.000000000000000000e+00,
                                      1.000000000000000000e+00
                                  ]]
                else:
                    intrinsics = [[279.1911, 0.0000, 210.8265],
                                  [0.0000, 279.3980, 172.3114],
                                  [0.0000, 0.0000, 1.0000]]
                PhotoMask_Output(_, disp_net, intrinsics, j, poses, ref_imgs,
                                 save_dir)
            # ***************************************************************

            final_poses = pose2tf_mat(args.rotation_mode, imgs, poses)
            predictions_array[j] = final_poses
            # rel_VO_pose取final poses的第2项,整体则是取T10,T21,T32。。。
            rel_VO_pose = np.identity(4)
            rel_VO_pose[:3, :] = final_poses[1]
            # 引入尺度因子对单目VO输出的位姿进行修正,并进行坐标系对齐到雷达坐标系
            scale_factor = 7
            rel_VO_pose[:3, -1:] = rel_VO_pose[:3, -1:] * scale_factor
            rel_VO_pose = Transform_matrix_C2L @ rel_VO_pose @ np.linalg.inv(
                Transform_matrix_C2L)
            rel_VO_pose = GramSchmidtHelper(rel_VO_pose)
            ResultSaver.saveRelativePose(rel_VO_pose)

            PGM.curr_node_idx = j + 1
            PGM.curr_se3 = np.matmul(PGM.curr_se3, rel_VO_pose)
            PGM.addOdometryFactor(rel_VO_pose)
            PGM.prev_node_idx = PGM.curr_node_idx
            ResultSaver.saveUnoptimizedPoseGraphResult(PGM.curr_se3,
                                                       PGM.curr_node_idx)

            # if (j % num_frames_to_skip_to_show == 0):
            #     ResultSaver.vizCurrentTrajectory(fig_idx=fig_idx)
            #     writer.grab_frame()

            if args.isKitti:
                ATE, RE = compute_pose_error(sample['poses'], final_poses)
                errors[j] = ATE, RE
        '''save output files'''
        if save_dir is not None:
            # np.save(save_dir / 'predictions.npy', predictions_array)
            ResultSaver.saveFinalPoseGraphResult(filename='abs_VO_poses.txt')
            ResultSaver.saveRelativePosesResult(filename='rel_VO_poses.txt')
            np.savetxt(save_dir / 'processing_time.txt', processing_time)
            if args.isKitti:
                np.savetxt(save_dir / 'errors.txt', errors)

        mean_errors = errors.mean(0)
        std_errors = errors.std(0)
        error_names = ['ATE', 'RE']
        print('')
        print("Results")
        print("\t {:>10}, {:>10}".format(*error_names))
        print("mean \t {:10.4f}, {:10.4f}".format(*mean_errors))
        print("std \t {:10.4f}, {:10.4f}".format(*std_errors))
示例#21
0
    secondaryhill.center = (cossx[i] - sinsy[i], sinsx[i] + cossy[i])
    impactorhill.center = (cosix[i] - siniy[i], sinix[i] + cosiy[i])
    text.set_text('{} Years'.format(int(times[i] / (year))))
    return primarydot, secondarydot, impactordot, primaryline, secondaryline, impactorline, text, primaryhill, secondaryhill, impactorhill


anim = animation.FuncAnimation(fig,
                               animate,
                               init_func=init,
                               frames=Noutputs,
                               interval=1,
                               blit=True)
# %%
plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
f = f'vid/ps_animation.mp4'
writervideo = FFMpegWriter(fps=10)  # ffmpeg must be installed
anim.save(f, writer=writervideo)
# %%
'''3D ANIMATION OF OUTCOME OF SIMULATION'''
lim = 1
fig = plt.figure(figsize=(9, 9))
axes = fig.add_subplot(111, projection='3d')
axes.set_xlabel("$x/R_\mathrm{h}$")
axes.set_ylabel("$y/R_\mathrm{h}$")
axes.set_zlabel("$z/R_\mathrm{h}$")
axes.set_xlim3d([-lim, lim])
axes.set_ylim3d([-lim, lim])
axes.set_zlim3d([-lim, lim])
primaryline, = axes.plot([], [], [], label="primary", c="teal", lw=lw)
secondaryline, = axes.plot([], [], [], label="secondary", c="hotpink", lw=lw)
impactorline, = axes.plot([], [], [], label="impactor", c="sienna", lw=lw)
示例#22
0
            plt_pos_traj = mins_to_plot[
                min_counter[0]].plot_pos_from_pos_traj_index(i - offset[0])
            return plt_pos_traj

    _save_count = 2000
    anim = FuncAnimation(fig,
                         animate,
                         init_func=init,
                         interval=2,
                         blit=False,
                         save_count=_save_count)

    if save_animation:
        # f = r"c://Users/xx/Desktop/animation.gif"
        writergif = PillowWriter(fps=30)
        writervideo = FFMpegWriter(fps=60)

        # anim.save(f, writer=writergif)

        animation_name_gif = "animation_test.gif"
        animation_name_video = "animation_test123.mp4"
        print(
            "Saving animation. Depending on the choise of 'save_count' this might take some time"
        )
        print(f"Chosen 'save_count' = {_save_count}")
        # anim.save(animation_name, writer=writergif)
        anim.save(animation_name_video, writer=writervideo)
        print(f"Animation saved to {animation_name_video}")
else:
    if plot_propterties:
        env_from_json.plot(ax1_1)
示例#23
0
            pred_score, aleatoric, epistemic, cumsum=False)

        fig, ax = plt.subplots(1, figsize=(24, 3.5))
        fontsize = 25
        plt.ylim(0, 1.1)
        plt.xlim(0, len(xvals) + 1)
        plt.ylabel('Probability', fontsize=fontsize)
        plt.xlabel('Frame (FPS=%d)' % (p.fps), fontsize=fontsize)
        plt.xticks(range(0,
                         len(xvals) + 1, int(p.n_frames / p.fps_display)),
                   fontsize=fontsize)
        plt.yticks(fontsize=fontsize)

        from matplotlib.animation import FFMpegWriter
        curve_writer = FFMpegWriter(fps=p.fps_display,
                                    metadata=dict(title='Movie Test',
                                                  artist='Matplotlib',
                                                  comment='Movie support!'))
        with curve_writer.saving(fig, "demo/curve_video.mp4", 100):
            for t in range(len(xvals)):
                draw_curve(xvals[:(t + 1)], pred_score[:(t + 1)],
                           std_alea[:(t + 1)], std_epis[:(t + 1)])
                curve_writer.grab_frame()
        curve_frames = get_video_frames("demo/curve_video.mp4",
                                        n_frames=p.n_frames)

        # create video writer
        video_writer = cv2.VideoWriter(
            p.vis_file, cv2.VideoWriter_fourcc(*'DIVX'), p.fps_display,
            (video_data[0].shape[1], video_data[0].shape[0]))
        for t, frame in enumerate(video_data):
            det_boxes = detections[t]  # 19 x 6
示例#24
0
    def make_e_movie(self,
                     t_steps,
                     duration,
                     odir,
                     fps=24,
                     name_add='',
                     bed_feedback=True):
        """
        Makes movie of the entrainment field.
        Takes t_steps number of time-steps from *current* state and exports a movie in 'odir' directory that is a maximum of 'duration' seconds long. 
        Note that if the number of frames, in combination with the frames per second, makes a duration less than 20 seconds then it will be 1 fps and will last frames seconds long. 
        You can also add to the end of the name with the command 'name_add=_(your name here)' (make sure to include the underscore).
        """
        # For saving
        # import matplotlib as mpl
        # matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        # Resets any externally exposed parameters for movie (otherwise movie might look weird)
        plt.rcParams.update(plt.rcParamsDefault)
        import matplotlib.animation as animation
        from matplotlib.animation import FFMpegWriter
        from IPython.display import clear_output

        # Calculate how many steps to skip before each save
        dt_frame = np.max((int((t_steps) / (fps * duration)), 1))

        ### Make the data:
        es = [self.e]
        dt = 0
        for frame in tqdm.tqdm(range(t_steps)):
            self.step(bed_feedback=bed_feedback)
            dt += 1
            if dt % dt_frame == 0:
                dt = 0
                es.append(self.e)

        es = np.array(es)

        n_frames = len(es)

        # create a figure with two subplots
        fig = plt.figure(figsize=(np.min((4 * float(self.Nx / self.Ny), 20)),
                                  4))

        # initialize two axes objects (one in each axes)
        im_e = plt.imshow(es[0], vmin=0, vmax=1, cmap='binary',
                          aspect=1)  #float(self.Ny/self.Nx))

        # set titles and labels
        ax = plt.gca()
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params(axis='both', bottom=False, left=False)
        ax.set_title("Entrainment Field", fontsize=30)
        plt.tight_layout()

        ### Animate function
        def animate(frame):
            """
            Animation function. Takes the current frame number (to select the potion of
            data to plot) and a plot object to update.
            """

            print("Working on frame %s of %s" % (frame + 1, len(es)))
            clear_output(wait=True)

            im_e.set_array(es[frame])

            return im_e

        sim = animation.FuncAnimation(
            # Your Matplotlib Figure object
            fig,
            # The function that does the updating of the Figure
            animate,
            # Frame information (here just frame number)
            np.arange(n_frames),
            # Extra arguments to the animate function
            fargs=[],
            # Frame-time in ms; i.e. for a given frame-rate x, 1000/x
            interval=1000 / fps)

        # Try to set the DPI to the actual number of pixels you're plotting
        writer = FFMpegWriter(fps=fps,
                              metadata=dict(artist='Me'),
                              bitrate=1800)
        name = odir + self.export_name() + name_add + '_e.mp4'
        sim.save(name, dpi=300, writer=writer)

        return
示例#25
0
def make_tracking_video(vid_path,
                        csv_path,
                        output_fname="Tracking.avi",
                        start=0,
                        stop=None,
                        fps=30):
    """
    Makes a video to visualize licking at water ports and position of the animal.

    :parameters
    ---
    video_fname: str
        Full path to the behavior video.

    csv_fname: str
        Full path to the csv file from EZTrack.

    output_fname: str
        Desired file name for output. It will be saved to the same folder
        as the data.

    start: int
        Frame to start on.

    stop: int or None
        Frame to stop on or if None, the end of the movie.

    fps: int
        Sampling rate of the behavior camera.
    """
    # Get behavior video.
    vid = cv2.VideoCapture(vid_path)
    if stop is None:
        stop = int(vid.get(7))  # 7 is the index for total frames.

    # Save data to the same folder.
    folder = os.path.split(vid_path)[0]
    output_path = os.path.join(folder, output_fname)

    # Get EZtrack data.
    eztrack = read_eztrack(csv_path)

    # Make video.
    fig, ax = plt.subplots()
    writer = FFMpegWriter(fps=fps)
    with writer.saving(fig, output_path, 100):
        for frame_number in np.arange(start, stop):
            # Plot frame.
            vid.set(1, frame_number)
            ret, frame = vid.read()
            ax.imshow(frame)

            # Plot position.
            x = eztrack.at[frame_number, "x"]
            y = eztrack.at[frame_number, "y"]
            ax.scatter(x, y, marker="+", s=60, c="r")

            ax.text(
                0,
                0,
                "Frame: " + str(frame_number) + "   Time: " +
                str(np.round(frame_number / 30, 1)) + " s",
            )

            ax.set_aspect("equal")
            plt.axis("off")

            writer.grab_frame()

            plt.cla()
示例#26
0
    def make_panel_movie(self,
                         t_steps,
                         duration,
                         odir,
                         fps=24,
                         name_add='',
                         bed_feedback=True):
        """
        Makes movie of entrainment field, y-averaged height, and bed activity. Note that the entrainment field's aspect ratio will be adjusted to fit.
        Takes t_steps number of time-steps from *current* state and exports a movie in 'odir' directory that is a maximum of 'duration' seconds long. 
        Note that if the number of frames, in combination with the frames per second, makes a duration less than 20 seconds then it will be 1 fps and will last frames seconds long. 
        You can also add to the end of the name with the command 'name_add=_(your name here)' (make sure to include the underscore).
        """
        # For saving
        # import matplotlib as mpl
        # matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        # Resets any externally exposed parameters for movie (otherwise movie might look weird)
        plt.rcParams.update(plt.rcParamsDefault)
        import matplotlib.animation as animation
        from matplotlib.animation import FFMpegWriter
        from IPython.display import clear_output

        # Calculate how many steps to skip before each save
        dt_frame = np.max((int((t_steps) / (fps * duration)), 1))

        ### Make the data:
        zs = [np.mean(self.z, axis=0)]
        es = [self.e]
        qs = [self.bed_activity()]
        ts = [self.t]
        dt = 0
        for frame in tqdm.tqdm(range(t_steps)):
            self.step(bed_feedback=bed_feedback)
            qs.append(self.bed_activity())
            ts.append(self.t)
            dt += 1
            if dt % dt_frame == 0:
                dt = 0
                zs.append(np.mean(self.z, axis=0))
                es.append(self.e)

        zs = np.array(zs)
        es = np.array(es)
        qs = np.array(qs)

        n_frames = len(zs)

        # create a figure with two subplots
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(4, 4))

        # initialize two axes objects (one in each axes)
        im_e = ax1.imshow(es[0],
                          vmin=0,
                          vmax=1,
                          cmap='binary',
                          aspect=self.Nx / (5 * self.Ny))
        im_z, = ax2.plot(self.x, zs[-1], '-k')
        im_q, = ax3.plot(ts, np.zeros(len(qs)), '-k', lw=1)

        # set titles and labels
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        ax1.tick_params(axis='both', bottom=False, left=False)
        ax1.set_title("Entrainment Field")
        ax2.set_ylabel("Height (grain radii)")
        ax2.set_xlabel(r"$x$")
        ax2.set_xlim(0, 1)
        ax2.set_ylim(np.min(zs[-1]), np.max(zs[-1]))
        ax3.set_ylabel(r"Bed activity")
        ax3.set_xlabel(r"$t$")
        # ax3.axhline(y=self.q_in/self.Ny,ls='--',color='k')
        ax3.set_ylim(0, np.max(qs))
        ax3.set_xlim(ts[0], ts[-1])
        plt.tight_layout()

        ### Animate function
        def animate(frame):
            """
            Animation function. Takes the current frame number (to select the potion of
            data to plot) and a plot object to update.
            """

            print("Working on frame %s of %s" % (frame + 1, len(zs)))
            clear_output(wait=True)

            q_temp = np.zeros(len(qs))
            q_temp[:frame * dt_frame] = qs[:frame * dt_frame]

            im_q.set_ydata(q_temp)
            im_e.set_array(es[frame])
            im_z.set_ydata(zs[frame])

            return im_e, im_z, im_q

        sim = animation.FuncAnimation(
            # Your Matplotlib Figure object
            fig,
            # The function that does the updating of the Figure
            animate,
            # Frame information (here just frame number)
            np.arange(n_frames),
            # Extra arguments to the animate function
            fargs=[],
            # Frame-time in ms; i.e. for a given frame-rate x, 1000/x
            interval=1000 / fps)

        # Try to set the DPI to the actual number of pixels you're plotting
        writer = FFMpegWriter(fps=fps,
                              metadata=dict(artist='Me'),
                              bitrate=1800)
        name = odir + self.export_name() + name_add + '_panel.mp4'
        sim.save(name, dpi=300, writer=writer)

        return
示例#27
0
文件: plot.py 项目: razh431/Bee
axs[2].plot(x, total_angle)
axs[2].set_title('total angle')

#graphing LEFT antenna length, total angle change, and orientation rate
fig1, axs1 = plt.subplots(3)
fig1.suptitle('plots: ' + str(crop) + ', bee number: ' + str(vid_num))

axs1[0].plot(x, antl_length)
axs1[0].set_title('antl_length')

axs1[1].plot(x, delta_tot_angle)
axs1[1].set_title('change in angle')

axs1[2].plot(x, ori_angle)
axs1[2].set_title('orientation')

plt.xlabel('frame number')
# plt.ylabel('y - axis'

metadata = dict(title='Bee Graphs')
writer = FFMpegWriter(fps=15, metadata=metadata)

#save vid
with writer.saving(fig,
                   str(crop) + ', bee number: ' + str(vid_num) + ".mp4", 100):
    for i in x:
        writer.grab_frame()

plt.show()
示例#28
0
valid_pos = [(0, 0)] * len(tracking_df)
pos_idx = 0
caimg_frame = 0

ca_images = np.concatenate([video.load_images(f) for f in local_mmap_fpaths], axis=0)
ca_movie = ca_images
mean_movie = np.mean(ca_movie, axis=0)

maxmov = int(np.nanpercentile(ca_movie[10:50], 80))
minmov = int(np.nanpercentile(ca_movie[10:50], 2))
caimg_frame = None
caimg_timestamps[0] = -1
caimg_frame_index = 0
caimg_timestamp = -1

writer = FFMpegWriter(fps=15 * 3, metadata=dict(title='Test'))
fig = plt.figure()
vid_ax = plt.subplot(2, 1, 1)
ca_vid_plt = plt.imshow(np.zeros(video_dims).transpose())
vid_ax.set_axis_off()

trace_ax = plt.subplot(2, 1, 2)
window_sec = 4
plt.xlim(-window_sec, window_sec)
C = cnm_obj.estimates.C
cell_contours = video.create_contours(cnm_obj.estimates.A[:,selected_cells],
                                      ca_movie.shape[1:], thr=0.75)
trace_vid_plt = [0] * len(selected_cells)
contour_colours = dict()
# length of the plotted trace. Extend the window by 1 sec because the timestamps aren't evenly spread.
trace_len = 2 * (window_sec + 1) * caimg_frame_rate + 1
def CreateVideoSlow(
    videooutname,
    clip,
    Dataframe,
    tmpfolder,
    dotsize,
    colormap,
    alphavalue,
    pcutoff,
    trailpoints,
    cropping,
    x1,
    x2,
    y1,
    y2,
    save_frames,
    bodyparts2plot,
    outputframerate,
    Frames2plot,
    bodyparts2connect,
    skeleton_color,
    draw_skeleton,
    displaycropped,
    color_by,
):
    """Creating individual frames with labeled body parts and making a video"""
    # scorer=np.unique(Dataframe.columns.get_level_values(0))[0]
    # bodyparts2plot = list(np.unique(Dataframe.columns.get_level_values(1)))

    if displaycropped:
        ny, nx = y2 - y1, x2 - x1
    else:
        ny, nx = clip.height(), clip.width()

    fps = clip.fps()
    if outputframerate is None:  # by def. same as input rate.
        outputframerate = fps

    nframes = clip.nframes
    duration = nframes / fps

    print("Duration of video [s]: {}, recorded with {} fps!".format(
        round(duration, 2), round(fps, 2)))
    print(
        "Overall # of frames: {} with cropped frame dimensions: {} {}".format(
            nframes, nx, ny))
    print("Generating frames and creating video.")
    df_x, df_y, df_likelihood = Dataframe.values.reshape(
        (len(Dataframe), -1, 3)).T
    if cropping and not displaycropped:
        df_x += x1
        df_y += y1

    bpts = Dataframe.columns.get_level_values("bodyparts")
    all_bpts = bpts.values[::3]
    if draw_skeleton:
        bpts2connect = get_segment_indices(bodyparts2connect, all_bpts)

    bplist = bpts.unique().to_list()
    nbodyparts = len(bplist)
    if Dataframe.columns.nlevels == 3:
        nindividuals = 1
        map2bp = list(range(len(all_bpts)))
        map2id = [0 for _ in map2bp]
    else:
        nindividuals = len(
            Dataframe.columns.get_level_values("individuals").unique())
        map2bp = [bplist.index(bp) for bp in all_bpts]
        nbpts_per_ind = (
            Dataframe.groupby(level="individuals", axis=1).size().values // 3)
        map2id = []
        for i, j in enumerate(nbpts_per_ind):
            map2id.extend([i] * j)
    keep = np.flatnonzero(np.isin(all_bpts, bodyparts2plot))
    bpts2color = [(ind, map2bp[ind], map2id[ind]) for ind in keep]
    if color_by == "individual":
        colors = visualization.get_cmap(nindividuals, name=colormap)
    else:
        colors = visualization.get_cmap(nbodyparts, name=colormap)

    nframes_digits = int(np.ceil(np.log10(nframes)))
    if nframes_digits > 9:
        raise Exception(
            "Your video has more than 10**9 frames, we recommend chopping it up."
        )

    if Frames2plot is None:
        Index = set(range(nframes))
    else:
        Index = {int(k) for k in Frames2plot if 0 <= k < nframes}

    # Prepare figure
    prev_backend = plt.get_backend()
    plt.switch_backend("agg")
    dpi = 100
    fig = plt.figure(frameon=False, figsize=(nx / dpi, ny / dpi))
    ax = fig.add_subplot(111)

    writer = FFMpegWriter(fps=outputframerate, codec="h264")
    with writer.saving(fig, videooutname,
                       dpi=dpi), np.errstate(invalid="ignore"):
        for index in trange(min(nframes, len(Dataframe))):
            imagename = tmpfolder + "/file" + str(index).zfill(
                nframes_digits) + ".png"
            image = img_as_ubyte(clip.load_frame())
            if index in Index:  # then extract the frame!
                if cropping and displaycropped:
                    image = image[y1:y2, x1:x2]
                ax.imshow(image)

                if draw_skeleton:
                    for bpt1, bpt2 in bpts2connect:
                        if np.all(
                                df_likelihood[[bpt1, bpt2], index] > pcutoff):
                            ax.plot(
                                [df_x[bpt1, index], df_x[bpt2, index]],
                                [df_y[bpt1, index], df_y[bpt2, index]],
                                color=skeleton_color,
                                alpha=alphavalue,
                            )

                for ind, num_bp, num_ind in bpts2color:
                    if df_likelihood[ind, index] > pcutoff:
                        if color_by == "bodypart":
                            color = colors(num_bp)
                        else:
                            color = colors(num_ind)
                        if trailpoints > 0:
                            ax.scatter(
                                df_x[ind][max(0, index - trailpoints):index],
                                df_y[ind][max(0, index - trailpoints):index],
                                s=dotsize**2,
                                color=color,
                                alpha=alphavalue * 0.75,
                            )
                        ax.scatter(
                            df_x[ind, index],
                            df_y[ind, index],
                            s=dotsize**2,
                            color=color,
                            alpha=alphavalue,
                        )
                ax.set_xlim(0, nx)
                ax.set_ylim(0, ny)
                ax.axis("off")
                ax.invert_yaxis()
                fig.subplots_adjust(left=0,
                                    bottom=0,
                                    right=1,
                                    top=1,
                                    wspace=0,
                                    hspace=0)
                if save_frames:
                    fig.savefig(imagename)
                writer.grab_frame()
                ax.clear()

    print("Labeled video {} successfully created.".format(videooutname))
    plt.switch_backend(prev_backend)
示例#30
0
文件: plot.py 项目: zcl-maker/duat
    def time_1d_animation(self, output_path=None, dataset_selector=None, axes_selector=None, time_selector=None,
                          dpi=200, fps=1, scale_mode="expand",
                          latex_label=True, interval=200):
        """
        Generate a plot of 1d data animated in time.
        
        If an output path with a suitable extension is supplied, the method will export it. Available formats are mp4
        and gif. The returned objects allow for minimal customization and representation. For example in Jupyter you
        might use `IPython.display.HTML(animation.to_html5_video())`, where `animation` is the returned `FuncAnimation`
        instance.
        
        Note:
            Exporting a high resolution animated gif with many frames might eat your RAM.

        Args:
            output_path (str): The place where the plot is saved. If "" or None, the plot is shown in matplotlib.
            dataset_selector: See :func:`~duat.osiris.plot.Diagnostic.get_generator` method.
            axes_selector: See :func:`~duat.osiris.plot.Diagnostic.get_generator` method.
            time_selector: See :func:`~duat.osiris.plot.Diagnostic.get_generator` method.
            interval (float): Delay between frames in ms. If exporting to mp4, the fps is used instead to generate the
                              file, although the returned objects do use this value.
            dpi (int): The resolution of the frames in dots per inch (only if exporting).
            fps (int): The frames per seconds (only if exporting to mp4).
            scale_mode (str): How the scale is changed through time. Available methods are:

                * "expand": The y limits increase when needed, but they don't decrease.
                * "adjust_always": Always change the y limits to those of the data.
                * "max": Use the maximum range from the beginning.

            latex_label (bool): Whether for use LaTeX code for the plot.
            
        Returns:
            (`matplotlib.figure.Figure`, `matplotlib.axes.Axes`, `matplotlib.animation.FuncAnimation`):
            Objects representing the generated plot and its animation.
            
        Raises:
            FileNotFoundError: If tried to export to mp4 but ffmpeg is not found in the system.

        """
        if output_path:
            ensure_dir_exists(os.path.dirname(output_path))
        axes = self.get_axes(dataset_selector=dataset_selector, axes_selector=axes_selector)
        if len(axes) != 1:
            raise ValueError("Expected 1 axis plot, but %d were provided" % len(axes))
        axis = axes[0]

        gen = self.get_generator(dataset_selector=dataset_selector, axes_selector=axes_selector,
                                 time_selector=time_selector)

        # Set plot labels
        fig, ax = plt.subplots()
        fig.set_tight_layout(True)

        x_name = axis["LONG_NAME"]
        x_units = axis["UNITS"]
        y_name = self.data_name
        y_units = self.units

        ax.set_xlabel(_create_label(x_name, x_units, latex_label))
        ax.set_ylabel(_create_label(y_name, y_units, latex_label))

        # Plot the points
        x_min, x_max = axis["MIN"], axis["MAX"]
        plot_data, = ax.plot(axis["LIST"], next(gen))
        ax.set_xlim(x_min, x_max)

        if scale_mode == "max":
            # Get a list (generator) with the mins and maxs in each time step
            min_max_list = map(lambda l: [min(l), max(l)],
                               self.get_generator(dataset_selector=dataset_selector, axes_selector=axes_selector,
                                                  time_selector=time_selector))
            f = lambda mins, maxs: (min(mins), max(maxs))
            y_min, y_max = f(*zip(*min_max_list))
            ax.set_ylim(y_min, y_max)

        time_list = self.get_time_list(time_selector)

        # Prepare a function for the updates
        def update(i):
            """Update the plot, returning the artists which must be redrawn."""
            try:
                new_dataset = next(gen)
            except StopIteration:
                logger.warning("Tried to add a frame to the animation, but all data was used.")
                return
            label = 't = {0}'.format(time_list[i])
            plot_data.set_ydata(new_dataset[:])
            ax.set_title(label)
            if not scale_mode or scale_mode == "max":
                pass
            elif scale_mode == "expand":
                prev = ax.get_ylim()
                data_limit = [min(new_dataset), max(new_dataset)]
                ax.set_ylim(min(prev[0], data_limit[0]), max(prev[1], data_limit[1]))
            elif scale_mode == "adjust_always":
                ax.set_ylim(min(new_dataset), max(new_dataset))
            return plot_data, ax

        anim = FuncAnimation(fig, update, frames=range(1, len(time_list) - 2), interval=interval)

        if not output_path:  # "" or None
            pass
        else:
            filename = os.path.basename(output_path)
            if "." in filename:
                extension = output_path.split(".")[-1].lower()
            else:
                extension = None
            if extension == "gif":
                anim.save(output_path, dpi=dpi, writer='imagemagick')
            elif extension == "mp4":
                metadata = dict(title=os.path.split(self.data_path)[-1], artist='duat', comment=self.data_path)
                writer = FFMpegWriter(fps=fps, metadata=metadata)
                with writer.saving(fig, output_path, dpi):
                    # Iterate over frames
                    for i in range(1, len(time_list) - 1):
                        update(i)
                        writer.grab_frame()
                    # Keep showing the last frame for the fixed time
                    writer.grab_frame()
            else:
                logger.warning("Unknown extension in path %s. No output produced." % output_path)

        plt.close()

        return fig, ax, anim