Ejemplo n.º 1
0
    def observation_space(self) -> NamedTupleSpace:
        """ The un-batched observation space, based on the choice of dataset and
        the transforms at `self.transforms` (which apply to the train/valid/test
        environments).

        The returned spaces is a NamedTupleSpace, with the following properties:
        - `x`: observation space (e.g. `Image` space)
        - `task_labels`: Union[Discrete, Sparse[Discrete]]
           The task labels for each sample. When task labels are not available,
           the task labels space is Sparse, and entries will be `None`.
        """
        x_space = base_observation_spaces[self.dataset]
        if not self.transforms:
            # NOTE: When we don't pass any transforms, continuum scenarios still
            # at least use 'to_tensor'.
            x_space = Transforms.to_tensor(x_space)

        # apply the transforms to the observation space.
        for transform in self.transforms:
            x_space = transform(x_space)
        x_space = add_tensor_support(x_space)

        task_label_space = spaces.Discrete(self.nb_tasks)
        if not self.task_labels_at_train_time:
            task_label_space = Sparse(task_label_space, 1.0)
        task_label_space = add_tensor_support(task_label_space)

        return NamedTupleSpace(
            x=x_space,
            task_labels=task_label_space,
            dtype=self.Observations,
        )
Ejemplo n.º 2
0
def _add_task_labels_to_space(observation: X, task_labels: T) -> spaces.Dict:
    # TODO: Return a dict or NamedTuple at some point:
    return NamedTupleSpace(
        x=observation,
        task_labels=task_labels,
        dtype=ObservationsAndTaskLabels,
    )
Ejemplo n.º 3
0
def test_task_id_is_added_even_when_no_known_task_schedule():
    """ Test that even when the env is unknown or there are no task params, the
    task_id is still added correctly and is zero at all times.
    """
    # Breakout doesn't have default task params.
    original: CartPoleEnv = gym.make("Breakout-v0")
    env = MultiTaskEnvironment(
        original,
        add_task_id_to_obs=True,
    )
    env.seed(123)
    env.reset()

    assert env.observation_space == NamedTupleSpace(
        x=original.observation_space,
        task_labels=spaces.Discrete(1),
    )
    for step in range(0, 100):
        obs, _, done, info = env.step(env.action_space.sample())
        # env.render()

        x, task_id = obs
        assert task_id == 0

        if done:
            x, task_id = env.reset()
            assert task_id == 0
    env.close()
Ejemplo n.º 4
0
def test_add_task_id_to_obs():
    """ Test that the 'info' dict contains the task dict. """
    original: CartPoleEnv = gym.make("CartPole-v0")
    starting_length = original.length
    starting_gravity = original.gravity

    task_schedule = {
        10: dict(length=0.1),
        20: dict(length=0.2, gravity=-12.0),
        30: dict(gravity=0.9),
    }
    env = MultiTaskEnvironment(
        original,
        task_schedule=task_schedule,
        add_task_id_to_obs=True,
    )
    env.seed(123)
    env.reset()

    assert env.observation_space == NamedTupleSpace(
        x=original.observation_space,
        task_labels=spaces.Discrete(4),
    )

    for step in range(100):
        obs, _, done, info = env.step(env.action_space.sample())
        # env.render()

        x, task_id = obs

        if 0 <= step < 10:
            assert env.length == starting_length and env.gravity == starting_gravity
            assert task_id == 0, step

        elif 10 <= step < 20:
            assert env.length == 0.1
            assert task_id == 1, step

        elif 20 <= step < 30:
            assert env.length == 0.2 and env.gravity == -12.0
            assert task_id == 2, step

        elif step >= 30:
            assert env.length == starting_length and env.gravity == 0.9
            assert task_id == 3, step

        if done:
            obs = env.reset()
            assert isinstance(obs, tuple)

    env.close()
Ejemplo n.º 5
0
def _(space: NamedTupleSpace, sample: NamedTuple) -> NamedTuple:
    sample_dict: Dict
    if isinstance(sample, NamedTuple):
        sample_dict = sample._asdict()
    elif isinstance(sample, Mapping):
        sample_dict = sample
    else:
        assert len(sample) == len(space.spaces)
        sample_dict = dict(zip(space.names, sample))

    return space.dtype(
        **{
            key: from_tensor(space[key], value) if key in
            space.names else value
            for key, value in sample_dict.items()
        })
