Пример #1
0
def test_add_task_dict_to_info():
    """ Test that the 'info' dict contains the task dict. """
    original: CartPoleEnv = gym.make("CartPole-v0")
    starting_length = original.length
    starting_gravity = original.gravity

    task_schedule = {
        10: dict(length=0.1),
        20: dict(length=0.2, gravity=-12.0),
        30: dict(gravity=0.9),
    }
    env = MultiTaskEnvironment(
        original,
        task_schedule=task_schedule,
        add_task_dict_to_info=True,
    )
    env.seed(123)
    env.reset()
    for step in range(100):
        _, _, done, info = env.step(env.action_space.sample())
        # env.render()
        if done:
            env.reset()

        if 0 <= step < 10:
            assert env.length == starting_length and env.gravity == starting_gravity
            assert info == env.default_task
        elif 10 <= step < 20:
            assert env.length == 0.1
            assert info == dict_union(env.default_task, task_schedule[10])
        elif 20 <= step < 30:
            assert env.length == 0.2 and env.gravity == -12.0
            assert info == dict_union(env.default_task, task_schedule[20])
        elif step >= 30:
            assert env.length == starting_length and env.gravity == 0.9
            assert info == dict_union(env.default_task, task_schedule[30])

    env.close()
Пример #2
0
 def replace(self, **new_params):
     new_hp_dict = dict_union(self.to_dict(), new_params, recurse=True)
     new_hp = type(self).from_dict(new_hp_dict)
     return new_hp
Пример #3
0
from ..continual.setting import (
    ContinualRLSetting,
    ContinualRLTestEnvironment,
    supported_envs as _parent_supported_envs,
)
from .tasks import DiscreteTask, TaskSchedule, is_supported, make_discrete_task
from .tasks import registry, EnvSpec
from .test_environment import DiscreteTaskAgnosticRLTestEnvironment, TestEnvironment

from sequoia.settings.rl.envs import MONSTERKONG_INSTALLED
logger = get_logger(__file__)

supported_envs: Dict[str, EnvSpec] = dict_union(
    _parent_supported_envs,
    {
        spec.id: spec
        for env_id, spec in registry.env_specs.items()
        if spec.id not in _parent_supported_envs and is_supported(env_id)
    },
)
available_datasets: Dict[str,
                         str] = {env_id: env_id
                                 for env_id in supported_envs}

from .results import DiscreteTaskAgnosticRLResults
from sequoia.settings.base import Results


@dataclass
class DiscreteTaskAgnosticRLSetting(DiscreteContextAssumption,
                                    ContinualRLSetting):
    """ Continual Reinforcement Learning Setting where there are clear task boundaries,