コード例 #1
class TaskIncrementalRLSetting(IncrementalRLSetting):
    """ Continual RL setting with clear task boundaries and task labels.

    The task labels are given at both train and test time.
    task_labels_at_train_time: bool = constant(True)
    task_labels_at_test_time: bool = constant(True)
コード例 #2
class TaskIncrementalAssumption(FullyObservableContextAssumption,
    """ Assumption (mixin) for Settings where the task labels are available at
    both train and test time.
    task_labels_at_train_time: bool = constant(True)
    task_labels_at_test_time: bool = constant(True)
コード例 #3
class PartiallyObservableContextAssumption(HiddenContextAssumption):
    # Wether the task labels are observable during training.
    task_labels_at_train_time: bool = constant(True)
    # Wether we get informed when reaching the boundary between two tasks during
    # training.
    known_task_boundaries_at_train_time: bool = constant(True)
    known_task_boundaries_at_test_time: bool = flag(True)
コード例 #4
class IncrementalRLSetting(ContinualRLSetting):
    """ Continual RL setting the data is divided into 'tasks' with clear boundaries.

    By default, the task labels are given at train time, but not at test time.

    TODO: Decide how to implement the train procedure, if we give a single
    dataloader, we might need to call the agent's `on_task_switch` when we reach
    the task boundary.. Or, we could produce one dataloader per task, and then
    implement a custom `fit` procedure in the CLTrainer class, that loops over
    the tasks and calls the `on_task_switch` when needed.
    # Number of tasks.
    nb_tasks: int = 1
    # Wether the task boundaries are smooth or sudden.
    smooth_task_boundaries: bool = constant(False)
    # Wether to give access to the task labels at train time.
    task_labels_at_train_time: bool = True
    # Wether to give access to the task labels at test time.
    task_labels_at_test_time: bool = False

    # Class variable that holds the dict of available environments.
    available_datasets: ClassVar[Dict[str, str]] = dict_union(
            "monsterkong": "MetaMonsterKong-v0",
    dataset: str = choice(available_datasets, default="cartpole")

    def __post_init__(self, *args, **kwargs):
        super().__post_init__(*args, **kwargs)

        if self.dataset == "MetaMonsterKong-v0":
            # TODO: Limit the episode length in monsterkong?
            # TODO: Actually end episodes when reaching a task boundary, to force the
            # level to change?
            self.max_episode_steps = self.max_episode_steps or 500

    def phases(self) -> int:
        """The number of training 'phases', i.e. how many times `method.fit` will be
        In this Incremental-RL Setting, fit is called once per task.
        (Same as ClassIncrementalSetting in SL).
        return self.nb_tasks

    def create_task_schedule(self, temp_env: MultiTaskEnvironment,
                             change_steps: List[int]) -> Dict[int, Dict]:
        task_schedule: Dict[int, Dict] = {}
        if monsterkong_installed:
            if isinstance(temp_env.unwrapped, MetaMonsterKongEnv):
                for i, task_step in enumerate(change_steps):
                    task_schedule[task_step] = {"level": i}
                return task_schedule
        return super().create_task_schedule(temp_env=temp_env,
コード例 #5
class TraditionalSetting(TaskIncrementalAssumption):
    """ Assumption (mixin) for Settings where the data is stationary (only one
    nb_tasks: int = constant(1)
    def phases(self) -> int:
        """The number of training 'phases', i.e. how many times `method.fit` will be
        Defaults to the number of tasks, but may be different, for instance in so-called
        Multi-Task Settings, this is set to 1.
        return 1
コード例 #6
class DiscreteContextAssumption(ContinuousContextAssumption):
    # Wether we have clear boundaries between tasks, or if the transitions are smooth.
    # Equivalent to wether the context variable is discrete vs continuous.
    smooth_task_boundaries: bool = constant(False)
コード例 #7
class IncrementalRLSetting(ContinualRLSetting):
    """ Continual RL setting the data is divided into 'tasks' with clear boundaries.

    By default, the task labels are given at train time, but not at test time.

    TODO: Decide how to implement the train procedure, if we give a single
    dataloader, we might need to call the agent's `on_task_switch` when we reach
    the task boundary.. Or, we could produce one dataloader per task, and then
    implement a custom `fit` procedure in the CLTrainer class, that loops over
    the tasks and calls the `on_task_switch` when needed.

    # The number of tasks. By default 0, which means that it will be set
    # depending on other fields in __post_init__, or eventually be just 1.
    nb_tasks: int = field(0, alias=["n_tasks", "num_tasks"])
    # Wether the task boundaries are smooth or sudden.
    smooth_task_boundaries: bool = constant(False)
    # Wether to give access to the task labels at train time.
    task_labels_at_train_time: bool = True
    # Wether to give access to the task labels at test time.
    task_labels_at_test_time: bool = False

    # Class variable that holds the dict of available environments.
    available_datasets: ClassVar[Dict[str, str]] = dict_union(
        {"monsterkong": "MetaMonsterKong-v0"},
    dataset: str = "CartPole-v0"

    def __post_init__(self, *args, **kwargs):
        if not self.nb_tasks:
            # TODO: In case of the metaworld envs, we could device the 'default' nb of
            # tasks to use based on the number of available tasks

        super().__post_init__(*args, **kwargs)

        if self.dataset == "MetaMonsterKong-v0":
            # TODO: Limit the episode length in monsterkong?
            # TODO: Actually end episodes when reaching a task boundary, to force the
            # level to change?
            self.max_episode_steps = self.max_episode_steps or 500

        # FIXME: Really annoying little bugs with these three arguments!
        self.nb_tasks = self.max_steps // self.steps_per_task

    def _setup_fields_using_temp_env(self, temp_env: MultiTaskEnvironment):
        """ Setup some of the fields on the Setting using a temporary environment.

        This temporary environment only lives during the __post_init__() call.
        # TODO: If the dataset has a `max_path_length` attribute, then it's probably
        # a Mujoco / metaworld / etc env, and so we set a limit on the episode length to
        # avoid getting an error.
        max_path_length: Optional[int] = getattr(temp_env, "max_path_length",
        if self.max_episode_steps is None and max_path_length is not None:
            assert max_path_length > 0
            self.max_episode_steps = temp_env.max_path_length

    def phases(self) -> int:
        """The number of training 'phases', i.e. how many times `method.fit` will be

        In this Incremental-RL Setting, fit is called once per task.
        (Same as ClassIncrementalSetting in SL).
        return self.nb_tasks

    def _make_env(
        base_env: Union[str, gym.Env, Callable[[], gym.Env]],
        wrappers: List[Callable[[gym.Env], gym.Env]] = None,
        observe_state_directly: bool = False,
    ) -> gym.Env:
        """ Helper function to create a single (non-vectorized) environment.

        This is also used to create the env whenever `self.dataset` is a string that
        isn't registered in gym. This happens for example when using an environment from
        meta-world (or mtenv).
        # Check if the env is registed in a known 'third party' gym-like package, and if
        # needed, create the base env in the way that package requires.
        if isinstance(base_env, str):
            env_id = base_env

            # Check if the id belongs to mtenv
            if mtenv_installed and env_id in mtenv_envs:
                from mtenv import make

                base_env = make(env_id)
                # Add a wrapper that will remove the task information, because we use
                # the same MultiTaskEnv wrapper for all the environments.
                wrappers.insert(0, MTEnvAdapterWrapper)

            if metaworld_installed and env_id in metaworld_envs:
                # TODO: Should we use a particular benchmark here?
                # For now, we find the first benchmark that has an env with this name.
                for benchmark_class in [metaworld.ML10]:
                    benchmark = benchmark_class()
                    if env_id in benchmark.train_classes.keys():
                        # TODO: We can either let the base_env be an env type, or
                        # actually instantiate it.
                        base_env: Type[MetaWorldEnv] = benchmark.train_classes[
                        # NOTE: (@lebrice) Here I believe it's better to just have the
                        # constructor, that way we re-create the env for each task.
                        # I think this might be better, as I don't know for sure that
                        # the `set_task` can be called more than once in metaworld.
                        # base_env = base_env_type()
                    raise NotImplementedError(
                        f"Can't find a metaworld benchmark that uses env {env_id}"

        return ContinualRLSetting._make_env(

    def create_task_schedule(self, temp_env: MultiTaskEnvironment,
                             change_steps: List[int]) -> Dict[int, Dict]:
        task_schedule: Dict[int, Dict] = {}

        if monsterkong_installed:
            if isinstance(temp_env.unwrapped, MetaMonsterKongEnv):
                for i, task_step in enumerate(change_steps):
                    task_schedule[task_step] = {"level": i}
                return task_schedule

        if isinstance(temp_env.unwrapped, MTEnv):
            for i, task_step in enumerate(change_steps):
                task_schedule[task_step] = operator.methodcaller(
                    "set_task_state", i)
            return task_schedule

        if isinstance(temp_env.unwrapped, (MetaWorldEnv, MujocoEnv)):
            # TODO: Which benchmark to choose?
            base_env = temp_env.unwrapped
            found = False
            # Find the benchmark that contains this type of env.
            for benchmark_class in [metaworld.ML10]:
                benchmark = benchmark_class()
                for env_name, env_class in benchmark.train_classes.items():
                    if isinstance(base_env, env_class):
                        # Found the right benchmark that contains this env class, now
                        # create the task schedule using
                        # the tasks.
                        found = True
                if found:
            if not found:
                raise NotImplementedError(
                    f"Can't find a benchmark with env class {type(base_env)}!")

            # `benchmark` is here the right benchmark to use to create the tasks.
            training_tasks = [
                task for task in benchmark.train_tasks
                if task.env_name == env_name
            task_schedule = {
                step: operator.methodcaller("set_task", task)
                for step, task in zip(change_steps, training_tasks)
            return task_schedule

        return super().create_task_schedule(temp_env=temp_env,

    def create_train_wrappers(self):
        return super().create_train_wrappers()
コード例 #8
class Setting2(Setting1):
    bar: int = constant(1)

    def __post_init__(self):
        print(f"Setting2 __init__ ({self})")
コード例 #9
ファイル: incremental.py プロジェクト: gopeshh/Sequoia
class IncrementalSetting(ContinualSetting):
    """ Mixin that defines methods that are common to all 'incremental'
    settings, where the data is separated into tasks, and where you may not
    always get the task labels.

    Concretely, this holds the train and test loops that are common to the
    ClassIncrementalSetting (highest node on the Passive side) and ContinualRL
    (highest node on the Active side), therefore this setting, while abstract,
    is quite important.


    Results: ClassVar[Type[Results]] = IncrementalResults

    class Observations(Setting.Observations):
        """ Observations produced by an Incremental setting.

        Adds the 'task labels' to the base Observation.

        task_labels: Union[Optional[Tensor], Sequence[Optional[Tensor]]] = None

    # TODO: Actually add the 'smooth' task boundary case.
    # Wether we have clear boundaries between tasks, or if the transition is
    # smooth.
    smooth_task_boundaries: bool = constant(False)  # constant for now.

    # Wether task labels are available at train time.
    # NOTE: Forced to True at the moment.
    task_labels_at_train_time: bool = flag(default=True)
    # Wether task labels are available at test time.
    task_labels_at_test_time: bool = flag(default=False)
    # Wether we get informed when reaching the boundary between two tasks during
    # training. Only used when `smooth_task_boundaries` is False.

    # TODO: Setting constant for now, but we could add task boundary detection
    # later on!
    known_task_boundaries_at_train_time: bool = constant(True)
    # Wether we get informed when reaching the boundary between two tasks during
    # training. Only used when `smooth_task_boundaries` is False.
    known_task_boundaries_at_test_time: bool = True

    # The number of tasks. By default 0, which means that it will be set
    # depending on other fields in __post_init__, or eventually be just 1.
    nb_tasks: int = field(0, alias=["n_tasks", "num_tasks"])

    # Attributes (not parsed through the command-line):
    _current_task_id: int = field(default=0, init=False)

    # WIP: When True, a Monitor-like wrapper will be applied to the training environment
    # and monitor the 'online' performance during training. Note that in SL, this will
    # also cause the Rewards (y) to be withheld until actions are passed to the `send`
    # method of the Environment.
    monitor_training_performance: bool = False

    # Options related to Weights & Biases (wandb). Turned Off by default. Passing any of
    # its arguments will enable wandb.
    wandb: Optional[WandbConfig] = None

    def __post_init__(self, *args, **kwargs):
        super().__post_init__(*args, **kwargs)

        self.train_env: Environment = None  # type: ignore
        self.val_env: Environment = None  # type: ignore
        self.test_env: TestEnvironment = None  # type: ignore

        self.wandb_run: Optional[Run] = None

        self._start_time: Optional[float] = None
        self._end_time: Optional[float] = None
        self._setting_logged_to_wandb: bool = False

    def phases(self) -> int:
        """The number of training 'phases', i.e. how many times `method.fit` will be

        Defaults to the number of tasks, but may be different, for instance in so-called
        Multi-Task Settings, this is set to 1.
        return self.nb_tasks

    def current_task_id(self) -> Optional[int]:
        """ Get the current task id.

        TODO: Do we want to return None if the task labels aren't currently
        available? (at either Train or Test time?) Or if we 'detect' if
        this is being called from the method?

        TODO: This property doesn't really make sense in the Multi-Task SL or RL
        return self._current_task_id

    def current_task_id(self, value: int) -> None:
        """ Sets the current task id. """
        self._current_task_id = value

    def task_boundary_reached(self, method: Method, task_id: int,
                              training: bool):
        known_task_boundaries = (self.known_task_boundaries_at_train_time
                                 if training else
        task_labels_available = (self.task_labels_at_train_time if training
                                 else self.task_labels_at_test_time)

        if known_task_boundaries:
            # Inform the model of a task boundary. If the task labels are
            # available, then also give the id of the new task to the
            # method.
            # TODO: Should we also inform the method of wether or not the
            # task switch is occuring during training or testing?
            if not hasattr(method, "on_task_switch"):
                        f"On a task boundary, but since your method doesn't "
                        f"have an `on_task_switch` method, it won't know about "
                        f"it! "))
            elif not task_labels_available:
            elif self.phases == 1:
                # NOTE: on_task_switch won't be called if there is only one task.

    def main_loop(self, method: Method) -> IncrementalResults:
        """ Runs an incremental training loop, wether in RL or CL. """
        # TODO: Add ways of restoring state to continue a given run?
        # For each training task, for each test task, a list of the Metrics obtained
        # during testing on that task.
        # NOTE: We could also just store a single metric for each test task, but then
        # we'd lose the ability to create a plots to show the performance within a test
        # task.
        # IDEA: We could use a list of IIDResults! (but that might cause some circular
        # import issues)
        results = self.Results()
        if self.monitor_training_performance:
            results._online_training_performance = []

        # TODO: Fix this up, need to set the '_objective_scaling_factor' to a different
        # value depending on the 'dataset' / environment.
        results._objective_scaling_factor = self._get_objective_scaling_factor(

        if self.wandb:
            # Init wandb, and then log the setting's options.
            self.wandb_run = self.setup_wandb(method)


        self._start_time = time.process_time()

        for task_id in range(self.phases):
            logger.info(f"Starting training" +
                        (f" on task {task_id}." if self.nb_tasks > 1 else "."))
            self.current_task_id = task_id
            self.task_boundary_reached(method, task_id=task_id, training=True)

            # Creating the dataloaders ourselves (rather than passing 'self' as
            # the datamodule):
            task_train_env = self.train_dataloader()
            task_valid_env = self.val_dataloader()


            if self.monitor_training_performance:

            logger.info(f"Finished Training on task {task_id}.")
            test_metrics: TaskSequenceResults = self.test_loop(method)

            # Add a row to the transfer matrix.
                f"Resulting objective of Test Loop: {test_metrics.objective}")

            if wandb.run:
                d = add_prefix(test_metrics.to_log_dict(),
                # d = add_prefix(test_metrics.to_log_dict(), prefix="Test", sep="/")
                d["current_task"] = task_id

        self._end_time = time.process_time()
        runtime = self._end_time - self._start_time
        results._runtime = runtime
        logger.info(f"Finished main loop in {runtime} seconds.")
        self.log_results(method, results)
        return results

    def setup_wandb(self, method: Method) -> Run:
        """Call wandb.init, log the experiment configuration to the config dict.

        This assumes that `self.wandb` is not None. This happens when one of the wandb
        arguments is passed.

        method : Method
            Method to be applied.
        assert isinstance(self.wandb, WandbConfig)
        method_name: str = method.get_name()
        setting_name: str = self.get_name()

        if not self.wandb.run_name:
            # Set the default name for this run.
            run_name = f"{method_name}-{setting_name}"
            dataset = getattr(self, "dataset", None)
            if isinstance(dataset, str):
                run_name += f"-{dataset}"
            if self.nb_tasks > 1:
                run_name += f"_{self.nb_tasks}t"
            self.wandb.run_name = run_name

        run: Run = self.wandb.wandb_init()
        run.config["setting"] = setting_name
        run.config["method"] = method_name
        for k, value in self.to_dict().items():
            if not k.startswith("_"):
                run.config[f"setting/{k}"] = value

        run.summary["setting"] = self.get_name()
        run.summary["method"] = method.get_name()
        assert wandb.run is run
        return run

    def log_results(self, method: Method, results: IncrementalResults) -> None:
        TODO: Create the tabs we need to show up in wandb:
        1. Final
            - Average "Current/Online" performance (scalar)
            - Average "Final" performance (scalar)
            - Runtime
        2. Test
            - Task i (evolution over time (x axis is the task id, if possible))

        if wandb.run:
            wandb.summary["method"] = method.get_name()
            wandb.summary["setting"] = self.get_name()
            dataset = getattr(self, "dataset", "")
            if dataset and isinstance(dataset, str):
                wandb.summary["dataset"] = dataset


            # BUG: Sometimes logging a matplotlib figure causes a crash:
            # File "/home/fabrice/miniconda3/envs/sequoia/lib/python3.8/site-packages/plotly/matplotlylib/mplexporter/utils.py", line 246, in get_grid_style
            # if axis._gridOnMajor and len(gridlines) > 0:
            # AttributeError: 'XAxis' object has no attribute '_gridOnMajor'


    def test_loop(self, method: Method) -> "IncrementalSetting.Results":
        """ (WIP): Runs an incremental test loop and returns the Results.

        The idea is that this loop should be exactly the same, regardless of if
        you're on the RL or the CL side of the tree.

        NOTE: If `self.known_task_boundaries_at_test_time` is `True` and the
        method has the `on_task_switch` callback defined, then a callback
        wrapper is added that will invoke the method's `on_task_switch` and pass
        it the task id (or `None` if `not self.task_labels_available_at_test_time`)
        when a task boundary is encountered.

        This `on_task_switch` 'callback' wrapper gets added the same way for
        Supervised or Reinforcement learning settings.
        test_env = self.test_dataloader()

        test_env: TestEnvironment

        was_training = method.training

        if self.known_task_boundaries_at_test_time and self.nb_tasks > 1:

            def _on_task_switch(step: int, *arg) -> None:
                # TODO: This attribute isn't on IncrementalSetting itself, it's defined
                # on ContinualRLSetting.
                if step not in test_env.boundary_steps:
                if not hasattr(method, "on_task_switch"):
                            f"On a task boundary, but since your method doesn't "
                            f"have an `on_task_switch` method, it won't know about "
                            f"it! "))

                if self.task_labels_at_test_time:
                    # TODO: Should this 'test boundary' step depend on the batch size?
                    task_steps = sorted(test_env.boundary_steps)
                    # TODO: If the ordering of tasks were different (shuffled
                    # tasks for example), then this wouldn't work, we'd need a
                    # list of the task ids or something like that.
                    task_id = task_steps.index(step)
                        f"Calling `method.on_task_switch({task_id})` "
                        f"since task labels are available at test-time.")
                    logger.debug(f"Calling `method.on_task_switch(None)` "
                                 f"since task labels aren't available at "
                                 f"test-time, but task boundaries are known.")

            test_env = StepCallbackWrapper(test_env,

            # If the Method has `test` defined, use it.
            test_env: TestEnvironment
            # Get the metrics from the test environment
            test_results: Results = test_env.get_results()

        except NotImplementedError:
            logger.debug(f"Will query the method for actions at each step, "
                         f"since it doesn't implement a `test` method.")
            obs = test_env.reset()

            # TODO: Do we always have a maximum number of steps? or of episodes?
            # Will it work the same for Supervised and Reinforcement learning?
            max_steps: int = getattr(test_env, "step_limit", None)

            # Reset on the last step is causing trouble, since the env is closed.
            pbar = tqdm.tqdm(itertools.count(), total=max_steps, desc="Test")
            episode = 0

            for step in pbar:
                if obs is None:
                # NOTE: The env might not be closed, while `obs` is actually still there.
                # if test_env.is_closed():
                #     logger.debug(f"Env is closed")
                #     break
                # logger.debug(f"At step {step}")

                # BUG: Need to pass an action space that actually reflects the batch
                # size, even for the last batch!

                # BUG: This doesn't work if the env isn't batched.
                action_space = test_env.action_space
                batch_size = getattr(test_env, "num_envs",
                                     getattr(test_env, "batch_size", 0))
                env_is_batched = batch_size is not None and batch_size >= 1
                if env_is_batched:
                    # NOTE: Need to pass an action space that actually reflects the batch
                    # size, even for the last batch!
                    obs_batch_size = obs.x.shape[0] if obs.x.shape else None
                    action_space_batch_size = (test_env.action_space.shape[0]
                                               if test_env.action_space.shape
                                               else None)
                    if (obs_batch_size is not None
                            and obs_batch_size != action_space_batch_size):
                        action_space = batch_space(
                            test_env.single_action_space, obs_batch_size)

                action = method.get_actions(obs, action_space)

                # logger.debug(f"action: {action}")
                # TODO: Remove this:
                if isinstance(action, Actions):
                    action = action.y_pred
                if isinstance(action, Tensor):
                    action = action.cpu().numpy()

                if test_env.is_closed():

                obs, reward, done, info = test_env.step(action)

                if done and not test_env.is_closed():
                    # logger.debug(f"end of test episode {episode}")
                    obs = test_env.reset()
                    episode += 1

            test_results: TaskSequenceResults = test_env.get_results()

        # Restore 'training' mode, if it was set at the start.
        if was_training:

        return test_results
        # return test_results
        # if not self.task_labels_at_test_time:
        #     # TODO: move this wrapper to common/wrappers.
        #     test_env = RemoveTaskLabelsWrapper(test_env)

    def train_dataloader(
        self, *args, **kwargs
    ) -> Environment["IncrementalSetting.Observations", Actions, Rewards]:
        """ Returns the DataLoader/Environment for the current train task. """
        return super().train_dataloader(*args, **kwargs)

    def val_dataloader(
        self, *args, **kwargs
    ) -> Environment["IncrementalSetting.Observations", Actions, Rewards]:
        """ Returns the DataLoader/Environment used for validation on the
        current task.
        return super().val_dataloader(*args, **kwargs)

    def test_dataloader(
        self, *args, **kwargs
    ) -> Environment["IncrementalSetting.Observations", Actions, Rewards]:
        """ Returns the Test Environment (for all the tasks). """
        return super().test_dataloader(*args, **kwargs)

    def _get_objective_scaling_factor(self) -> float:
        return 1.0
コード例 #10
class IncrementalRLSetting(IncrementalAssumption,
    """ Continual RL setting in which:
    - Changes in the environment's context occur suddenly (same as in Discrete, Task-Agnostic RL)
    - Task boundary information (and task labels) are given at training time
    - Task boundary information is given at test time, but task identity is not.

    Observations: ClassVar[Type[Observations]] = Observations
    Actions: ClassVar[Type[Actions]] = Actions
    Rewards: ClassVar[Type[Rewards]] = Rewards

    # The function used to create the tasks for the chosen env.
    _task_sampling_function: ClassVar[Callable[
        ..., IncrementalTask]] = make_incremental_task
    Results: ClassVar[Type[Results]] = IncrementalRLResults

    # Class variable that holds the dict of available environments.
    available_datasets: ClassVar[Dict[str, str]] = available_datasets
    # Which dataset/environment to use for training, validation and testing.
    dataset: str = choice(available_datasets, default="CartPole-v0")

    # # The number of tasks. By default 0, which means that it will be set
    # # depending on other fields in __post_init__, or eventually be just 1.
    # nb_tasks: int = field(0, alias=["n_tasks", "num_tasks"])

    # (Copied from the assumption, just for clarity:)
    # TODO: Shouldn't these kinds of properties be on the class, rather than on the
    # instance?

    # Wether the task boundaries are smooth or sudden.
    smooth_task_boundaries: Final[bool] = constant(False)
    # Wether to give access to the task labels at train time.
    task_labels_at_train_time: Final[bool] = constant(True)
    # Wether to give access to the task labels at test time.
    task_labels_at_test_time: bool = False

    train_envs: List[Union[str, Callable[[], gym.Env]]] = list_field()
    val_envs: List[Union[str, Callable[[], gym.Env]]] = list_field()
    test_envs: List[Union[str, Callable[[], gym.Env]]] = list_field()

    def __post_init__(self):
        if not self.nb_tasks:
            # TODO: In case of the metaworld envs, we could determine the 'default' nb
            # of tasks to use based on the number of available tasks

        if self.dataset == "MT10":
            from metaworld import MT10, Task, MetaWorldEnv

            self._benchmark = MT10(
                seed=self.config.seed if self.config else None)
            envs: Dict[str, Type[MetaWorldEnv]] = self._benchmark.train_classes
            env_tasks: Dict[str, List[Task]] = {
                env_name: [
                    task for task in self._benchmark.train_tasks
                    if task.env_name == env_name
                for env_name, env_class in
            from itertools import islice
            train_env_tasks: Dict[str, List[Task]] = {}
            val_env_tasks: Dict[str, List[Task]] = {}
            test_env_tasks: Dict[str, List[Task]] = {}
            test_fraction = 0.1
            val_fraction = 0.1
            for env_name, env_tasks in env_tasks.items():
                n_tasks = len(env_tasks)
                n_val_tasks = int(max(1, n_tasks * val_fraction))
                n_test_tasks = int(max(1, n_tasks * test_fraction))
                n_train_tasks = len(env_tasks) - n_val_tasks - n_test_tasks
                if n_train_tasks <= 1:
                    # Can't create train, val and test tasks.
                    raise RuntimeError(
                        f"There aren't enough tasks for env {env_name} ({n_tasks}) "
                tasks_iterator = iter(env_tasks)
                train_env_tasks[env_name] = list(
                    islice(tasks_iterator, n_train_tasks))
                val_env_tasks[env_name] = list(
                    islice(tasks_iterator, n_val_tasks))
                test_env_tasks[env_name] = list(
                    islice(tasks_iterator, n_test_tasks))
                assert train_env_tasks[env_name]
                assert val_env_tasks[env_name]
                assert test_env_tasks[env_name]

            from ..discrete.multienv_wrappers import RandomMultiEnvWrapper
            # TODO: Fix the naming of this MultiTaskEnvironment wrapper:
            from sequoia.common.gym_wrappers import MultiTaskEnvironment
            import operator
            # NOTE: We could do some shuffling here, for instance.
            train_env_names, train_env_classes = zip(
            val_env_names, val_env_classes = zip(*list(val_env_tasks.items()))
            test_env_names, test_env_classes = zip(

            self.train_envs = [
                for env_name, tasks in train_env_tasks.items()
            self.val_envs = [
                for env_name, tasks in val_env_tasks.items()
            self.test_envs = [
                for env_name, tasks in test_env_tasks.items()

        # if is_monsterkong_env(self.dataset):
        #     if self.force_pixel_observations:
        #         # Add this to the kwargs that will be passed to gym.make, to make sure that
        #         # we observe pixels, and not state.
        #         self.base_env_kwargs["observe_state"] = False
        #     elif self.force_state_observations:
        #         self.base_env_kwargs["observe_state"] = True

        self._using_custom_envs_foreach_task: bool = False
        if self.train_envs:
            self._using_custom_envs_foreach_task = True
            self.nb_tasks = len(self.train_envs)
            # TODO: Not sure what to do with the `self.dataset` field, because
            # ContinualRLSetting expects to have a single env, while we have more than
            # one, the __post_init__ tries to create the rest of the fields based on
            # `self.dataset`
            self.dataset = self.train_envs[0]

            if not self.val_envs:
                # TODO: Use a wrapper that sets a different random seed
                self.val_envs = self.train_envs.copy()
            if not self.test_envs:
                # TODO: Use a wrapper that sets a different random seed
                self.test_envs = self.train_envs.copy()
            if (self.train_task_schedule or self.val_task_schedule
                    or self.test_task_schedule):
                raise RuntimeError(
                    "You can either pass `train/valid/test_envs`, or a task schedule, "
                    "but not both!")
            if self.val_envs or self.test_envs:
                raise RuntimeError(
                    "Can't pass `val_envs` or `test_envs` without passing `train_envs`."


        if self._using_custom_envs_foreach_task:
            # TODO: Use 'no-op' task schedules for now.
            # self.train_task_schedule.clear()
            # self.val_task_schedule.clear()
            # self.test_task_schedule.clear()

            # TODO: Check that all the envs have the same observation spaces!
            # (If possible, find a way to check this without having to instantiate all
            # the envs.)

        # TODO: If the dataset has a `max_path_length` attribute, then it's probably
        # a Mujoco / metaworld / etc env, and so we set a limit on the episode length to
        # avoid getting an error.
        max_path_length: Optional[int] = getattr(self._temp_train_env,
                                                 "max_path_length", None)
        if self.max_episode_steps is None and max_path_length is not None:
            assert max_path_length > 0
            self.max_episode_steps = max_path_length

        # if self.dataset == "MetaMonsterKong-v0":
        #     # TODO: Limit the episode length in monsterkong?
        #     # TODO: Actually end episodes when reaching a task boundary, to force the
        #     # level to change?
        #     self.max_episode_steps = self.max_episode_steps or 500

        # FIXME: Really annoying little bugs with these three arguments!
        # self.nb_tasks = self.max_steps // self.steps_per_task

    def train_task_lengths(self) -> List[int]:
        """ Gives the length of each training task (in steps for now). """
        return [
            task_b_step - task_a_step for task_a_step, task_b_step in pairwise(

    def train_phase_lengths(self) -> List[int]:
        """ Gives the length of each training 'phase', i.e. the maximum number of (steps
        for now) that can be taken in the training environment, in a single call to .fit
        return [
            task_b_step - task_a_step for task_a_step, task_b_step in pairwise(

    def current_train_task_length(self) -> int:
        """ Deprecated field, gives back the max number of steps per task. """
        if self.stationary_context:
            return sum(self.train_task_lengths)
        return self.train_task_lengths[self.current_task_id]

    # steps_per_task: int = deprecated_property(
    #     "steps_per_task", "current_train_task_length"
    # )
    # @property
    # def steps_per_task(self) -> int:
    #     # unique_task_lengths = list(set(self.train_task_lengths))
    #     warning = DeprecationWarning(
    #         "The 'steps_per_task' attribute is deprecated, use "
    #         "`current_train_task_length` instead, which gives the length of the "
    #         "current task."
    #     )
    #     warnings.warn(warning)
    #     logger.warning(warning)
    #     return self.current_train_task_length

    # @property
    # def steps_per_phase(self) -> int:
    #     # return unique_task_lengths
    #     test_task_lengths: List[int] = [
    #         task_b_step - task_a_step
    #         for task_a_step, task_b_step in pairwise(
    #             sorted(self.test_task_schedule.keys())
    #         )
    #     ]

    def task_label_space(self) -> gym.Space:
        # TODO: Explore an alternative design for the task sampling, based more around
        # gym spaces rather than the generic function approach that's currently used?
        # IDEA: Might be cleaner to put this in the assumption class
        task_label_space = spaces.Discrete(self.nb_tasks)
        if not self.task_labels_at_train_time or not self.task_labels_at_test_time:
            sparsity = 1
            if self.task_labels_at_train_time ^ self.task_labels_at_test_time:
                # We have task labels "50%" of the time, ish:
                sparsity = 0.5
            task_label_space = Sparse(task_label_space, sparsity=sparsity)
        return task_label_space

    def setup(self, stage: str = None) -> None:
        # Called before the start of each task during training, validation and
        # testing.
        # What's done in ContinualRLSetting:
        # if stage in {"fit", None}:
        #     self.train_wrappers = self.create_train_wrappers()
        #     self.valid_wrappers = self.create_valid_wrappers()
        # elif stage in {"test", None}:
        #     self.test_wrappers = self.create_test_wrappers()
        if self._using_custom_envs_foreach_task:
                f"Using custom environments from `self.[train/val/test]_envs` for task "
            # NOTE: Here is how this supports passing custom envs for each task: We just
            # switch out the value of this property, and let the
            # `train/val/test_dataloader` methods work as usual!
            self.dataset = self.train_envs[self.current_task_id]
            self.val_dataset = self.val_envs[self.current_task_id]
            # TODO: The test loop goes through all the envs, hence this doesn't really
            # work.
            self.test_dataset = self.test_envs[self.current_task_id]

            # TODO: Check that the observation/action spaces are all the same for all
            # the train/valid/test envs
            # TODO: Inconsistent naming between `val_envs` and `valid_wrappers` etc.
            # TODO: Should we populate the `self.train_envs`, `self.val_envs` and
            # `self.test_envs` fields here as well, just to be consistent?
            # base_env = self.dataset
            # def task_env(task_index: int) -> Callable[[], MultiTaskEnvironment]:
            #     return self._make_env(
            #         base_env=base_env,
            #         wrappers=[],
            #     )
            # self.train_envs = [partial(gym.make, self.dataset) for i in range(self.nb_tasks)]
            # self.val_envs = [partial(gym.make, self.dataset) for i in range(self.nb_tasks)]
            # self.test_envs = [partial(gym.make, self.dataset) for i in range(self.nb_tasks)]
            # assert False, self.train_task_schedule

    # def _setup_fields_using_temp_env(self, temp_env: MultiTaskEnvironment):
    #     """ Setup some of the fields on the Setting using a temporary environment.

    #     This temporary environment only lives during the __post_init__() call.
    #     """
    #     super()._setup_fields_using_temp_env(temp_env)
    def test_dataloader(self,
                        batch_size: Optional[int] = None,
                        num_workers: Optional[int] = None):
        if not self._using_custom_envs_foreach_task:
            return super().test_dataloader(batch_size=batch_size,

        # raise NotImplementedError("TODO:")

        # IDEA: Pretty hacky, but might be cleaner than adding fields for the moment.
        test_max_steps = self.test_max_steps
        test_max_episodes = self.test_max_episodes
        self.test_max_steps = test_max_steps // self.nb_tasks
        if self.test_max_episodes:
            self.test_max_episodes = test_max_episodes // self.nb_tasks
        # self.test_env = self.TestEnvironment(self.test_envs[self.current_task_id])

        task_test_env = super().test_dataloader(batch_size=batch_size,

        self.test_max_steps = test_max_steps
        self.test_max_episodes = test_max_episodes
        return task_test_env

    def test_loop(self, method: Method["IncrementalRLSetting"]):
        if not self._using_custom_envs_foreach_task:
            return super().test_loop(method)

        # TODO: If we're using custom envs for each task, then the test loop needs to be
        # re-organized.
        # raise NotImplementedError(
        #     f"TODO: Need to add a wrapper that can switch between envs, or "
        #     f"re-write the test loop."
        # )
        assert self.nb_tasks == len(self.test_envs), "assuming this for now."
        test_envs = []
        for task_id in range(self.nb_tasks):
            # TODO: Make sure that self.test_dataloader() uses the right number of steps
            # per test task (current hard-set to self.test_max_steps).
            task_test_env = self.test_dataloader()
        from ..discrete.multienv_wrappers import ConcatEnvsWrapper

        on_task_switch_callback = getattr(method, "on_task_switch", None)
        joined_test_env = ConcatEnvsWrapper(
        # TODO: Use this 'joined' test environment in this test loop somehow.
        # IDEA: Hacky way to do it: (I don't think this will work as-is though)
        _test_dataloader_method = self.test_dataloader
        self.test_dataloader = lambda *args, **kwargs: joined_test_env
        self.test_dataloader = _test_dataloader_method

        test_loop_results = DiscreteTaskAgnosticRLSetting.Results()
        for task_id, test_env in enumerate(test_envs):
            # TODO: The results are still of the wrong type, because we aren't changing
            # the type of test environment or the type of Results
            results_of_wrong_type: IncrementalRLResults = test_env.get_results(
            # For now this weird setup means that there will be only one 'result'
            # object in this that actually has metrics:
            # assert results_of_wrong_type.task_results[task_id].metrics
            all_metrics: List[EpisodeMetrics] = sum([
                result.metrics for result in results_of_wrong_type.task_results
            ], [])
            n_metrics_in_each_result = [
                for result in results_of_wrong_type.task_results
            # assert all(n_metrics == 0 for i, n_metrics in enumerate(n_metrics_in_each_result) if i != task_id), (n_metrics_in_each_result, task_id)
            # TODO: Also transfer the other properties like runtime, online performance,
            # etc?
            # TODO: Maybe add addition for these?
            # task_result = sum(results_of_wrong_type.task_results)
            task_result = TaskResults(metrics=all_metrics)
            # task_result: TaskResults[EpisodeMetrics] = results_of_wrong_type.task_results[task_id]
        return test_loop_results

    def phases(self) -> int:
        """The number of training 'phases', i.e. how many times `method.fit` will be

        In this Incremental-RL Setting, fit is called once per task.
        (Same as ClassIncrementalSetting in SL).
        return self.nb_tasks

    def _make_env(
        base_env: Union[str, gym.Env, Callable[[], gym.Env]],
        wrappers: List[Callable[[gym.Env], gym.Env]] = None,
        **base_env_kwargs: Dict,
    ) -> gym.Env:
        """ Helper function to create a single (non-vectorized) environment.

        This is also used to create the env whenever `self.dataset` is a string that
        isn't registered in gym. This happens for example when using an environment from
        meta-world (or mtenv).
        # Check if the env is registed in a known 'third party' gym-like package, and if
        # needed, create the base env in the way that package requires.
        if isinstance(base_env, str):
            env_id = base_env

            # Check if the id belongs to mtenv
            if MTENV_INSTALLED and env_id in mtenv_envs:
                from mtenv import make as mtenv_make

                # This is super weird. Don't undestand at all
                # why they are doing this. Makes no sense to me whatsoever.
                base_env = mtenv_make(env_id, **base_env_kwargs)

                # Add a wrapper that will remove the task information, because we use
                # the same MultiTaskEnv wrapper for all the environments.
                wrappers.insert(0, MTEnvAdapterWrapper)

            if METAWORLD_INSTALLED and env_id in metaworld_envs:
                # TODO: Should we use a particular benchmark here?
                # For now, we find the first benchmark that has an env with this name.
                import metaworld

                for benchmark_class in [metaworld.ML10]:
                    benchmark = benchmark_class()
                    if env_id in benchmark.train_classes.keys():
                        # TODO: We can either let the base_env be an env type, or
                        # actually instantiate it.
                        base_env: Type[MetaWorldEnv] = benchmark.train_classes[
                        # NOTE: (@lebrice) Here I believe it's better to just have the
                        # constructor, that way we re-create the env for each task.
                        # I think this might be better, as I don't know for sure that
                        # the `set_task` can be called more than once in metaworld.
                        # base_env = base_env_type()
                    raise NotImplementedError(
                        f"Can't find a metaworld benchmark that uses env {env_id}"

        return ContinualRLSetting._make_env(

    def create_task_schedule(self, temp_env: gym.Env,
                             change_steps: List[int]) -> Dict[int, Dict]:
        task_schedule: Dict[int, Dict] = {}
        if self._using_custom_envs_foreach_task:
            # If custom envs were passed to be used for each task, then we don't create
            # a "task schedule", because the only reason we're using a task schedule is
            # when we want to change something about the 'base' env in order to get
            # multiple tasks.
            # Create a task schedule dict, just to fit in?
            for i, task_step in enumerate(change_steps):
                task_schedule[task_step] = {}
            return task_schedule

        # TODO: Make it possible to use something other than steps as keys in the task
        # schedule, something like a NamedTuple[int, DeltaType], e.g. Episodes(10) or
        # Steps(10), something like that!
        # IDEA: Even fancier, we could use a TimeDelta to say "do one hour of task 0"!!
        for step in change_steps:
            # TODO: Add a `stage` argument (an enum or something with 'train', 'valid'
            # 'test' as values, and pass it to this function. Tasks should be the same
            # in train/valid for now, given the same task Id.
            # TODO: When the Results become able to handle a different ordering of tasks
            # at train vs test time, allow the test task schedule to have different
            # ordering than train / valid.
            task = type(self)._task_sampling_function(
                seed=self.config.seed if self.config else None,
            task_schedule[step] = task

        return task_schedule

    def create_train_wrappers(self):
        # TODO: Clean this up a bit?
        if self._using_custom_envs_foreach_task:
            # TODO: Maybe do something different here, since we don't actually want to
            # add a CL wrapper at all in this case?
            assert not any(self.train_task_schedule.values())
            base_env = self.train_envs[self.current_task_id]
            base_env = self.train_dataset
        # assert False, super().create_train_wrappers()
        if self.stationary_context:
            task_schedule_slice = self.train_task_schedule.copy()
            assert len(task_schedule_slice) >= 2
            # Need to pop the last task, so that we don't sample it by accident!
            max_step = max(task_schedule_slice)
            last_task = task_schedule_slice.pop(max_step)
            # TODO: Shift the second-to-last task to the last step
            last_boundary = max(task_schedule_slice)
            second_to_last_task = task_schedule_slice.pop(last_boundary)
            task_schedule_slice[max_step] = second_to_last_task
            if 0 not in task_schedule_slice:
                assert self.nb_tasks == 1
                task_schedule_slice[0] = second_to_last_task
            # assert False, (max_step, last_boundary, last_task, second_to_last_task)
            current_task = list(
            task_length = self.train_max_steps // self.nb_tasks
            task_schedule_slice = {
                0: current_task,
                task_length: current_task,
        return self._make_wrappers(
            # TODO: Removing this, but we have to check that it doesn't change when/how
            # the task boundaries are given to the Method.
            # sharp_task_boundaries=self.known_task_boundaries_at_train_time,
            transforms=self.transforms + self.train_transforms,

    def create_valid_wrappers(self):
        if self._using_custom_envs_foreach_task:
            # TODO: Maybe do something different here, since we don't actually want to
            # add a CL wrapper at all in this case?
            assert not any(self.val_task_schedule.values())
            base_env = self.val_envs[self.current_task_id]
            base_env = self.val_dataset
        # assert False, super().create_train_wrappers()
        if self.stationary_context:
            task_schedule_slice = self.val_task_schedule
            current_task = list(
            task_length = self.train_max_steps // self.nb_tasks
            task_schedule_slice = {
                0: current_task,
                task_length: current_task,
        return self._make_wrappers(
            # TODO: Removing this, but we have to check that it doesn't change when/how
            # the task boundaries are given to the Method.
            # sharp_task_boundaries=self.known_task_boundaries_at_train_time,
            transforms=self.transforms + self.val_transforms,

    def create_test_wrappers(self):
        if self._using_custom_envs_foreach_task:
            # TODO: Maybe do something different here, since we don't actually want to
            # add a CL wrapper at all in this case?
            assert not any(self.test_task_schedule.values())
            base_env = self.test_envs[self.current_task_id]
            base_env = self.test_dataset
        # assert False, super().create_train_wrappers()
        task_schedule_slice = self.test_task_schedule
        # if self.stationary_context:
        # else:
        #     current_task = list(self.test_task_schedule.values())[self.current_task_id]
        #     task_length = self.test_max_steps // self.nb_tasks
        #     task_schedule_slice = {
        #         0: current_task,
        #         task_length: current_task,
        #     }
        return self._make_wrappers(
            # TODO: Removing this, but we have to check that it doesn't change when/how
            # the task boundaries are given to the Method.
            # sharp_task_boundaries=self.known_task_boundaries_at_train_time,
            transforms=self.transforms + self.test_transforms,

    def _check_all_envs_have_same_spaces(
        envs_or_env_functions: List[Union[str, gym.Env, Callable[[],
        wrappers: List[Callable[[gym.Env], gym.Wrapper]],
    ) -> None:
        """ Checks that all the environments in the list have the same
        observation/action spaces.
        if self._using_custom_envs_foreach_task:
            # TODO: Removing this check for now.
        first_env = self._make_env(base_env=envs_or_env_functions[0],
        for task_id, task_env_id_or_function in zip(
                range(1, len(envs_or_env_functions)),
            task_env = self._make_env(
            if task_env.observation_space != first_env.observation_space:
                raise RuntimeError(
                    f"Env at task {task_id} doesn't have the same observation "
                    f"space ({task_env.observation_space}) as the environment of "
                    f"the first task: {first_env.observation_space}.")
            if task_env.action_space != first_env.action_space:
                raise RuntimeError(
                    f"Env at task {task_id} doesn't have the same action "
                    f"space ({task_env.action_space}) as the environment of "
                    f"the first task: {first_env.action_space}")

    def _make_wrappers(
        base_env: Union[str, gym.Env, Callable[[], gym.Env]],
        task_schedule: Dict[int, Dict],
        # sharp_task_boundaries: bool,
        task_labels_available: bool,
        transforms: List[Transforms],
        starting_step: int,
        max_steps: int,
        new_random_task_on_reset: bool,
    ) -> List[Callable[[gym.Env], gym.Env]]:
        if self._using_custom_envs_foreach_task:
            if task_schedule:
                        f"Ignoring task schedule {task_schedule}, since custom envs were "
                        f"passed for each task!"))
            task_schedule = None

        wrappers = super()._make_wrappers(

        if self._using_custom_envs_foreach_task:
            # If the user passed a specific env to use for each task, then there won't
            # be a MultiTaskEnv wrapper in `wrappers`, since the task schedule is
            # None/empty.
            # Instead, we will add a Wrapper that always gives the task ID of the
            # current task.

            # TODO: There are some 'unused' args above: `starting_step`, `max_steps`,
            # `new_random_task_on_reset` which are still passed to the super() call, but
            # just unused.
            if new_random_task_on_reset:
                raise NotImplementedError(
                    "TODO: Add a MultiTaskEnv wrapper of some sort that alternates "
                    " between the source envs.")

            assert not task_schedule
            task_label = self.current_task_id
            task_label_space = spaces.Discrete(self.nb_tasks)
            if not task_labels_available:
                task_label = None
                task_label_space = Sparse(task_label_space, sparsity=1.0)


        if is_monsterkong_env(base_env):
            # TODO: Need to register a MetaMonsterKong-State-v0 or something like that!
            # TODO: Maybe add another field for 'force_state_observations' ?
            # if self.force_pixel_observations:

        return wrappers
コード例 #11
class TaskIncrementalSetting(IncrementalSetting):
    """ Assumption (mixin) for Settings where the task labels are available at
    both train and test time.
    task_labels_at_train_time: bool = constant(True)
    task_labels_at_test_time: bool = constant(True)