示例#1
0
文件: base.py 项目: alcinos/dps
    def visualize(self, n=4):
        sample = self.sample(n)
        images = sample["image"]
        # annotations, annotations_shape, annotations_mask = sample["annotations"]
        labels = sample["label"]

        if self.n_frames == 0:
            images = images[:, None]

        fig, *_ = animate(images, labels=labels)

        plt.show()
        plt.close(fig)
示例#2
0
文件: atari.py 项目: alcinos/dps
    def visualize(self, n=4):
        sample = self.sample(n)
        images = sample["image"]
        actions = sample["action"]
        rewards = sample["reward"]

        labels = [
            "actions={}, rewards={}".format(a, r)
            for a, r in zip(actions, rewards)
        ]

        fig, *_ = animate(images, labels=labels)

        plt.show()
        plt.close(fig)
示例#3
0
    def _plot_reconstruction(self, updater, fetched):
        inp = fetched['inp']
        output = fetched['output']

        fig_height = 20
        fig_width = 4.5 * fig_height

        diff = self.normalize_images(np.abs(inp - output).sum(axis=-1, keepdims=True) / output.shape[-1])
        xent = self.normalize_images(xent_loss(pred=output, label=inp, tf=False).sum(axis=-1, keepdims=True))

        path = self.path_for("animation", updater, ext=None)
        fig, axes, anim, path = animate(
            inp, output, diff.astype('f'), xent.astype('f'),
            figsize=(fig_width, fig_height), path=path)
        plt.close()
示例#4
0
    def __call__(self, updater):
        self.fetches = "inp output"

        fetched = self._fetch(updater)
        fetched = Config(fetched)

        inp = fetched['inp']
        output = fetched['output']
        T = inp.shape[1]
        mean_image = np.tile(inp.mean(axis=1, keepdims=True), (1, T, 1, 1, 1))

        B = inp.shape[0]

        fig_unit_size = 3

        fig_height = B * fig_unit_size
        fig_width = 7 * fig_unit_size

        diff = self.normalize_images(
            np.abs(inp - output).sum(axis=-1, keepdims=True))
        xent = self.normalize_images(
            xent_loss(pred=output, label=inp, tf=False).sum(axis=-1,
                                                            keepdims=True))

        diff_mean = self.normalize_images(
            np.abs(mean_image - output).sum(axis=-1, keepdims=True))
        xent_mean = self.normalize_images(
            xent_loss(pred=mean_image, label=inp, tf=False).sum(axis=-1,
                                                                keepdims=True))

        path = self.path_for("animation", updater, ext=None)

        fig, axes, anim, path = animate(inp,
                                        output,
                                        diff.astype('f'),
                                        xent.astype('f'),
                                        mean_image,
                                        diff_mean.astype('f'),
                                        xent_mean.astype('f'),
                                        figsize=(fig_width, fig_height),
                                        path=path,
                                        square_grid=False)
        plt.close()
示例#5
0
文件: atari.py 项目: lqiang2003cn/dps
    def visualize(self, n=None):
        sample = self.sample(self.batch_size * self.n_batches)
        images = [[] for i in range(self.batch_size)]

        for i in range(self.n_batches):
            for j in range(self.batch_size):
                images[j].append(sample['image'][i * self.batch_size + j])

        images = np.array([np.concatenate(stream) for stream in images])

        fig, *_ = animate(images)
        plt.subplots_adjust(top=0.95,
                            bottom=0,
                            left=0,
                            right=1,
                            wspace=0.05,
                            hspace=0.1)

        plt.show()
        plt.close(fig)
示例#6
0
文件: atari.py 项目: lqiang2003cn/dps
    def visualize(self, n=4):
        sample = self.sample(n)
        images = sample["image"]
        actions = sample["action"]
        rewards = sample["reward"]
        annotations = variable_shape_array_to_list(sample["annotations"])

        labels = [
            "actions={}, rewards={}".format(a, r)
            for a, r in zip(actions, rewards)
        ]

        fig, *_ = animate(images, labels=labels, annotations=annotations)
        plt.subplots_adjust(top=0.95,
                            bottom=0,
                            left=0,
                            right=1,
                            wspace=0.05,
                            hspace=0.1)

        plt.show()
        plt.close(fig)
示例#7
0
文件: gqn.py 项目: lqiang2003cn/dps
    def visualize(self, n=4):
        sample = self.sample(n)
        images = sample["image"]
        pose_r = sample["pose_r"]
        pose_t = sample["pose_t"]
        indices = sample["idx"]

        with numpy_print_options(precision=2):
            text = [[
                "idx={}\nt={}\npose_r={}\npose_t={}".format(i, t, _pr, _pt)
                for t, (_pr, _pt) in enumerate(zip(pr, pt))
            ] for i, pr, pt in zip(indices, pose_r, pose_t)]

        fig, _, anim, _ = animate(images, text=text, fig_unit_size=4)

        path = 'gqn_visualization.mp4'
        anim.save(path,
                  writer='ffmpeg',
                  codec='hevc',
                  extra_args=['-preset', 'ultrafast'])

        plt.show()
        plt.close(fig)