Ejemplo n.º 1
0
    def _after_reset(self, observation: ClassIncrementalSetting.Observations):
        image_batch = observation.numpy().x
        # Need to create a single image with the right dtype for the Monitor
        # from gym to create gifs / videos with it.
        if self.batch_size:
            # Need to tile the image batch so it can be seen as a single image
            # by the Monitor.
            image_batch = tile_images(image_batch)

        image_batch = Transforms.channels_last_if_needed(image_batch)
        if image_batch.dtype == np.float32:
            assert (0 <= image_batch).all() and (image_batch <= 1).all()
            image_batch = (256 * image_batch).astype(np.uint8)

        assert image_batch.dtype == np.uint8
        # Debugging this issue here:
        # super()._after_reset(image_batch)

        # -- Code from Monitor
        if not self.enabled:
            return
        # Reset the stat count
        self.stats_recorder.after_reset(observation)
        if self.config.render:
            self.reset_video_recorder()

        # Bump *after* all reset activity has finished
        self.episode_id += 1

        self._flush()
Ejemplo n.º 2
0
 def render(self, mode="human", **kwargs):
     # NOTE: This doesn't get called, because the video recorder uses
     # self.env.render(), rather than self.render()
     # TODO: Render when the 'render' argument in config is set to True.
     image_batch = super().render(mode=mode, **kwargs)
     if mode == "rgb_array" and self.batch_size:
         image_batch = tile_images(image_batch)
     return image_batch
Ejemplo n.º 3
0
    def render(self, mode: str = "rgb_array") -> np.ndarray:
        observations = self._current_batch[0]
        if isinstance(observations, Observations):
            image_batch = observations.x
        else:
            assert isinstance(observations, Tensor)
            image_batch = observations
        if isinstance(image_batch, Tensor):
            image_batch = image_batch.cpu().numpy()

        if self.batch_size:
            image_batch = tile_images(image_batch)

        image_batch = Transforms.channels_last_if_needed(image_batch)
        assert image_batch.shape[-1] in {3, 4}
        if image_batch.dtype == np.float32:
            assert (0 <= image_batch).all() and (image_batch <= 1).all()
            image_batch = (256 * image_batch).astype(np.uint8)
        assert image_batch.dtype == np.uint8

        if mode == "rgb_array":
            # NOTE: Need to create a single image, channels_last format, and
            # possibly even of dtype uint8, in order for things like Monitor to
            # work.
            return image_batch

        if mode == "human":
            # return plt.imshow(image_batch)
            if self.viewer is None:
                display = None
                # TODO: There seems to be a bit of a bug, tests sometime fail because
                # "Can't connect to display: None" etc.
                try:
                    from gym.envs.classic_control.rendering import SimpleImageViewer
                except Exception:
                    from pyvirtualdisplay import Display

                    display = Display(visible=0, size=(1366, 768))
                    display.start()
                    from gym.envs.classic_control.rendering import SimpleImageViewer
                finally:
                    self.viewer = SimpleImageViewer(display=display)

            self.viewer.imshow(image_batch)
            return self.viewer.isopen

        raise NotImplementedError(f"Unsuported mode {mode}")
Ejemplo n.º 4
0
 def render(self, mode="human", **kwargs):
     # TODO: This might not be setup right. Need to check.
     image_batch = super().render(mode=mode, **kwargs)
     if mode == "rgb_array" and self.batch_size:
         return tile_images(image_batch)
     return image_batch