Ejemplo n.º 6
0
def test_starting_step_and_max_step():
    """ Test that when start_step and max_step arg given, the env stays within
    the [start_step, max_step] portion of the task schedule.
    """
    original: CartPoleEnv = gym.make("CartPole-v0")
    starting_length = original.length
    starting_gravity = original.gravity

    task_schedule = {
        10: dict(length=0.1),
        20: dict(length=0.2, gravity=-12.0),
        30: dict(gravity=0.9),
    }
    env = MultiTaskEnvironment(
        original,
        task_schedule=task_schedule,
        add_task_id_to_obs=True,
        starting_step=10,
        max_steps=19,
    )
    env.seed(123)
    env.reset()

    assert env.observation_space == NamedTupleSpace(
        x=original.observation_space,
        task_labels=spaces.Discrete(4),
    )

    # Trying to set the 'steps' to something smaller than the starting step
    # doesn't work.
    env.steps = -123
    assert env.steps == 10

    # Trying to set the 'steps' to something greater than the max_steps
    # doesn't work.
    env.steps = 50
    assert env.steps == 19

    # Here we reset the steps to 10, and also check that this works.
    env.steps = 10
    assert env.steps == 10

    for step in range(0, 100):
        # The environment started at an offset of 10.
        assert env.steps == max(min(step + 10, 19), 10)

        obs, _, done, info = env.step(env.action_space.sample())
        # env.render()

        x, task_id = obs

        # Check that we're always stuck between 10 and 20
        assert 10 <= env.steps < 20
        assert env.length == 0.1
        assert task_id == 1, step

        if done:
            print(f"Resetting on step {step}")
            obs = env.reset()
            assert isinstance(obs, tuple)

    env.close()
