Esempio n. 1
0
def _show_plane_dist(args, distances, pose_glob, kp_id_list, frame_diff=25):
    from utils.mpl_setup import plt_figure
    from utils.plot_util import plot_skel_3d, plot_origin
    from matplotlib.widgets import Slider

    # set up figure
    num_fig = 2
    if args.video_file_name is not None:
        num_fig = 3
    plt, fig, axes = plt_figure(num_fig, is_3d_axis=[0])
    ax_slider = fig.add_axes(
        [0.2, 0.05, 0.65,
         0.03])  # left, bottom, width, height in fractions of figure w/h
    slider_ind = Slider(ax_slider,
                        'Fid',
                        0,
                        pose_glob.shape[0] - frame_diff,
                        valinit=0,
                        valfmt='%d')

    # update callback
    global last_i
    last_i = -1

    def _update(_):
        global last_i
        # get index for this sample
        i = int(slider_ind.val)

        if last_i == i:
            return
        last_i = i

        # clear content
        for ax in axes:
            ax.clear()

        skel_n = pose_glob[i].copy()
        mean = skel_n.mean(0, keepdims=True)
        skel_n -= mean
        skel_n_t = pose_glob[i + frame_diff].copy() - mean
        plot_skel_3d(axes[0], model, skel_n, color_fixed='r')
        plot_skel_3d(axes[0], model, skel_n_t, color_fixed='g')
        axes[0].set_title('global (centered)')
        plot_origin(axes[0])
        axes[0].view_init(elev=-60, azim=-90)

        for j in kp_id_list:
            t = distances[i:(i + frame_diff), j]
            axes[1].plot(t, label='%d' % j)
        axes[1].legend()

        if args.video_file_name is not None:
            img = read_vid_frame(args.video_file_name, i)
            axes[2].imshow(img[:, :, ::-1])
            axes[2].xaxis.set_visible(False), axes[2].yaxis.set_visible(False)
        fig.canvas.draw_idle()

    slider_ind.on_changed(_update)
    plt.show()
Esempio n. 2
0
def _show_pairwise_dist(pairwise_dist, kp_pair_list, total_num_kp=12):
    from utils.mpl_setup import plt_figure
    dist_all = list()
    for vid_name, vid_data in pairwise_dist.items():
        for p in vid_data:
            if p is not None:
                dist_all.append(p)

    dist_all = np.array(dist_all)

    # figure out pairs
    kp_pair_list = [tuple(x) for x in kp_pair_list]
    cnt = 0
    show_tasks = list()
    for i in range(total_num_kp):
        for j in range(i + 1, total_num_kp):
            if (i, j) in kp_pair_list:
                show_tasks.append([cnt, i, j])
            cnt += 1

    plt, fig, axes = plt_figure(len(kp_pair_list))
    colors = ['r', 'g', 'b', 'c', 'm', 'k']
    for f, (c, (k, i, j)) in enumerate(zip(colors, show_tasks)):
        hist, edges = np.histogram(dist_all[:, k])
        hist = hist / float(np.sum(hist))
        bin_centers = 0.5 * (edges[1:] + edges[:-1])
        axes[f].stem(bin_centers, hist, c, label='%d-%d' % (i, j))
        axes[f].legend()
    plt.show()
Esempio n. 3
0
def _show_stft(args, stft, pose_glob, frame_diff=25):
    from utils.mpl_setup import plt_figure
    from utils.plot_util import plot_skel_3d, plot_origin
    from matplotlib.widgets import Slider

    # set up figure
    num_fig = 2
    if args.video_file_name is not None:
        num_fig = 3
    plt, fig, axes = plt_figure(num_fig, is_3d_axis=[0])
    ax_slider = fig.add_axes(
        [0.2, 0.05, 0.65,
         0.03])  # left, bottom, width, height in fractions of figure w/h
    slider_ind = Slider(ax_slider,
                        'Fid',
                        0,
                        pose_glob.shape[0],
                        valinit=0,
                        valfmt='%d')

    # update callback
    global last_i
    last_i = -1

    def _update(_):
        global last_i
        # get index for this sample
        i = int(slider_ind.val)

        if last_i == i:
            return
        last_i = i

        # clear content
        for ax in axes:
            ax.clear()

        skel_n = pose_glob[i].copy()
        mean = skel_n.mean(0, keepdims=True)
        skel_n -= mean
        plot_skel_3d(axes[0], model, skel_n)
        axes[0].set_title('global (centered)')
        plot_origin(axes[0])
        axes[0].view_init(elev=-60, azim=-90)

        for ax in [axes[0]]:
            ax.set_xlim([-0.15,
                         0.15]), ax.set_ylim([-0.15, 0.15
                                              ]), ax.set_zlim([-0.15, 0.15])
            ax.set_xlabel('x'), ax.set_ylabel('y'), ax.set_zlabel('z')

        s, e = i - frame_diff, i + frame_diff
        s, e = max(0, s), min(stft.shape[1] - 1, e)
        axes[1].imshow(np.log(1 + stft[:, s:e]))
        axes[1].set_xlabel('time'), axes[1].set_ylabel('freq')

        if args.video_file_name is not None:
            img = read_vid_frame(args.video_file_name, i)
            axes[2].imshow(img[:, :, ::-1])
            axes[2].xaxis.set_visible(False), axes[2].yaxis.set_visible(False)
        fig.canvas.draw_idle()

    slider_ind.on_changed(_update)
    plt.show()
