예제 #1
class ContinualRLSetting(ActiveSetting, IncrementalSetting):
    """ Reinforcement Learning Setting where the environment changes over time.

    This is an Active setting which uses gym environments as sources of data.
    These environments' attributes could change over time following a task
    schedule. An example of this could be that the gravity increases over time
    in cartpole, making the task progressively harder as the agent interacts with
    the environment.

    # The type of results returned by an RL experiment.
    Results: ClassVar[Type[Results]] = RLResults

    class Observations(IncrementalSetting.Observations):
        """ Observations in a continual RL Setting. """

        # Just as a reminder, these are the fields defined in the base classes:
        # x: Tensor
        # task_labels: Union[Optional[Tensor], Sequence[Optional[Tensor]]] = None

        # The 'done' part of the 'step' method. We add this here in case a
        # method were to iterate on the environments in the dataloader-style so
        # they also have access to those (i.e. for the BaselineMethod).
        done: Optional[Sequence[bool]] = None
        # Same, for the 'info' portion of the result of 'step'.
        # TODO: If we add the 'task space' (with all the attributes, for instance
        # then add it to the observations using the `AddInfoToObservations`.
        # info: Optional[Sequence[Dict]] = None

    # Image transforms to use.
    transforms: List[Transforms] = list_field()

    # Class variable that holds the dict of available environments.
    available_datasets: ClassVar[Dict[str, str]] = {
        "cartpole": "CartPole-v0",
        "pendulum": "Pendulum-v0",
        "breakout": "Breakout-v0",
        # "duckietown": "Duckietown-straight_road-v0"
    # TODO: Add breakout to 'available_datasets' only when atari_py is installed.

    # Which environment (a.k.a. "dataset") to learn on.
    # The dataset could be either a string (env id or a key from the
    # available_datasets dict), a gym.Env, or a callable that returns a single environment.
    # If self.dataset isn't one of those, an error will be raised.
    dataset: str = choice(available_datasets, default="cartpole")

    # The number of tasks. By default 1 for this setting.
    nb_tasks: int = field(1, alias=["n_tasks", "num_tasks"])

    # Max number of steps per task. (Also acts as the "length" of the training
    # and validation "Datasets")
    max_steps: int = 100_000
    # Maximum episodes per task.
    # TODO: Test that the limit on the number of episodes actually works.
    max_episodes: Optional[int] = None
    # Number of steps per task. When left unset and when `max_steps` is set,
    # takes the value of `max_steps` divided by `nb_tasks`.
    steps_per_task: Optional[int] = None
    # (WIP): Number of episodes per task.
    episodes_per_task: Optional[int] = None

    # Total number of steps in the test loop. (Also acts as the "length" of the testing
    # environment.)
    test_steps: int = 10_000
    # Number of steps per task in the test loop. When left unset and when `test_steps`
    # is set, takes the value of `test_steps` divided by `nb_tasks`.
    test_steps_per_task: Optional[int] = None

    # Standard deviation of the multiplicative Gaussian noise that is used to
    # create the values of the env attributes for each task.
    task_noise_std: float = 0.2

    # Wether the task boundaries are smooth or sudden.
    smooth_task_boundaries: bool = True

    # Wether to observe the state directly, rather than pixels. This can be
    # useful to debug environments like CartPole, for instance.
    observe_state_directly: bool = False

    # Path to a json file from which to read the train task schedule.
    train_task_schedule_path: Optional[Path] = None
    # Path to a json file from which to read the validation task schedule.
    valid_task_schedule_path: Optional[Path] = None
    # Path to a json file from which to read the test task schedule.
    test_task_schedule_path: Optional[Path] = None

    # Wether observations from the environments whould include
    # the end-of-episode signal. Only really useful if your method will iterate
    # over the environments in the dataloader style
    # (as does the baseline method).
    add_done_to_observations: bool = False

    # The maximum number of steps per episode. When None, there is no limit.
    max_episode_steps: Optional[int] = None

    # NOTE: Added this `cmd=False` option to mark that we don't want to generate
    # any command-line arguments for these fields.
    train_task_schedule: Dict[int, Dict[str, float]] = dict_field(cmd=False)
    valid_task_schedule: Dict[int, Dict[str, float]] = dict_field(cmd=False)
    test_task_schedule: Dict[int, Dict[str, float]] = dict_field(cmd=False)

    train_wrappers: List[Callable[[gym.Env], gym.Env]] = list_field(cmd=False)
    valid_wrappers: List[Callable[[gym.Env], gym.Env]] = list_field(cmd=False)
    test_wrappers: List[Callable[[gym.Env], gym.Env]] = list_field(cmd=False)

    batch_size: Optional[int] = field(default=None, cmd=False)
    num_workers: Optional[int] = field(default=None, cmd=False)

    def __post_init__(self, *args, **kwargs):
        super().__post_init__(*args, **kwargs)
        self._new_random_task_on_reset: bool = False

        # Post processing of the 'dataset' field:
        if self.dataset in self.available_datasets.keys():
            # the environment name was passed, rather than an id
            # (e.g. 'cartpole' -> 'CartPole-v0").
            self.dataset = self.available_datasets[self.dataset]

        elif self.dataset not in self.available_datasets.values():
            # The passed dataset is assumed to be an environment ID, but it
            # wasn't in the dict of available datasets! We issue a warning, but
            # proceed to let the user use whatever environment they want to.
                    f"The chosen dataset/environment ({self.dataset}) isn't in the "
                    f"available_datasets dict, so we can't garantee this will work!"

        if isinstance(self.dataset, gym.Env) and self.batch_size:
            raise RuntimeError(
                f"Batch size should be None when a gym.Env "
                f"object is passed as `dataset`."
        if not isinstance(self.dataset, (str, gym.Env)) and not callable(self.dataset):
            raise RuntimeError(
                f"`dataset` must be either a string, a gym.Env, or a callable. "
                f"(got {self.dataset})"

        # Set the number of tasks depending on the increment, and vice-versa.
        # (as only one of the two should be used).
        assert self.max_steps, "assuming this should always be set, for now."
        # TODO: Clean this up, not super clear what options take precedence on
        # which other options.

        # Load the task schedules from the corresponding files, if present.
        if self.train_task_schedule_path:
            self.train_task_schedule = self.load_task_schedule(

        if self.valid_task_schedule_path:
            self.valid_task_schedule = self.load_task_schedule(

        if self.test_task_schedule_path:
            self.test_task_schedule = self.load_task_schedule(

        if self.train_task_schedule:
            if self.steps_per_task is not None:
                # If steps per task was passed, then we overwrite the keys of the tasks
                # schedule.
                self.train_task_schedule = {
                    i * self.steps_per_task: self.train_task_schedule[step]
                    for i, step in enumerate(sorted(self.train_task_schedule.keys()))
                # A task schedule was passed: infer the number of tasks from it.
                change_steps = sorted(self.train_task_schedule.keys())
                assert 0 in change_steps, "Schedule needs a task at step 0."
                # TODO: @lebrice: I guess we have to assume that the interval
                # between steps is constant for now? Do we actually depend on this
                # being the case? I think steps_per_task is only really ever used
                # for creating the task schedule, which we already have in this
                # case.
                assert (
                    len(change_steps) >= 2
                ), "WIP: need a minimum of two tasks in the task schedule for now."
                self.steps_per_task = change_steps[1] - change_steps[0]
                # Double-check that this is the case.
                for i in range(len(change_steps) - 1):
                    if change_steps[i + 1] - change_steps[i] != self.steps_per_task:
                        raise NotImplementedError(
                            f"WIP: This might not work yet if the tasks aren't "
                            f"equally spaced out at a fixed interval."

            nb_tasks = len(self.train_task_schedule)
            if self.smooth_task_boundaries:
                # NOTE: When in a ContinualRLSetting with smooth task boundaries,
                # the last entry in the schedule represents the state of the env at
                # the end of the "task". When there are clear task boundaries (i.e.
                # when in 'Class'/Task-Incremental RL), the last entry is the start
                # of the last task.
                nb_tasks -= 1
            if self.nb_tasks != 1:
                if self.nb_tasks != nb_tasks:
                    raise RuntimeError(
                        f"Passed number of tasks {self.nb_tasks} doesn't match the "
                        f"number of tasks deduced from the task schedule ({nb_tasks})"
            self.nb_tasks = nb_tasks

            self.max_steps = max(self.train_task_schedule.keys())
            if not self.smooth_task_boundaries:
                # See above note about the last entry.
                self.max_steps += self.steps_per_task

        elif self.nb_tasks:
            if self.steps_per_task:
                self.max_steps = self.nb_tasks * self.steps_per_task
            elif self.max_steps:
                self.steps_per_task = self.max_steps // self.nb_tasks

        elif self.steps_per_task:
            if self.nb_tasks:
                self.max_steps = self.nb_tasks * self.steps_per_task
            elif self.max_steps:
                self.nb_tasks = self.max_steps // self.steps_per_task

        elif self.max_steps:
            if self.nb_tasks:
                self.steps_per_task = self.max_steps // self.nb_tasks
            elif self.steps_per_task:
                self.nb_tasks = self.max_steps // self.steps_per_task

        if not all([self.nb_tasks, self.max_steps, self.steps_per_task]):
            raise RuntimeError(
                f"You need to provide at least two of 'max_steps', "
                f"'nb_tasks', or 'steps_per_task'."

        assert self.max_steps == self.nb_tasks * self.steps_per_task

        if self.test_task_schedule:
            if 0 not in self.test_task_schedule:
                raise RuntimeError("Task schedules needs to include an initial task.")

            if self.test_steps_per_task is not None:
                # If steps per task was passed, then we overwrite the number of steps
                # for each task in the schedule to match.
                self.test_task_schedule = {
                    i * self.test_steps_per_task: self.test_task_schedule[step]
                    for i, step in enumerate(sorted(self.test_task_schedule.keys()))

            change_steps = sorted(self.test_task_schedule.keys())
            assert 0 in change_steps, "Schedule needs to include task at step 0."

            nb_test_tasks = len(change_steps)
            if self.smooth_task_boundaries:
                nb_test_tasks -= 1
            assert (
                nb_test_tasks == self.nb_tasks
            ), "nb of tasks should be the same for train and test."

            self.test_steps_per_task = change_steps[1] - change_steps[0]
            for i in range(self.nb_tasks - 1):
                if change_steps[i + 1] - change_steps[i] != self.test_steps_per_task:
                    raise NotImplementedError(
                        "WIP: This might not work yet if the test tasks aren't "
                        "equally spaced out at a fixed interval."

            self.test_steps = max(change_steps)
            if not self.smooth_task_boundaries:
                # See above note about the last entry.
                self.test_steps += self.test_steps_per_task

        elif self.test_steps_per_task is None:
            # This is basically never the case, since the test_steps defaults to 10_000.
            assert (
            ), "need to set one of test_steps or test_steps_per_task"
            self.test_steps_per_task = self.test_steps // self.nb_tasks
            # FIXME: This is too complicated for what is is.
            # Check that the test steps must either be the default value, or the right
            # value to use in this case.
            assert self.test_steps in {10_000, self.test_steps_per_task * self.nb_tasks}
            assert (
            ), "need to set one of test_steps or test_steps_per_task"
            self.test_steps = self.test_steps_per_task * self.nb_tasks

        assert self.test_steps // self.test_steps_per_task == self.nb_tasks

        if self.smooth_task_boundaries:
            # If we're operating in the 'Online/smooth task transitions' "regime",
            # then there is only one "task", and we don't have task labels.
            # TODO: HOWEVER, the task schedule could/should be able to have more
            # than one non-stationarity! This indicates a need for a distinction
            # between 'tasks' and 'non-stationarities' (changes in the env).
            self.known_task_boundaries_at_train_time = False
            self.known_task_boundaries_at_test_time = False
            self.task_labels_at_train_time = False
            self.task_labels_at_test_time = False
            # self.steps_per_task = self.max_steps

        # Task schedules for training / validation and testing.

        # Create a temporary environment so we can extract the spaces and create
        # the task schedules.
        with self._make_env(
            self.dataset, self._temp_wrappers(), self.observe_state_directly
        ) as temp_env:
            # FIXME: Replacing the observation space dtypes from their original
            # 'generated' NamedTuples to self.Observations. The alternative
            # would be to add another argument to the MultiTaskEnv wrapper, to
            # pass down a dtype to be set on its observation_space's `dtype`
            # attribute, which would be ugly.
            assert isinstance(temp_env.observation_space, NamedTupleSpace)
            temp_env.observation_space.dtype = self.Observations
            # Populate the task schedules created above.
            if not self.train_task_schedule:
                train_change_steps = list(range(0, self.max_steps, self.steps_per_task))
                if self.smooth_task_boundaries:
                    # Add a last 'task' at the end of the 'epoch', so that the
                    # env changes smoothly right until the end.
                self.train_task_schedule = self.create_task_schedule(
                    temp_env, train_change_steps,

            assert self.train_task_schedule is not None
            # The validation task schedule is the same as the one used in
            # training by default.
            if not self.valid_task_schedule:
                self.valid_task_schedule = deepcopy(self.train_task_schedule)

            if not self.test_task_schedule:
                # The test task schedule is by default the same as in validation
                # except that the interval between the tasks may be different,
                # depending on the value of `self.test_steps_per_task`.
                valid_steps = sorted(self.valid_task_schedule.keys())
                valid_tasks = [self.valid_task_schedule[step] for step in valid_steps]
                self.test_task_schedule = {
                    i * self.test_steps_per_task: deepcopy(task)
                    for i, task in enumerate(valid_tasks)

            # Set the spaces using the temp env.
            self.observation_space = temp_env.observation_space
            self.action_space = temp_env.action_space
            self.reward_range = temp_env.reward_range
            self.reward_space = getattr(
                    low=self.reward_range[0], high=self.reward_range[1], shape=()

        del temp_env

        self.train_env: gym.Env
        self.valid_env: gym.Env
        self.test_env: gym.Env

    def create_task_schedule(
        self, temp_env: MultiTaskEnvironment, change_steps: List[int]
    ) -> Dict[int, Dict]:
        """ Create the task schedule, which maps from a step to the changes that
        will occur in the environment when that step is reached.
        Uses the provided `temp_env` to generate the random tasks at the steps
        given in `change_steps` (a list of integers).

        Returns a dictionary mapping from integers (the steps) to the changes
        that will occur in the env at that step.

        TODO: IDEA: Instead of just setting env attributes, use the
        `methodcaller` or `attrsetter` from the `operator` built-in package,
        that way later when we want to add support for Meta-World, we can just
        use `partial(methodcaller("set_task"), task="new_task")(env)` or
        something like that (i.e. generalize from changing an attribute to
        applying a function on the env, which would allow calling methods in
        addition to setting attributes.)
        task_schedule: Dict[int, Dict] = {}
        # Start with the default task (step 0) and then add a new task at
        # intervals of `self.steps_per_task`
        for task_step in change_steps:
            if task_step == 0:
                # Start with the default task, so that we can recover the 'iid'
                # case with standard env dynamics when there is only one task
                # and no non-stationarity.
                task_schedule[task_step] = temp_env.default_task
                task_schedule[task_step] = temp_env.random_task()

        return task_schedule

    def apply(
        self, method: Method, config: Config = None
    ) -> "ContinualRLSetting.Results":
        """Apply the given method on this setting to producing some results. """
        # Use the supplied config, or parse one from the arguments that were
        # used to create `self`.
        self.config: Config
        if config is not None:
            self.config = config
            logger.debug(f"Using Config {self.config}")
        elif isinstance(getattr(method, "config", None), Config):
            self.config = method.config
            logger.debug(f"Using Config from the Method: {self.config}")
            logger.debug(f"Parsing the Config from the command-line.")
            self.config = Config.from_args(self._argv, strict=False)
            logger.debug(f"Resulting Config: {self.config}")

        # TODO: Test to make sure that this doesn't cause any other bugs with respect to
        # the display of stuff:
        # Call this method, which creates a virtual display if necessary.

        # TODO: Should we really overwrite the method's 'config' attribute here?
        if not getattr(method, "config", None):
            method.config = self.config

        # TODO: Remove `Setting.configure(method)` entirely, from everywhere,
        # and use the `prepare_data` or `setup` methods instead (since these
        # `configure` methods aren't using the `method` anyway.)

        # BUG This won't work if the task schedule uses callables as the values (as
        # they aren't json-serializable.)
        if self._new_random_task_on_reset:
                f"Train tasks: "
                + json.dumps(list(self.train_task_schedule.values()), indent="\t")
                f"Train task schedule:"
                + json.dumps(self.train_task_schedule, indent="\t")
        if self.config.debug:
                f"Test task schedule:"
                + json.dumps(self.test_task_schedule, indent="\t")

        # Run the Training loop (which is defined in IncrementalSetting).
        results = self.main_loop(method)

        logger.info("Results summary:")
        method.receive_results(self, results=results)
        return results

        # Run the Test loop (which is defined in IncrementalSetting).
        # results: RlResults = self.test_loop(method)

    def setup(self, stage: str = None) -> None:
        # Called before the start of each task during training, validation and
        # testing.
        if stage in {"fit", None}:
            self.train_wrappers = self.create_train_wrappers()
            self.valid_wrappers = self.create_valid_wrappers()
        elif stage in {"test", None}:
            self.test_wrappers = self.create_test_wrappers()
    def prepare_data(self, *args, **kwargs) -> None:
        # We don't really download anything atm.
        if self.config is None:
            self.config = Config()
        super().prepare_data(*args, **kwargs)

    def train_dataloader(
        self, batch_size: int = None, num_workers: int = None
    ) -> ActiveEnvironment:
        """Create a training gym.Env/DataLoader for the current task.
        batch_size : int, optional
            The batch size, which in this case is the number of environments to
            run in parallel. When `None`, the env won't be vectorized. Defaults
            to None.
        num_workers : int, optional
            The number of workers (processes) to use in the vectorized env. When
            None, the envs are run in sequence, which could be very slow. Only
            applies when `batch_size` is not None. Defaults to None.

            A (possibly vectorized) environment/dataloader for the current task.
        if not self.has_prepared_data:
        # NOTE: We actually want to call setup every time, so we re-create the
        # wrappers for each task.
        # if not self.has_setup_fit:

        batch_size = batch_size or self.batch_size
        num_workers = num_workers if num_workers is not None else self.num_workers
        env_factory = partial(
        env_dataloader = self._make_env_dataloader(

        if self.monitor_training_performance:
            from sequoia.settings.passive.cl.measure_performance_wrapper import (
            env_dataloader = MeasureRLPerformanceWrapper(
                env_dataloader, wandb_prefix=f"Train/Task {self.current_task_id}"
        self.train_env = env_dataloader
        # BUG: There is a mismatch between the train env's observation space and the
        # shape of its observations.
        self.observation_space = self.train_env.observation_space

        return self.train_env

    def val_dataloader(
        self, batch_size: int = None, num_workers: int = None
    ) -> Environment:
        """Create a validation gym.Env/DataLoader for the current task.
        batch_size : int, optional
            The batch size, which in this case is the number of environments to
            run in parallel. When `None`, the env won't be vectorized. Defaults
            to None.
        num_workers : int, optional
            The number of workers (processes) to use in the vectorized env. When
            None, the envs are run in sequence, which could be very slow. Only
            applies when `batch_size` is not None. Defaults to None.

            A (possibly vectorized) environment/dataloader for the current task.
        if not self.has_prepared_data:

        env_factory = partial(
        env_dataloader = self._make_env_dataloader(
            batch_size=batch_size or self.batch_size,
            num_workers=num_workers if num_workers is not None else self.num_workers,
        self.val_env = env_dataloader
        return self.val_env

    def test_dataloader(
        self, batch_size: int = None, num_workers: int = None
    ) -> TestEnvironment:
        """Create the test 'dataloader/gym.Env' for all tasks.
        NOTE: This test environment isn't just for the current task, it actually
        contains the sequence of all tasks. This is different than the train or
        validation environments, since if the task labels are available at train
        time, then calling train/valid_dataloader` returns the envs for the
        current task only, and the `.fit` method is called once per task.
        This environment is also different in that it is wrapped with a Monitor,
        which we might eventually use to save the results/gifs/logs of the
        testing runs.

        batch_size : int, optional
            The batch size, which in this case is the number of environments to
            run in parallel. When `None`, the env won't be vectorized. Defaults
            to None.
        num_workers : int, optional
            The number of workers (processes) to use in the vectorized env. When
            None, the envs are run in sequence, which could be very slow. Only
            applies when `batch_size` is not None. Defaults to None.

            A testing environment which keeps track of the performance of the
            actor and accumulates logs/statistics that are used to eventually
            create the 'Result' object.
        if not self.has_prepared_data:
        # BUG: gym.wrappers.Monitor doesn't want to play nice when applied to
        # Vectorized env, it seems..
        # FIXME: Remove this when the Monitor class works correctly with
        # batched environments.
        batch_size = batch_size or self.batch_size
        if batch_size is not None:
                        f"WIP: Only support batch size of `None` (i.e., a single env) "
                        f"for the test environments of RL Settings at the moment, "
                        f"because the Monitor class from gym doesn't work with "
                        f"VectorEnvs. (batch size was {batch_size})",
            batch_size = None

        num_workers = num_workers if num_workers is not None else self.num_workers
        env_factory = partial(
        # TODO: Pass the max_steps argument to this `_make_env_dataloader` method,
        # rather than to a `step_limit` on the TestEnvironment.
        env_dataloader = self._make_env_dataloader(
            env_factory, batch_size=batch_size, num_workers=num_workers,
        # TODO: We should probably change the max_steps depending on the
        # batch size of the env.
        test_loop_max_steps = self.test_steps // (batch_size or 1)
        # TODO: Find where to configure this 'test directory' for the outputs of
        # the Monitor.
        test_dir = "results"
        # TODO: Debug wandb Monitor integration.
        self.test_env = ContinualRLTestEnvironment(
            video_callable=None if self.config.render else False,
        return self.test_env

    def phases(self) -> int:
        """The number of training 'phases', i.e. how many times `method.fit` will be
        In the case of ContinualRL, fit is only called once, with an environment that
        shifts between all the tasks.
        return 1
    def _make_env(
        base_env: Union[str, gym.Env, Callable[[], gym.Env]],
        wrappers: List[Callable[[gym.Env], gym.Env]] = None,
        observe_state_directly: bool = False,
    ) -> gym.Env:
        """ Helper function to create a single (non-vectorized) environment. """
        env: gym.Env
        if isinstance(base_env, str):
            if base_env.startswith("MetaMonsterKong") and observe_state_directly:
                env = gym.make(base_env, observe_state=True)
                env = gym.make(base_env)
        elif isinstance(base_env, gym.Env):
            env = base_env
        elif callable(base_env):
            env = base_env()
            raise RuntimeError(
                f"base_env should either be a string, a callable, or a gym "
                f"env. (got {base_env})."
        for wrapper in wrappers:
            env = wrapper(env)
        return env

    def _make_env_dataloader(
        env_factory: Callable[[], gym.Env],
        batch_size: Optional[int],
        num_workers: Optional[int] = None,
        seed: Optional[int] = None,
        max_steps: Optional[int] = None,
        max_episodes: Optional[int] = None,
    ) -> GymDataLoader:
        """ Helper function for creating a (possibly vectorized) environment.
            f"batch_size: {batch_size}, num_workers: {num_workers}, seed: {seed}"

        env: Union[gym.Env, gym.vector.VectorEnv]
        if batch_size is None:
            env = env_factory()
            env = make_batched_env(
                # TODO: Still debugging shared memory + custom spaces (e.g. Sparse).

        ## Apply the "post-batch" wrappers:
        # from sequoia.common.gym_wrappers import ConvertToFromTensors
        # TODO: Only the BaselineMethod requires this, we should enable it only
        # from the BaselineMethod, and leave it 'off' by default.
        if self.add_done_to_observations:
            env = AddDoneToObservation(env)
        # # Convert the samples to tensors and move them to the right device.
        # env = ConvertToFromTensors(env)
        # env = ConvertToFromTensors(env, device=self.config.device)
        # Add a wrapper that converts numpy arrays / etc to Observations/Rewards
        # and from Actions objects to numpy arrays.
        env = TypedObjectsWrapper(
        # Create an IterableDataset from the env using the EnvDataset wrapper.
        dataset = EnvDataset(env, max_steps=max_steps, max_episodes=max_episodes,)

        # Create a GymDataLoader for the EnvDataset.
        env_dataloader = GymDataLoader(dataset)

        if batch_size and seed:
            # Seed each environment with its own seed (based on the base seed).
            env.seed([seed + i for i in range(env_dataloader.num_envs)])

        return env_dataloader

    def create_train_wrappers(self) -> List[Callable[[gym.Env], gym.Env]]:
        """Get the list of wrappers to add to each training environment.
        The result of this method must be pickleable when using
        List[Callable[[gym.Env], gym.Env]]
        # We add a restriction to prevent users from getting data from
        # previous or future tasks.
        # TODO: This assumes that tasks all have the same length.
        starting_step = self.current_task_id * self.steps_per_task
        max_steps = starting_step + self.steps_per_task - 1
        return self._make_wrappers(

    def create_valid_wrappers(self) -> List[Callable[[gym.Env], gym.Env]]:
        """Get the list of wrappers to add to each validation environment.
        The result of this method must be pickleable when using

        List[Callable[[gym.Env], gym.Env]]
        TODO: Decide how this 'validation' environment should behave in
        comparison with the train and test environments. 
        # We add a restriction to prevent users from getting data from
        # previous or future tasks.
        # TODO: Should the validation environment only be for the current task?
        starting_step = self.current_task_id * self.steps_per_task
        max_steps = starting_step + self.steps_per_task - 1
        return self._make_wrappers(

    def create_test_wrappers(self) -> List[Callable[[gym.Env], gym.Env]]:
        """Get the list of wrappers to add to a single test environment.
        The result of this method must be pickleable when using

        List[Callable[[gym.Env], gym.Env]]
        return self._make_wrappers(

    def load_task_schedule(self, file_path: Path) -> Dict[int, Dict]:
        """ Load a task schedule from the given path. """
        with open(file_path) as f:
            task_schedule = json.load(f)
            return {int(k): task_schedule[k] for k in sorted(task_schedule.keys())}

    def _make_wrappers(
        task_schedule: Dict[int, Dict],
        sharp_task_boundaries: bool,
        task_labels_available: bool,
        transforms: List[Transforms],
        starting_step: int,
        max_steps: int,
        new_random_task_on_reset: bool,
    ) -> List[Callable[[gym.Env], gym.Env]]:
        """ helper function for creating the train/valid/test wrappers. 
        These wrappers get applied *before* the batching, if applicable.
        wrappers: List[Callable[[gym.Env], gym.Env]] = []
        # NOTE: When transitions are smooth, there are no "task boundaries".
        assert sharp_task_boundaries == (not self.smooth_task_boundaries)

        # TODO: Add some kind of Wrapper around the dataset to make it
        # semi-supervised?

        if self.max_episode_steps:
                partial(TimeLimit, max_episode_steps=self.max_episode_steps)

        if is_classic_control_env(self.dataset) and not self.observe_state_directly:
            # If we are in a classic control env, and we dont want the state to
            # be fully-observable (i.e. we want pixel observations rather than
            # getting the pole angle, velocity, etc.), then add the
            # PixelObservation wrapper to the list of wrappers.

        if (
            isinstance(self.dataset, str)
            and self.dataset.lower().startswith("metamonsterkong")
            and not self.observe_state_directly
            # TODO: Do we need the AtariPreprocessing wrapper on MonsterKong?
            # wrappers.append(partial(AtariPreprocessing, frame_skip=1))
        elif is_atari_env(self.dataset):
            # TODO: Test & Debug this: Adding the Atari preprocessing wrapper.
            # TODO: Figure out the differences (if there are any) between the
            # AtariWrapper from SB3 and the AtariPreprocessing wrapper from gym.
            # wrappers.append(AtariPreprocessing)

        # Apply image transforms if the env will have image-like obs space
        if not self.observe_state_directly:
            # wrappers.append(ImageObservations)
            # Wrapper to apply the image transforms to the env.
            wrappers.append(partial(TransformObservation, f=transforms))

        # Add a wrapper which will add non-stationarity to the environment.
        # The "task" transitions will either be sharp or smooth.
        # In either case, the task ids for each sample are added to the
        # observations, and the dicts containing the task information (i.e. the
        # current values of the env attributes from the task schedule) get added
        # to the 'info' dicts.
        if sharp_task_boundaries:
            assert self.nb_tasks >= 1
            # Add a wrapper that creates sharp tasks.
            cl_wrapper = MultiTaskEnvironment
            # Add a wrapper that creates smooth tasks.
            cl_wrapper = SmoothTransitions

        # If the task labels aren't available, we then add another wrapper that
        # hides that information (setting both of them to None) and also marks
        # those spaces as `Sparse`.
        if not task_labels_available:
            # NOTE: This sets the task labels to None, rather than removing
            # them entirely.
            # wrappers.append(RemoveTaskLabelsWrapper)

        return wrappers

    def _temp_wrappers(self) -> List[Callable[[gym.Env], gym.Env]]:
        """ Gets the minimal wrappers needed to figure out the Spaces of the
        train/valid/test environments.
        This is called in the 'constructor' (__post_init__) to set the Setting's
        observation/action/reward spaces, so this should depend on as little
        state from `self` as possible, since not all attributes have been
        defined at the time when this is called. 
        return self._make_wrappers(
            # These two shouldn't matter really:
예제 #2
 class A(TestSetup):
     a: Dict[str, int] = dict_field(dict(bob=0, john=1, bart=2))
예제 #3
 class A(TestSetup):
     a: Dict[str, int] = dict_field()
예제 #4
 class A:
     a: Dict[str, int] = dict_field(default)
예제 #5
 class SomeClass(Serializable):
     d: Dict[str, some_type] = dict_field()
     l: List[Tuple[some_type, some_type]] = list_field()
     t: Dict[str, Optional[some_type]] = dict_field()
예제 #6
class Loss(Serializable):
    """ Object used to store the losses and metrics. 

    Used to simplify the return type of the different `get_loss` functions and
    also to help in debugging models that use a combination of different loss

    TODO: Add some kind of histogram plot to show the relative contribution of
    each loss signal?
    TODO: Maybe create a `make_plots()` method to create wandb plots?
    name: str
    loss: Tensor = 0.  # type: ignore
    losses: Dict[str, "Loss"] = dict_field()
    # NOTE: By setting to_dict=False below, we don't include the tensors when
    # serializing the attributes.
    # TODO: Does that also mean that the tensors can't be pickled (moved) by
    # pytorch-lightning during training? Is there a case where that would be
    # useful?
    tensors: Dict[str, Tensor] = dict_field(repr=False, to_dict=False)
    metrics: Dict[str, Union[Metrics, Tensor]] = dict_field()
    # When multiplying the Loss by a value, this keep track of the coefficients
    # used, so that if we wanted to we could recover the 'unscaled' loss.
    _coefficient: Union[float, Tensor] = field(1.0, repr=False)

    x: InitVar[Optional[Tensor]] = None
    h_x: InitVar[Optional[Tensor]] = None
    y_pred: InitVar[Optional[Tensor]] = None
    y: InitVar[Optional[Tensor]] = None

    def __post_init__(self,
                      x: Tensor = None,
                      h_x: Tensor = None,
                      y_pred: Tensor = None,
                      y: Tensor = None):
        assert self.name, "Loss objects should be given a name!"
        if self.name not in self.metrics:
            # Create a Metrics object if given the necessary tensors.
            metrics = get_metrics(x=x, h_x=h_x, y_pred=y_pred, y=y)
            if metrics:
                self.metrics[self.name] = metrics
        self._device: torch.device = None
        for name in list(self.tensors.keys()):
            tensor = self.tensors[name]
            if not isinstance(tensor, Tensor):
                self.tensors[name] = torch.as_tensor(tensor)
            elif self._device is None:
                self._device = tensor.device

    def to_pl_dict(self, verbose: bool = False) -> Dict:
        """Creates a pytorch-lightning-style dict from this Loss object.

        Can be used as a return value to the `[training/validation/test]_step'
        methods of a `LightningModule`, like so:
        # (inside some LightningModule)
        def training_step(self, batch, ...) -> Dict:
            x, y = batch
            y_pred = self.forward(x)
            nce = self.loss_fn(y_pred, y)
            loss: Loss = Loss("train", loss=nce, y_pred=y_pred, y=y)
            return loss.to_pl_dict()

            verbose (bool, optional): Wether to keep things short or to include
                everything into the log dictionary. Defaults to False.

            Dict: A dictionary with the usual 'loss', 'log' and 'progress_bar'
                keys, and additionally with a copy of 'self' at the key
        return {
            "loss": self.loss,
            "log": self.to_log_dict(verbose=verbose),
            "progress_bar": self.to_pbar_message(),
            "loss_object": self,

    def total_loss(self) -> Tensor:
        return self.loss
    def requires_grad(self) -> bool:
        """ Returns wether the loss tensor in this object requires grad. """
        return isinstance(self.loss, Tensor) and self.loss.requires_grad
    def backward(self, *args, **kwargs):
        """ Calls `self.loss.backward(*args, **kwargs)`. """
        return self.loss.backward(*args, **kwargs)
    def metric(self) -> Optional[Metrics]:
        """Shortcut for `self.metrics[self.name]`.

            Optional[Metrics]: The main metrics associated with this Loss.
        return self.metrics.get(self.name)

    def metric(self, value: Metrics) -> None:
        """Shortcut for `self.metrics[self.name] = value`.

        value : Metrics
            The main metrics associated with this Loss.
        assert self.name not in self.metrics, "There's already be a metric?"
        self.metrics[self.name] = value

    def accuracy(self) -> float:
        if isinstance(self.metric, ClassificationMetrics):
            return self.metric.accuracy

    def mse(self) -> Tensor:
        assert isinstance(self.metric, RegressionMetrics), self
        return self.metric.mse

    def __add__(self, other: Union["Loss", Any]) -> "Loss":
        """Adds two Loss instances together.
        Adds the losses, total loss and metrics. Overwrites the tensors.
        Keeps the name of the first one. This is useful when doing something
        loss = Loss("Test")
        for x, y in dataloader:
            loss += model.get_loss(x=x, y=y)
            The merged/summed up Loss.
        if other == 0:
            return self
        if not isinstance(other, Loss):
            return NotImplemented
        name = self.name
        loss = self.loss + other.loss
        if self.name == other.name:
            losses  = add_dicts(self.losses, other.losses)
            metrics = add_dicts(self.metrics, other.metrics)
            # IDEA: when the names don't match, store the entire Loss
            # object into the 'losses' dict, rather than a single loss tensor.
            losses = add_dicts(self.losses, {other.name: other})
            # TODO: setting in the 'metrics' dict, we are duplicating the
            # metrics, since they now reside in the `self.metrics[other.name]`
            # and `self.losses[other.name].metrics` attributes.
            metrics = self.metrics
            # metrics = add_dicts(self.metrics, {other.name: other.metrics})
        tensors = add_dicts(self.tensors, other.tensors, add_values=False)
        return Loss(

    def __iadd__(self, other: Union["Loss", Any]) -> "Loss":
        """Adds Loss to `self` in-place.
        Adds the losses, total loss and metrics. Overwrites the tensors.
        Keeps the name of the first one. This is useful when doing something
        loss = Loss("Test")
        for x, y in dataloader:
            loss += model.get_loss(x=x, y=y)
            `self`: The merged/summed up Loss.
        self.loss = self.loss + other.loss
        if self.name == other.name:
            self.losses  = add_dicts(self.losses, other.losses)
            self.metrics = add_dicts(self.metrics, other.metrics)
            # IDEA: when the names don't match, store the entire Loss
            # object into the 'losses' dict, rather than a single loss tensor.
            self.losses = add_dicts(self.losses, {other.name: other})
        self.tensors = add_dicts(self.tensors, other.tensors, add_values=False)
        return self

    def __radd__(self, other: Any):
        """Addition operator for when forward addition returned `NotImplemented`.

        For example, doing something like `None + Loss()` will use __radd__,
        whereas doing `Loss() + None` will use __add__.
        if other is None:
            return self
        elif other == 0:
            return self
        if isinstance(other, Tensor):
            # TODO: Other could be a loss tensor, maybe create a Loss object for it?
        return NotImplemented

    def __mul__(self, factor: Union[float, Tensor]) -> "Loss":
        """ Scale each loss tensor by `coefficient`.

            returns a scaled Loss instance.
        result = Loss(
            loss=self.loss * factor,
                k: value * factor for k, value in self.losses.items()
            _coefficient=self._coefficient * factor,
        return result

    def __rmul__(self, factor: Union[float, Tensor]) -> "Loss":
        # assert False, f"rmul: {factor}"
        return self.__mul__(factor)

    def __truediv__(self, coefficient: Union[float, Tensor]) -> "Loss":
        return self * (1 / coefficient)

    def unscaled_losses(self):
        """ Recovers the 'unscaled' version of this loss.

        TODO: This isn't used anywhere. We could probably remove it.
        return {
            k: value / self._coefficient for k, value in self.losses.items()

    def to_log_dict(self, verbose: bool = False) -> Dict[str, Union[str, float, Dict]]:
        """Creates a dictionary to be logged (e.g. by `wandb.log`).

            verbose (bool, optional): Wether to include a lot of information, or
            to only log the 'essential' stuff. See the `cleanup` function for
            more info. Defaults to False.

            Dict: A dict containing the things to be logged.
        # TODO: Could also produce some wandb plots and stuff here when verbose?
        log_dict: Dict[str, Union[str, float, Dict, Tensor]] = {}
        log_dict["loss"] = round(float(self.loss), 6)

        for name, metric in self.metrics.items():
            if isinstance(metric, Serializable):
                log_dict[name] = metric.to_log_dict(verbose=verbose)
                log_dict[name] = metric

        for name, loss in self.losses.items():
            if isinstance(loss, Serializable):
                log_dict[name] = loss.to_log_dict(verbose=verbose)
                log_dict[name] = loss

        log_dict = add_prefix(log_dict, prefix=self.name, sep="/")
        keys_to_remove: List[str] = []
        if not verbose:
            # when NOT verbose, remove any entries with this matching key.
            # TODO: add/remove keys here if you want to customize what doesn't get logged to wandb.
            # TODO: Could maybe make this a class variable so that it could be
            # extended/overwritten, but that sounds like a bit too much rn.
            keys_to_remove = [
        result = cleanup(log_dict, keys_to_remove=keys_to_remove, sep="/") 
        return result
    def to_pbar_message(self) -> Dict[str, float]:
        """ Smaller, less-detailed version of `to_log_dict()` for progress bars.
        # NOTE: PL actually doesn't seem to accept strings as values 
        message: Dict[str, Union[str, float]] = {}
        message["Loss"] = float(self.loss)

        for name, metric in self.metrics.items():
            if isinstance(metric, Metrics):
                message[name] = metric.to_pbar_message()
                message[name] = metric

        for name, loss_info in self.losses.items():
            message[name] = loss_info.to_pbar_message()

        message = add_prefix(message, prefix=self.name, sep=" ")

        return cleanup(message, sep=" ")

    def clear_tensors(self) -> None:
        """ Clears the `tensors` attribute of `self` and of sublosses.
        NOTE: This could be useful if you want to save some space/compute, but
        it isn't being used atm, and there's no issue. You might want to call
        this if you are storing big tensors (or passing them to the constructor)
        for _, loss in self.losses.items():
        return self

    def absorb(self, other: "Loss") -> None:
        """Absorbs `other` into `self`, merging the losses and metrics.

            other (Loss): Another loss to 'merge' into this one.
        new_name = self.name
        old_name = other.name
        # Here we create a new 'other' and use __iadd__ to merge the attributes.
        new_other = Loss(name=new_name)
        new_other.loss = other.loss
        # We also replace the name in the keys, if present.
        new_other.metrics = {
            k.replace(old_name, new_name): v for k, v in other.metrics.items() 
        new_other.losses = {
            k.replace(old_name, new_name): v for k, v in other.losses.items() 
        self += new_other

    def all_metrics(self) -> Dict[str, Metrics]:
        """ Returns a 'cleaned up' dictionary of all the Metrics objects. """
        assert self.name
        result: Dict[str, Metrics] = {}

        for name, loss in self.losses.items():
            # TODO: Aren't we potentially colliding with 'self.metrics' here?
            subloss_metrics = loss.all_metrics()
            for key, metric in subloss_metrics.items():
                assert key not in result, (
                    f"Collision in metric keys of subloss {name}: key={key}, "
                result[key] = metric
        result = add_prefix(result, prefix=self.name, sep="/")
        return result
예제 #7
 class Leaderboard(Serializable):
     participants: Dict[Person, int] = dict_field({bob: 1, peter: 2})
예제 #8
 class Leaderboard(FrozenSerializable if frozen else Serializable):
     participants: Dict[Person, int] = dict_field({bob: 1, peter: 2})