def observation_space(self) -> NamedTupleSpace: """ The un-batched observation space, based on the choice of dataset and the transforms at `self.transforms` (which apply to the train/valid/test environments). The returned spaces is a NamedTupleSpace, with the following properties: - `x`: observation space (e.g. `Image` space) - `task_labels`: Union[Discrete, Sparse[Discrete]] The task labels for each sample. When task labels are not available, the task labels space is Sparse, and entries will be `None`. """ x_space = base_observation_spaces[self.dataset] if not self.transforms: # NOTE: When we don't pass any transforms, continuum scenarios still # at least use 'to_tensor'. x_space = Transforms.to_tensor(x_space) # apply the transforms to the observation space. for transform in self.transforms: x_space = transform(x_space) x_space = add_tensor_support(x_space) task_label_space = spaces.Discrete(self.nb_tasks) if not self.task_labels_at_train_time: task_label_space = Sparse(task_label_space, 1.0) task_label_space = add_tensor_support(task_label_space) return NamedTupleSpace( x=x_space, task_labels=task_label_space, dtype=self.Observations, )
def _add_task_labels_to_space(observation: X, task_labels: T) -> spaces.Dict: # TODO: Return a dict or NamedTuple at some point: return NamedTupleSpace( x=observation, task_labels=task_labels, dtype=ObservationsAndTaskLabels, )
def test_task_id_is_added_even_when_no_known_task_schedule(): """ Test that even when the env is unknown or there are no task params, the task_id is still added correctly and is zero at all times. """ # Breakout doesn't have default task params. original: CartPoleEnv = gym.make("Breakout-v0") env = MultiTaskEnvironment( original, add_task_id_to_obs=True, ) env.seed(123) env.reset() assert env.observation_space == NamedTupleSpace( x=original.observation_space, task_labels=spaces.Discrete(1), ) for step in range(0, 100): obs, _, done, info = env.step(env.action_space.sample()) # env.render() x, task_id = obs assert task_id == 0 if done: x, task_id = env.reset() assert task_id == 0 env.close()
def test_add_task_id_to_obs(): """ Test that the 'info' dict contains the task dict. """ original: CartPoleEnv = gym.make("CartPole-v0") starting_length = original.length starting_gravity = original.gravity task_schedule = { 10: dict(length=0.1), 20: dict(length=0.2, gravity=-12.0), 30: dict(gravity=0.9), } env = MultiTaskEnvironment( original, task_schedule=task_schedule, add_task_id_to_obs=True, ) env.seed(123) env.reset() assert env.observation_space == NamedTupleSpace( x=original.observation_space, task_labels=spaces.Discrete(4), ) for step in range(100): obs, _, done, info = env.step(env.action_space.sample()) # env.render() x, task_id = obs if 0 <= step < 10: assert env.length == starting_length and env.gravity == starting_gravity assert task_id == 0, step elif 10 <= step < 20: assert env.length == 0.1 assert task_id == 1, step elif 20 <= step < 30: assert env.length == 0.2 and env.gravity == -12.0 assert task_id == 2, step elif step >= 30: assert env.length == starting_length and env.gravity == 0.9 assert task_id == 3, step if done: obs = env.reset() assert isinstance(obs, tuple) env.close()
def _(space: NamedTupleSpace, sample: NamedTuple) -> NamedTuple: sample_dict: Dict if isinstance(sample, NamedTuple): sample_dict = sample._asdict() elif isinstance(sample, Mapping): sample_dict = sample else: assert len(sample) == len(space.spaces) sample_dict = dict(zip(space.names, sample)) return space.dtype( **{ key: from_tensor(space[key], value) if key in space.names else value for key, value in sample_dict.items() })
def test_starting_step_and_max_step(): """ Test that when start_step and max_step arg given, the env stays within the [start_step, max_step] portion of the task schedule. """ original: CartPoleEnv = gym.make("CartPole-v0") starting_length = original.length starting_gravity = original.gravity task_schedule = { 10: dict(length=0.1), 20: dict(length=0.2, gravity=-12.0), 30: dict(gravity=0.9), } env = MultiTaskEnvironment( original, task_schedule=task_schedule, add_task_id_to_obs=True, starting_step=10, max_steps=19, ) env.seed(123) env.reset() assert env.observation_space == NamedTupleSpace( x=original.observation_space, task_labels=spaces.Discrete(4), ) # Trying to set the 'steps' to something smaller than the starting step # doesn't work. env.steps = -123 assert env.steps == 10 # Trying to set the 'steps' to something greater than the max_steps # doesn't work. env.steps = 50 assert env.steps == 19 # Here we reset the steps to 10, and also check that this works. env.steps = 10 assert env.steps == 10 for step in range(0, 100): # The environment started at an offset of 10. assert env.steps == max(min(step + 10, 19), 10) obs, _, done, info = env.step(env.action_space.sample()) # env.render() x, task_id = obs # Check that we're always stuck between 10 and 20 assert 10 <= env.steps < 20 assert env.length == 0.1 assert task_id == 1, step if done: print(f"Resetting on step {step}") obs = env.reset() assert isinstance(obs, tuple) env.close()
def __post_init__(self, *args, **kwargs): super().__post_init__(*args, **kwargs) self._new_random_task_on_reset: bool = False # Post processing of the 'dataset' field: if self.dataset in self.available_datasets.keys(): # the environment name was passed, rather than an id # (e.g. 'cartpole' -> 'CartPole-v0"). self.dataset = self.available_datasets[self.dataset] elif self.dataset not in self.available_datasets.values(): # The passed dataset is assumed to be an environment ID, but it # wasn't in the dict of available datasets! We issue a warning, but # proceed to let the user use whatever environment they want to. logger.warning( UserWarning( f"The chosen dataset/environment ({self.dataset}) isn't in the " f"available_datasets dict, so we can't garantee this will work!" )) if isinstance(self.dataset, gym.Env) and self.batch_size: raise RuntimeError(f"Batch size should be None when a gym.Env " f"object is passed as `dataset`.") if not isinstance(self.dataset, (str, gym.Env)) and not callable(self.dataset): raise RuntimeError( f"`dataset` must be either a string, a gym.Env, or a callable. " f"(got {self.dataset})") # Set the number of tasks depending on the increment, and vice-versa. # (as only one of the two should be used). assert self.max_steps, "assuming this should always be set, for now." # TODO: Clean this up, not super clear what options take precedence on # which other options. # Load the task schedules from the corresponding files, if present. if self.train_task_schedule_path: self.train_task_schedule = self.load_task_schedule( self.train_task_schedule_path) if self.valid_task_schedule_path: self.valid_task_schedule = self.load_task_schedule( self.valid_task_schedule_path) if self.test_task_schedule_path: self.test_task_schedule = self.load_task_schedule( self.test_task_schedule_path) if self.train_task_schedule: if self.steps_per_task is not None: # If steps per task was passed, then we overwrite the keys of the tasks # schedule. self.train_task_schedule = { i * self.steps_per_task: self.train_task_schedule[step] for i, step in enumerate( sorted(self.train_task_schedule.keys())) } else: # A task schedule was passed: infer the number of tasks from it. change_steps = sorted(self.train_task_schedule.keys()) assert 0 in change_steps, "Schedule needs a task at step 0." # TODO: @lebrice: I guess we have to assume that the interval # between steps is constant for now? Do we actually depend on this # being the case? I think steps_per_task is only really ever used # for creating the task schedule, which we already have in this # case. assert ( len(change_steps) >= 2 ), "WIP: need a minimum of two tasks in the task schedule for now." self.steps_per_task = change_steps[1] - change_steps[0] # Double-check that this is the case. for i in range(len(change_steps) - 1): if change_steps[ i + 1] - change_steps[i] != self.steps_per_task: raise NotImplementedError( f"WIP: This might not work yet if the tasks aren't " f"equally spaced out at a fixed interval.") nb_tasks = len(self.train_task_schedule) if self.smooth_task_boundaries: # NOTE: When in a ContinualRLSetting with smooth task boundaries, # the last entry in the schedule represents the state of the env at # the end of the "task". When there are clear task boundaries (i.e. # when in 'Class'/Task-Incremental RL), the last entry is the start # of the last task. nb_tasks -= 1 if self.nb_tasks != 1: if self.nb_tasks != nb_tasks: raise RuntimeError( f"Passed number of tasks {self.nb_tasks} doesn't match the " f"number of tasks deduced from the task schedule ({nb_tasks})" ) self.nb_tasks = nb_tasks self.max_steps = max(self.train_task_schedule.keys()) if not self.smooth_task_boundaries: # See above note about the last entry. self.max_steps += self.steps_per_task elif self.nb_tasks: if self.steps_per_task: self.max_steps = self.nb_tasks * self.steps_per_task elif self.max_steps: self.steps_per_task = self.max_steps // self.nb_tasks elif self.steps_per_task: if self.nb_tasks: self.max_steps = self.nb_tasks * self.steps_per_task elif self.max_steps: self.nb_tasks = self.max_steps // self.steps_per_task elif self.max_steps: if self.nb_tasks: self.steps_per_task = self.max_steps // self.nb_tasks elif self.steps_per_task: self.nb_tasks = self.max_steps // self.steps_per_task if not all([self.nb_tasks, self.max_steps, self.steps_per_task]): raise RuntimeError( f"You need to provide at least two of 'max_steps', " f"'nb_tasks', or 'steps_per_task'.") assert self.max_steps == self.nb_tasks * self.steps_per_task if self.test_task_schedule: if 0 not in self.test_task_schedule: raise RuntimeError( "Task schedules needs to include an initial task.") if self.test_steps_per_task is not None: # If steps per task was passed, then we overwrite the number of steps # for each task in the schedule to match. self.test_task_schedule = { i * self.test_steps_per_task: self.test_task_schedule[step] for i, step in enumerate( sorted(self.test_task_schedule.keys())) } change_steps = sorted(self.test_task_schedule.keys()) assert 0 in change_steps, "Schedule needs to include task at step 0." nb_test_tasks = len(change_steps) if self.smooth_task_boundaries: nb_test_tasks -= 1 assert (nb_test_tasks == self.nb_tasks ), "nb of tasks should be the same for train and test." self.test_steps_per_task = change_steps[1] - change_steps[0] for i in range(self.nb_tasks - 1): if change_steps[ i + 1] - change_steps[i] != self.test_steps_per_task: raise NotImplementedError( "WIP: This might not work yet if the test tasks aren't " "equally spaced out at a fixed interval.") self.test_steps = max(change_steps) if not self.smooth_task_boundaries: # See above note about the last entry. self.test_steps += self.test_steps_per_task elif self.test_steps_per_task is None: # This is basically never the case, since the test_steps defaults to 10_000. assert (self.test_steps ), "need to set one of test_steps or test_steps_per_task" self.test_steps_per_task = self.test_steps // self.nb_tasks else: # FIXME: This is too complicated for what is is. # Check that the test steps must either be the default value, or the right # value to use in this case. assert self.test_steps in { 10_000, self.test_steps_per_task * self.nb_tasks } assert (self.test_steps_per_task ), "need to set one of test_steps or test_steps_per_task" self.test_steps = self.test_steps_per_task * self.nb_tasks assert self.test_steps // self.test_steps_per_task == self.nb_tasks if self.smooth_task_boundaries: # If we're operating in the 'Online/smooth task transitions' "regime", # then there is only one "task", and we don't have task labels. # TODO: HOWEVER, the task schedule could/should be able to have more # than one non-stationarity! This indicates a need for a distinction # between 'tasks' and 'non-stationarities' (changes in the env). self.known_task_boundaries_at_train_time = False self.known_task_boundaries_at_test_time = False self.task_labels_at_train_time = False self.task_labels_at_test_time = False # self.steps_per_task = self.max_steps # Task schedules for training / validation and testing. # Create a temporary environment so we can extract the spaces and create # the task schedules. with self._make_env(self.dataset, self._temp_wrappers(), self.observe_state_directly) as temp_env: # FIXME: Replacing the observation space dtypes from their original # 'generated' NamedTuples to self.Observations. The alternative # would be to add another argument to the MultiTaskEnv wrapper, to # pass down a dtype to be set on its observation_space's `dtype` # attribute, which would be ugly. assert isinstance(temp_env.observation_space, NamedTupleSpace) temp_env.observation_space.dtype = self.Observations # Populate the task schedules created above. if not self.train_task_schedule: train_change_steps = list( range(0, self.max_steps, self.steps_per_task)) if self.smooth_task_boundaries: # Add a last 'task' at the end of the 'epoch', so that the # env changes smoothly right until the end. train_change_steps.append(self.max_steps) self.train_task_schedule = self.create_task_schedule( temp_env, train_change_steps, ) assert self.train_task_schedule is not None # The validation task schedule is the same as the one used in # training by default. if not self.valid_task_schedule: self.valid_task_schedule = deepcopy(self.train_task_schedule) if not self.test_task_schedule: # The test task schedule is by default the same as in validation # except that the interval between the tasks may be different, # depending on the value of `self.test_steps_per_task`. valid_steps = sorted(self.valid_task_schedule.keys()) valid_tasks = [ self.valid_task_schedule[step] for step in valid_steps ] self.test_task_schedule = { i * self.test_steps_per_task: deepcopy(task) for i, task in enumerate(valid_tasks) } space_dict = dict(temp_env.observation_space.items()) task_label_space = spaces.Discrete(self.nb_tasks) if not self.task_labels_at_train_time or not self.task_labels_at_test_time: task_label_space = Sparse(task_label_space) space_dict["task_labels"] = task_label_space # FIXME: Temporarily, we will actually set the task label space, since there # appears to be an error when using monsterkong space. observation_space = NamedTupleSpace(spaces=space_dict, dtype=self.Observations) self.observation_space = observation_space # Set the spaces using the temp env. # self.observation_space = temp_env.observation_space self.action_space = temp_env.action_space self.reward_range = temp_env.reward_range self.reward_space = getattr( temp_env, "reward_space", spaces.Box(low=self.reward_range[0], high=self.reward_range[1], shape=()), ) del temp_env self.train_env: gym.Env self.valid_env: gym.Env self.test_env: gym.Env
def _(space: NamedTupleSpace, sample: NamedTuple, device: torch.device = None): return space.dtype( **{ key: to_tensor(space[i], sample[i], device=device) for i, key in enumerate(space._spaces.keys()) })