Ejemplo n.º 7
0
    def __post_init__(self, *args, **kwargs):
        super().__post_init__(*args, **kwargs)
        self._new_random_task_on_reset: bool = False

        # Post processing of the 'dataset' field:
        if self.dataset in self.available_datasets.keys():
            # the environment name was passed, rather than an id
            # (e.g. 'cartpole' -> 'CartPole-v0").
            self.dataset = self.available_datasets[self.dataset]

        elif self.dataset not in self.available_datasets.values():
            # The passed dataset is assumed to be an environment ID, but it
            # wasn't in the dict of available datasets! We issue a warning, but
            # proceed to let the user use whatever environment they want to.
            logger.warning(
                UserWarning(
                    f"The chosen dataset/environment ({self.dataset}) isn't in the "
                    f"available_datasets dict, so we can't garantee this will work!"
                ))

        if isinstance(self.dataset, gym.Env) and self.batch_size:
            raise RuntimeError(f"Batch size should be None when a gym.Env "
                               f"object is passed as `dataset`.")
        if not isinstance(self.dataset,
                          (str, gym.Env)) and not callable(self.dataset):
            raise RuntimeError(
                f"`dataset` must be either a string, a gym.Env, or a callable. "
                f"(got {self.dataset})")

        # Set the number of tasks depending on the increment, and vice-versa.
        # (as only one of the two should be used).
        assert self.max_steps, "assuming this should always be set, for now."
        # TODO: Clean this up, not super clear what options take precedence on
        # which other options.

        # Load the task schedules from the corresponding files, if present.
        if self.train_task_schedule_path:
            self.train_task_schedule = self.load_task_schedule(
                self.train_task_schedule_path)

        if self.valid_task_schedule_path:
            self.valid_task_schedule = self.load_task_schedule(
                self.valid_task_schedule_path)

        if self.test_task_schedule_path:
            self.test_task_schedule = self.load_task_schedule(
                self.test_task_schedule_path)

        if self.train_task_schedule:
            if self.steps_per_task is not None:
                # If steps per task was passed, then we overwrite the keys of the tasks
                # schedule.
                self.train_task_schedule = {
                    i * self.steps_per_task: self.train_task_schedule[step]
                    for i, step in enumerate(
                        sorted(self.train_task_schedule.keys()))
                }
            else:
                # A task schedule was passed: infer the number of tasks from it.
                change_steps = sorted(self.train_task_schedule.keys())
                assert 0 in change_steps, "Schedule needs a task at step 0."
                # TODO: @lebrice: I guess we have to assume that the interval
                # between steps is constant for now? Do we actually depend on this
                # being the case? I think steps_per_task is only really ever used
                # for creating the task schedule, which we already have in this
                # case.
                assert (
                    len(change_steps) >= 2
                ), "WIP: need a minimum of two tasks in the task schedule for now."
                self.steps_per_task = change_steps[1] - change_steps[0]
                # Double-check that this is the case.
                for i in range(len(change_steps) - 1):
                    if change_steps[
                            i + 1] - change_steps[i] != self.steps_per_task:
                        raise NotImplementedError(
                            f"WIP: This might not work yet if the tasks aren't "
                            f"equally spaced out at a fixed interval.")

            nb_tasks = len(self.train_task_schedule)
            if self.smooth_task_boundaries:
                # NOTE: When in a ContinualRLSetting with smooth task boundaries,
                # the last entry in the schedule represents the state of the env at
                # the end of the "task". When there are clear task boundaries (i.e.
                # when in 'Class'/Task-Incremental RL), the last entry is the start
                # of the last task.
                nb_tasks -= 1
            if self.nb_tasks != 1:
                if self.nb_tasks != nb_tasks:
                    raise RuntimeError(
                        f"Passed number of tasks {self.nb_tasks} doesn't match the "
                        f"number of tasks deduced from the task schedule ({nb_tasks})"
                    )
            self.nb_tasks = nb_tasks

            self.max_steps = max(self.train_task_schedule.keys())
            if not self.smooth_task_boundaries:
                # See above note about the last entry.
                self.max_steps += self.steps_per_task

        elif self.nb_tasks:
            if self.steps_per_task:
                self.max_steps = self.nb_tasks * self.steps_per_task
            elif self.max_steps:
                self.steps_per_task = self.max_steps // self.nb_tasks

        elif self.steps_per_task:
            if self.nb_tasks:
                self.max_steps = self.nb_tasks * self.steps_per_task
            elif self.max_steps:
                self.nb_tasks = self.max_steps // self.steps_per_task

        elif self.max_steps:
            if self.nb_tasks:
                self.steps_per_task = self.max_steps // self.nb_tasks
            elif self.steps_per_task:
                self.nb_tasks = self.max_steps // self.steps_per_task

        if not all([self.nb_tasks, self.max_steps, self.steps_per_task]):
            raise RuntimeError(
                f"You need to provide at least two of 'max_steps', "
                f"'nb_tasks', or 'steps_per_task'.")

        assert self.max_steps == self.nb_tasks * self.steps_per_task

        if self.test_task_schedule:
            if 0 not in self.test_task_schedule:
                raise RuntimeError(
                    "Task schedules needs to include an initial task.")

            if self.test_steps_per_task is not None:
                # If steps per task was passed, then we overwrite the number of steps
                # for each task in the schedule to match.
                self.test_task_schedule = {
                    i * self.test_steps_per_task: self.test_task_schedule[step]
                    for i, step in enumerate(
                        sorted(self.test_task_schedule.keys()))
                }

            change_steps = sorted(self.test_task_schedule.keys())
            assert 0 in change_steps, "Schedule needs to include task at step 0."

            nb_test_tasks = len(change_steps)
            if self.smooth_task_boundaries:
                nb_test_tasks -= 1
            assert (nb_test_tasks == self.nb_tasks
                    ), "nb of tasks should be the same for train and test."

            self.test_steps_per_task = change_steps[1] - change_steps[0]
            for i in range(self.nb_tasks - 1):
                if change_steps[
                        i + 1] - change_steps[i] != self.test_steps_per_task:
                    raise NotImplementedError(
                        "WIP: This might not work yet if the test tasks aren't "
                        "equally spaced out at a fixed interval.")

            self.test_steps = max(change_steps)
            if not self.smooth_task_boundaries:
                # See above note about the last entry.
                self.test_steps += self.test_steps_per_task

        elif self.test_steps_per_task is None:
            # This is basically never the case, since the test_steps defaults to 10_000.
            assert (self.test_steps
                    ), "need to set one of test_steps or test_steps_per_task"
            self.test_steps_per_task = self.test_steps // self.nb_tasks
        else:
            # FIXME: This is too complicated for what is is.
            # Check that the test steps must either be the default value, or the right
            # value to use in this case.
            assert self.test_steps in {
                10_000, self.test_steps_per_task * self.nb_tasks
            }
            assert (self.test_steps_per_task
                    ), "need to set one of test_steps or test_steps_per_task"
            self.test_steps = self.test_steps_per_task * self.nb_tasks

        assert self.test_steps // self.test_steps_per_task == self.nb_tasks

        if self.smooth_task_boundaries:
            # If we're operating in the 'Online/smooth task transitions' "regime",
            # then there is only one "task", and we don't have task labels.
            # TODO: HOWEVER, the task schedule could/should be able to have more
            # than one non-stationarity! This indicates a need for a distinction
            # between 'tasks' and 'non-stationarities' (changes in the env).
            self.known_task_boundaries_at_train_time = False
            self.known_task_boundaries_at_test_time = False
            self.task_labels_at_train_time = False
            self.task_labels_at_test_time = False
            # self.steps_per_task = self.max_steps

        # Task schedules for training / validation and testing.

        # Create a temporary environment so we can extract the spaces and create
        # the task schedules.
        with self._make_env(self.dataset, self._temp_wrappers(),
                            self.observe_state_directly) as temp_env:
            # FIXME: Replacing the observation space dtypes from their original
            # 'generated' NamedTuples to self.Observations. The alternative
            # would be to add another argument to the MultiTaskEnv wrapper, to
            # pass down a dtype to be set on its observation_space's `dtype`
            # attribute, which would be ugly.
            assert isinstance(temp_env.observation_space, NamedTupleSpace)
            temp_env.observation_space.dtype = self.Observations
            # Populate the task schedules created above.
            if not self.train_task_schedule:
                train_change_steps = list(
                    range(0, self.max_steps, self.steps_per_task))
                if self.smooth_task_boundaries:
                    # Add a last 'task' at the end of the 'epoch', so that the
                    # env changes smoothly right until the end.
                    train_change_steps.append(self.max_steps)
                self.train_task_schedule = self.create_task_schedule(
                    temp_env,
                    train_change_steps,
                )

            assert self.train_task_schedule is not None
            # The validation task schedule is the same as the one used in
            # training by default.
            if not self.valid_task_schedule:
                self.valid_task_schedule = deepcopy(self.train_task_schedule)

            if not self.test_task_schedule:
                # The test task schedule is by default the same as in validation
                # except that the interval between the tasks may be different,
                # depending on the value of `self.test_steps_per_task`.
                valid_steps = sorted(self.valid_task_schedule.keys())
                valid_tasks = [
                    self.valid_task_schedule[step] for step in valid_steps
                ]
                self.test_task_schedule = {
                    i * self.test_steps_per_task: deepcopy(task)
                    for i, task in enumerate(valid_tasks)
                }

            space_dict = dict(temp_env.observation_space.items())
            task_label_space = spaces.Discrete(self.nb_tasks)
            if not self.task_labels_at_train_time or not self.task_labels_at_test_time:
                task_label_space = Sparse(task_label_space)
            space_dict["task_labels"] = task_label_space

            # FIXME: Temporarily, we will actually set the task label space, since there
            # appears to be an error when using monsterkong space.
            observation_space = NamedTupleSpace(spaces=space_dict,
                                                dtype=self.Observations)
            self.observation_space = observation_space
            # Set the spaces using the temp env.
            # self.observation_space = temp_env.observation_space
            self.action_space = temp_env.action_space
            self.reward_range = temp_env.reward_range
            self.reward_space = getattr(
                temp_env,
                "reward_space",
                spaces.Box(low=self.reward_range[0],
                           high=self.reward_range[1],
                           shape=()),
            )

        del temp_env

        self.train_env: gym.Env
        self.valid_env: gym.Env
        self.test_env: gym.Env
Ejemplo n.º 8
0
def _(space: NamedTupleSpace, sample: NamedTuple, device: torch.device = None):
    return space.dtype(
        **{
            key: to_tensor(space[i], sample[i], device=device)
            for i, key in enumerate(space._spaces.keys())
        })