def _after_reset(self, observation: ClassIncrementalSetting.Observations): image_batch = observation.numpy().x # Need to create a single image with the right dtype for the Monitor # from gym to create gifs / videos with it. if self.batch_size: # Need to tile the image batch so it can be seen as a single image # by the Monitor. image_batch = tile_images(image_batch) image_batch = Transforms.channels_last_if_needed(image_batch) if image_batch.dtype == np.float32: assert (0 <= image_batch).all() and (image_batch <= 1).all() image_batch = (256 * image_batch).astype(np.uint8) assert image_batch.dtype == np.uint8 # Debugging this issue here: # super()._after_reset(image_batch) # -- Code from Monitor if not self.enabled: return # Reset the stat count self.stats_recorder.after_reset(observation) if self.config.render: self.reset_video_recorder() # Bump *after* all reset activity has finished self.episode_id += 1 self._flush()
def render(self, mode="human", **kwargs): # NOTE: This doesn't get called, because the video recorder uses # self.env.render(), rather than self.render() # TODO: Render when the 'render' argument in config is set to True. image_batch = super().render(mode=mode, **kwargs) if mode == "rgb_array" and self.batch_size: image_batch = tile_images(image_batch) return image_batch
def render(self, mode: str = "rgb_array") -> np.ndarray: observations = self._current_batch[0] if isinstance(observations, Observations): image_batch = observations.x else: assert isinstance(observations, Tensor) image_batch = observations if isinstance(image_batch, Tensor): image_batch = image_batch.cpu().numpy() if self.batch_size: image_batch = tile_images(image_batch) image_batch = Transforms.channels_last_if_needed(image_batch) assert image_batch.shape[-1] in {3, 4} if image_batch.dtype == np.float32: assert (0 <= image_batch).all() and (image_batch <= 1).all() image_batch = (256 * image_batch).astype(np.uint8) assert image_batch.dtype == np.uint8 if mode == "rgb_array": # NOTE: Need to create a single image, channels_last format, and # possibly even of dtype uint8, in order for things like Monitor to # work. return image_batch if mode == "human": # return plt.imshow(image_batch) if self.viewer is None: display = None # TODO: There seems to be a bit of a bug, tests sometime fail because # "Can't connect to display: None" etc. try: from gym.envs.classic_control.rendering import SimpleImageViewer except Exception: from pyvirtualdisplay import Display display = Display(visible=0, size=(1366, 768)) display.start() from gym.envs.classic_control.rendering import SimpleImageViewer finally: self.viewer = SimpleImageViewer(display=display) self.viewer.imshow(image_batch) return self.viewer.isopen raise NotImplementedError(f"Unsuported mode {mode}")
def render(self, mode="human", **kwargs): # TODO: This might not be setup right. Need to check. image_batch = super().render(mode=mode, **kwargs) if mode == "rgb_array" and self.batch_size: return tile_images(image_batch) return image_batch