Esempio n. 4
0
def _show_coord_frame(args, name, pose_glob, pose_local):
    from utils.mpl_setup import plt_figure
    from utils.plot_util import plot_skel_3d, plot_origin, plot_setup
    from matplotlib.widgets import Slider

    # set up figure
    num_fig = 3
    if args.video_file_name is not None:
        num_fig = 4
    plt, fig, axes = plt_figure(num_fig, is_3d_axis=[0, 1, 2])
    ax_slider = fig.add_axes(
        [0.2, 0.05, 0.65,
         0.03])  # left, bottom, width, height in fractions of figure w/h
    slider_ind = Slider(ax_slider,
                        'Fid',
                        0,
                        pose_glob.shape[0],
                        valinit=0,
                        valfmt='%d')

    # update callback
    global last_i
    last_i = -1

    def _update(_):
        global last_i
        # get index for this sample
        i = int(slider_ind.val)

        if last_i == i:
            return
        last_i = i

        # clear content
        for ax in axes:
            ax.clear()

        skel = pose_glob[i].copy()
        plot_skel_3d(axes[0], model, skel), axes[0].set_title('global')
        plot_setup(axes[0])
        plot_origin(axes[0])
        axes[0].view_init(elev=-60, azim=-90)

        skel_n = pose_glob[i].copy()
        skel_n -= skel_n.mean(0, keepdims=True)
        plot_skel_3d(axes[1], model,
                     skel_n), axes[1].set_title('global (centered)')
        plot_origin(axes[1])
        axes[1].view_init(elev=-60, azim=-90)

        skel_loc = pose_local[i].copy()
        plot_skel_3d(axes[2], model,
                     skel_loc), axes[2].set_title('local: %s' % name)
        plot_origin(axes[2])

        for ax in [axes[1], axes[2]]:
            ax.set_xlim([-0.15,
                         0.15]), ax.set_ylim([-0.15, 0.15
                                              ]), ax.set_zlim([-0.15, 0.15])
            ax.set_xlabel('x'), ax.set_ylabel('y'), ax.set_zlabel('z')

        if args.video_file_name is not None:
            img = read_vid_frame(args.video_file_name, i)
            axes[3].imshow(img[:, :, ::-1])
            axes[3].xaxis.set_visible(False), axes[3].yaxis.set_visible(False)
        fig.canvas.draw_idle()

    slider_ind.on_changed(_update)
    plt.show()
    def get_thumb(self, tid):
        assert 0 <= tid < self.vid_size, 'Out of video range.'
        return self.thumbs[tid]

    def get_fs(self, tid):
        assert 0 <= tid < self.vid_size, 'Out of video range.'
        return read_vid_frame(self.path, tid)


if __name__ == '__main__':
    # reader = VideoThumbnailReader('/misc/lmbraid18/zimmermc/datasets/ExampleData/run00_cam1.mp4')
    # 17684 frames read in ~ 1min
    reader = VideoThumbnailReader(
        '/misc/lmbraid18/zimmermc/datasets/RatTrack_set4/Peller_dispencer/run00_cam1.mp4'
    )

    import time
    from utils.mpl_setup import plt_figure
    import numpy as np

    while True:
        i = np.random.randint(17684)

        plt, fig, axes = plt_figure(2)
        s = time.time()
        axes[0].imshow(reader.get_thumb(i))
        print('Time for thumb %.3f' % (time.time() - s))
        s = time.time()
        axes[1].imshow(reader.get_fs(i))
        print('Time for thumb %.3f' % (time.time() - s))
        plt.show()