Ejemplo n.º 1
0
    def get_actions(self, observations: ContinualRLSetting.Observations,
                    action_space: gym.Space) -> ContinualRLSetting.Actions:
        state = observations.x
        # OK so the DQN model is built to handle a sequence of 4 observations?
        # something like that. So we have to do a bit of a "hack" to get it to
        # work here, where we create a buffer of size 4, and populate it with
        # random guesses at first, and once its filled, we can actually predict.
        # This assumes that we're being asked to give actions for a sequence of
        # observations.

        # Not sure in which order the DQN expects the sequence to be.
        state = ProcessFrame84.process(state)
        state = Transforms.to_tensor(state)
        state = Transforms.channels_first_if_needed(state)
        self.test_buffer.append(state)
        if len(self.test_buffer) < 4:
            print(
                f"Returning random action since we don't yet have 4 observations in the buffer."
            )
            return action_space.sample()
        # TODO: Fix the rest.
        # return action_space.sample()

        fake_batch = torch.stack(tuple(self.test_buffer))
        assert fake_batch.shape[0] == 4
        fake_batch = fake_batch.reshape([-1, 4, *fake_batch.shape[2:]])
        # fake_batch = fake_batches.reshape((-1, *fake_batches.shape[2:]))
        with torch.no_grad():
            fake_batch = fake_batch.to(self.model.device)
            values = self.model(fake_batch)

        chosen_actions = values.argmax(dim=-1)
        return chosen_actions.cpu().numpy()
Ejemplo n.º 2
0
    def _after_reset(self, observation: ClassIncrementalSetting.Observations):
        image_batch = observation.numpy().x
        # Need to create a single image with the right dtype for the Monitor
        # from gym to create gifs / videos with it.
        if self.batch_size:
            # Need to tile the image batch so it can be seen as a single image
            # by the Monitor.
            image_batch = tile_images(image_batch)

        image_batch = Transforms.channels_last_if_needed(image_batch)
        if image_batch.dtype == np.float32:
            assert (0 <= image_batch).all() and (image_batch <= 1).all()
            image_batch = (256 * image_batch).astype(np.uint8)

        assert image_batch.dtype == np.uint8
        # Debugging this issue here:
        # super()._after_reset(image_batch)

        # -- Code from Monitor
        if not self.enabled:
            return
        # Reset the stat count
        self.stats_recorder.after_reset(observation)
        if self.config.render:
            self.reset_video_recorder()

        # Bump *after* all reset activity has finished
        self.episode_id += 1

        self._flush()
Ejemplo n.º 3
0
    def observation_space(self) -> NamedTupleSpace:
        """ The un-batched observation space, based on the choice of dataset and
        the transforms at `self.transforms` (which apply to the train/valid/test
        environments).

        The returned spaces is a NamedTupleSpace, with the following properties:
        - `x`: observation space (e.g. `Image` space)
        - `task_labels`: Union[Discrete, Sparse[Discrete]]
           The task labels for each sample. When task labels are not available,
           the task labels space is Sparse, and entries will be `None`.
        """
        x_space = base_observation_spaces[self.dataset]
        if not self.transforms:
            # NOTE: When we don't pass any transforms, continuum scenarios still
            # at least use 'to_tensor'.
            x_space = Transforms.to_tensor(x_space)

        # apply the transforms to the observation space.
        for transform in self.transforms:
            x_space = transform(x_space)
        x_space = add_tensor_support(x_space)

        task_label_space = spaces.Discrete(self.nb_tasks)
        if not self.task_labels_at_train_time:
            task_label_space = Sparse(task_label_space, 1.0)
        task_label_space = add_tensor_support(task_label_space)

        return NamedTupleSpace(
            x=x_space,
            task_labels=task_label_space,
            dtype=self.Observations,
        )
Ejemplo n.º 4
0
    def render(self, mode: str = "rgb_array") -> np.ndarray:
        observations = self._current_batch[0]
        if isinstance(observations, Observations):
            image_batch = observations.x
        else:
            assert isinstance(observations, Tensor)
            image_batch = observations
        if isinstance(image_batch, Tensor):
            image_batch = image_batch.cpu().numpy()

        if self.batch_size:
            image_batch = tile_images(image_batch)

        image_batch = Transforms.channels_last_if_needed(image_batch)
        assert image_batch.shape[-1] in {3, 4}
        if image_batch.dtype == np.float32:
            assert (0 <= image_batch).all() and (image_batch <= 1).all()
            image_batch = (256 * image_batch).astype(np.uint8)
        assert image_batch.dtype == np.uint8

        if mode == "rgb_array":
            # NOTE: Need to create a single image, channels_last format, and
            # possibly even of dtype uint8, in order for things like Monitor to
            # work.
            return image_batch

        if mode == "human":
            # return plt.imshow(image_batch)
            if self.viewer is None:
                display = None
                # TODO: There seems to be a bit of a bug, tests sometime fail because
                # "Can't connect to display: None" etc.
                try:
                    from gym.envs.classic_control.rendering import SimpleImageViewer
                except Exception:
                    from pyvirtualdisplay import Display

                    display = Display(visible=0, size=(1366, 768))
                    display.start()
                    from gym.envs.classic_control.rendering import SimpleImageViewer
                finally:
                    self.viewer = SimpleImageViewer(display=display)

            self.viewer.imshow(image_batch)
            return self.viewer.isopen

        raise NotImplementedError(f"Unsuported mode {mode}")