コード例 #1
0
def check_env(env):
    """
    Check that the environment is (almost) gym-compatible and that it is reproducible
    in the sense that it returns the same states when given the same seed.

    Parameters
    ----------
    env: gym.env or rlberry env
        Environment that we want to check.
    """
    # Small reproducibility test
    action = env.action_space.sample()
    safe_reseed(env, Seeder(42))
    env.reset()
    a = env.step(action)[0]

    safe_reseed(env, Seeder(42))
    env.reset()
    b = env.step(action)[0]
    if hasattr(a, "__len__"):
        assert np.all(np.array(a) == np.array(
            b)), "The environment does not seem to be reproducible"
    else:
        assert a == b, "The environment does not seem to be reproducible"

    # Modified check suite from gym
    check_gym_env(env)
コード例 #2
0
ファイル: test_env_seeding.py プロジェクト: omardrwch/rlberry
def test_env_seeding(ModelClass):
    env1 = ModelClass()
    seeder1 = seeding.Seeder(123)
    env1.reseed(seeder1)

    env2 = ModelClass()
    seeder2 = seeder1.spawn()
    env2.reseed(seeder2)

    env3 = ModelClass()
    seeder3 = seeding.Seeder(123)
    env3.reseed(seeder3)

    env4 = ModelClass()
    seeder4 = seeding.Seeder(123)
    env4.reseed(seeder4)

    env5 = ModelClass()
    env5.reseed(
        seeder1
    )  # same seeder as env1, but different trajectories. This is expected.

    seeding.safe_reseed(env4, seeder4)

    if deepcopy(env1).is_online():
        traj1 = get_env_trajectory(env1, 500)
        traj2 = get_env_trajectory(env2, 500)
        traj3 = get_env_trajectory(env3, 500)
        traj4 = get_env_trajectory(env4, 500)
        traj5 = get_env_trajectory(env5, 500)

        assert not compare_trajectories(traj1, traj2)
        assert compare_trajectories(traj1, traj3)
        assert not compare_trajectories(traj3, traj4)
        assert not compare_trajectories(traj1, traj5)
コード例 #3
0
ファイル: test_bandits.py プロジェクト: omardrwch/rlberry
def test_adversarial():
    r1 = np.concatenate((2 * np.ones((500, 1)), np.ones((500, 1))), axis=1)

    r2 = np.concatenate((np.ones((500, 1)), 2 * np.ones((500, 1))), axis=1)

    rewards = np.concatenate((r1, r2))

    env = AdversarialBandit(rewards=rewards)
    safe_reseed(env, Seeder(TEST_SEED))

    sample = [env.step(1)[1] for f in range(1000)]
    assert np.abs(np.mean(sample) - 1.5) < 1e-10
コード例 #4
0
ファイル: agent_manager.py プロジェクト: omardrwch/rlberry
 def load(self) -> bool:
     """Load agent from file."""
     try:
         self._agent_instance = self._agent_class.load(
             self._fname, **self._agent_kwargs
         )
         safe_reseed(self._agent_instance.env, self._seeder)
         return True
     except Exception as ex:
         self._agent_instance = None
         logger.error(
             f"Failed call to AgentHandler.load() for {self._agent_class}: {ex}"
         )
         return False
コード例 #5
0
    def reseed(self, seed_seq=None):
        """
        Get new random number generator for the agent.

        Parameters
        ----------
        seed_seq : :class:`numpy.random.SeedSequence`, :class:`rlberry.seeding.seeder.Seeder` or int, default : None
            Seed sequence from which to spawn the random number generator.
            If None, generate random seed.
            If int, use as entropy for SeedSequence.
            If seeder, use seeder.seed_seq
        """
        # self.seeder
        if seed_seq is None:
            self.seeder = self.seeder.spawn()
        else:
            self.seeder = Seeder(seed_seq)
        safe_reseed(self.env, self.seeder)
        safe_reseed(self.eval_env, self.seeder)
コード例 #6
0
def test_env_seeding(ModelClass):

    seeding.set_global_seed(123)
    env1 = ModelClass()

    seeding.set_global_seed(456)
    env2 = ModelClass()

    seeding.set_global_seed(123)
    env3 = ModelClass()

    seeding.set_global_seed(123)
    env4 = ModelClass()
    seeding.safe_reseed(env4)

    if deepcopy(env1).is_online():
        traj1 = get_env_trajectory(env1, 500)
        traj2 = get_env_trajectory(env2, 500)
        traj3 = get_env_trajectory(env3, 500)
        traj4 = get_env_trajectory(env4, 500)

        assert not compare_trajectories(traj1, traj2)
        assert compare_trajectories(traj1, traj3)
        assert not compare_trajectories(traj3, traj4)
コード例 #7
0
def process_env(env, seeder, copy_env=True):
    if isinstance(env, Tuple):
        constructor = env[0]
        kwargs = env[1] or {}
        processed_env = constructor(**kwargs)
    else:
        if copy_env:
            try:
                processed_env = deepcopy(env)
            except Exception as ex:
                logger.warning("[Agent] Not possible to deepcopy env: " +
                               str(ex))
        else:
            processed_env = env
    reseeded = safe_reseed(processed_env, seeder)
    if not reseeded:
        logger.warning("[Agent] Not possible to reseed environment.")
    return processed_env
コード例 #8
0
 def reseed(self):
     self.rng = seeding.get_rng()
     # seed gym.Env that is not a rlberry Model
     if not isinstance(self.env, Model):
         # get a seed for gym environment
         seeding.safe_reseed(self.env)
         seeding.safe_reseed(self.observation_space)
         seeding.safe_reseed(self.action_space)
     # seed rlberry Model
     else:
         self.env.reseed()
         self.observation_space.rng = self.env.rng
         self.action_space.rng = self.env.rng
コード例 #9
0
ファイル: basewrapper.py プロジェクト: omardrwch/rlberry
 def reseed(self, seed_seq=None):
     # self.seeder
     if seed_seq is None:
         self.seeder = self.seeder.spawn()
     else:
         self.seeder = Seeder(seed_seq)
     # seed gym.Env that is not a rlberry Model
     if not isinstance(self.env, Model):
         # get a seed for gym environment; spaces are reseeded below.
         safe_reseed(self.env, self.seeder, reseed_spaces=False)
     # seed rlberry Model
     else:
         self.env.reseed(self.seeder)
     safe_reseed(self.observation_space, self.seeder)
     safe_reseed(self.action_space, self.seeder)
コード例 #10
0
ファイル: test_bandits.py プロジェクト: omardrwch/rlberry
def test_cor_normal():
    env = CorruptedNormalBandit(means=[0, 1], cor_prop=0.1)
    safe_reseed(env, Seeder(TEST_SEED))

    sample = [env.step(1)[1] for f in range(1000)]
    assert np.abs(np.median(sample) - 1) < 0.5
コード例 #11
0
ファイル: test_bandits.py プロジェクト: omardrwch/rlberry
def test_normal():
    env = NormalBandit(means=[0, 1])
    safe_reseed(env, Seeder(TEST_SEED))

    sample = [env.step(1)[1] for f in range(1000)]
    assert np.abs(np.mean(sample) - 1) < 0.1
コード例 #12
0
ファイル: test_bandits.py プロジェクト: omardrwch/rlberry
def test_bernoulli():
    env = BernoulliBandit(p=[0.05, 0.95])
    safe_reseed(env, Seeder(TEST_SEED))

    sample = [env.step(1)[1] for f in range(1000)]
    assert np.abs(np.mean(sample) - 0.95) < 0.1