Exemplo n.º 1
0
def test_rescale_reward():
    # tolerance
    tol = 1e-14

    rng = Seeder(123).rng

    for _ in range(10):
        # generate random MDP
        S, A = 5, 2
        R = rng.uniform(0.0, 1.0, (S, A))
        P = rng.uniform(0.0, 1.0, (S, A, S))
        for ss in range(S):
            for aa in range(A):
                P[ss, aa, :] /= P[ss, aa, :].sum()
        env = FiniteMDP(R, P)

        # test
        wrapped = RescaleRewardWrapper(env, (-10, 10))
        _ = wrapped.reset()
        for _ in range(100):
            _, reward, _, _ = wrapped.sample(
                wrapped.observation_space.sample(),
                wrapped.action_space.sample())
            assert reward <= 10 + tol and reward >= -10 - tol

        _ = wrapped.reset()
        for _ in range(100):
            _, reward, _, _ = wrapped.step(wrapped.action_space.sample())
            assert reward <= 10 + tol and reward >= -10 - tol
Exemplo n.º 2
0
def check_env(env):
    """
    Check that the environment is (almost) gym-compatible and that it is reproducible
    in the sense that it returns the same states when given the same seed.

    Parameters
    ----------
    env: gym.env or rlberry env
        Environment that we want to check.
    """
    # Small reproducibility test
    action = env.action_space.sample()
    safe_reseed(env, Seeder(42))
    env.reset()
    a = env.step(action)[0]

    safe_reseed(env, Seeder(42))
    env.reset()
    b = env.step(action)[0]
    if hasattr(a, "__len__"):
        assert np.all(np.array(a) == np.array(
            b)), "The environment does not seem to be reproducible"
    else:
        assert a == b, "The environment does not seem to be reproducible"

    # Modified check suite from gym
    check_gym_env(env)
Exemplo n.º 3
0
 def __init__(self, n):
     """
     Parameters
     ----------
     n : int
         number of elements in the space
     """
     assert n >= 0, "The number of elements in Discrete must be >= 0"
     gym.spaces.Discrete.__init__(self, n)
     self.seeder = Seeder()
Exemplo n.º 4
0
class Discrete(gym.spaces.Discrete):
    """
    Class that represents discrete spaces.


    Inherited from gym.spaces.Discrete for compatibility with gym.

    rlberry wraps gym.spaces to make sure the seeding
    mechanism is unified in the library (rlberry.seeding)

    Attributes
    ----------
    rng : numpy.random._generator.Generator
        random number generator provided by rlberry.seeding

    Methods
    -------
    reseed()
        get new random number generator
    """
    def __init__(self, n):
        """
        Parameters
        ----------
        n : int
            number of elements in the space
        """
        assert n >= 0, "The number of elements in Discrete must be >= 0"
        gym.spaces.Discrete.__init__(self, n)
        self.seeder = Seeder()

    @property
    def rng(self):
        return self.seeder.rng

    def reseed(self, seed_seq=None):
        """
        Get new random number generator.

        Parameters
        ----------
        seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None
            Seed sequence from which to spawn the random number generator.
            If None, generate random seed.
            If int, use as entropy for SeedSequence.
            If seeder, use seeder.seed_seq
        """
        self.seeder.reseed(seed_seq)

    def sample(self):
        return self.rng.integers(0, self.n)

    def __str__(self):
        objstr = "%d-element Discrete space" % self.n
        return objstr
Exemplo n.º 5
0
def test_seeder_initialized_from_seeder():
    """
    Check that Seeder(seed_seq) respawns seed_seq in the constructor.
    """
    seeder1 = Seeder(43)
    seeder_temp = Seeder(43)
    seeder2 = Seeder(seeder_temp)

    data1 = seeder1.rng.integers(100, size=1000)
    data2 = seeder2.rng.integers(100, size=1000)
    assert (data1 != data2).sum() > 5
Exemplo n.º 6
0
def test_seeder_reseeding():
    """
    Check that reseeding with a Seeder instance works properly.
    """
    # seeders 1 and 2 are identical
    seeder1 = Seeder(43)
    seeder2 = Seeder(43)

    # reseed seeder 2 using seeder 1
    seeder2.reseed(seeder1)

    data1 = seeder1.rng.integers(100, size=1000)
    data2 = seeder2.rng.integers(100, size=1000)
    assert (data1 != data2).sum() > 5
Exemplo n.º 7
0
def test_seeder_spawning():
    """
    Check that Seeder(seed_seq) respawns seed_seq in the constructor.
    """
    seeder1 = Seeder(43)
    seeder2 = seeder1.spawn()
    seeder3 = seeder2.spawn()

    print(seeder1)
    print(seeder2)
    print(seeder3)

    data1 = seeder1.rng.integers(100, size=1000)
    data2 = seeder2.rng.integers(100, size=1000)
    assert (data1 != data2).sum() > 5
Exemplo n.º 8
0
def test_gym_safe_reseed(env_name):
    seeder = Seeder(123)
    seeder_aux = Seeder(123)

    env1 = gym.make(env_name)
    env2 = gym.make(env_name)
    env3 = gym.make(env_name)

    safe_reseed(env1, seeder)
    safe_reseed(env2, seeder)
    safe_reseed(env3, seeder_aux)

    traj1 = get_env_trajectory(env1, 500)
    traj2 = get_env_trajectory(env2, 500)
    traj3 = get_env_trajectory(env3, 500)
    assert not compare_trajectories(traj1, traj2)
    assert compare_trajectories(traj1, traj3)
Exemplo n.º 9
0
def test_mbqvi(S, A):
    rng = Seeder(123).rng

    for sim in range(5):
        # generate random MDP with deterministic transitions
        R = rng.uniform(0.0, 1.0, (S, A))
        P = np.zeros((S, A, S))
        for ss in range(S):
            for aa in range(A):
                ns = rng.integers(0, S)
                P[ss, aa, ns] = 1

        # run MBQVI and check exactness of estimators
        env = FiniteMDP(R, P)
        agent = MBQVIAgent(env, n_samples=1)
        agent.fit()
        assert np.abs(R - agent.R_hat).max() < 1e-16
        assert np.abs(P - agent.P_hat).max() < 1e-16
Exemplo n.º 10
0
def test_seeder_basic():
    seeder1 = Seeder(43)
    data1 = seeder1.rng.integers(100, size=1000)

    seeder2 = Seeder(44)
    data2 = seeder2.rng.integers(100, size=1000)

    seeder3 = Seeder(44)
    data3 = seeder3.rng.integers(100, size=1000)

    assert (data1 != data2).sum() > 5
    assert (data2 != data3).sum() == 0
    assert (
        seeder2.spawn(1).generate_state(1)[0] == seeder3.spawn(1).generate_state(1)[0]
    )
    assert (
        seeder1.spawn(1).generate_state(1)[0] != seeder3.spawn(1).generate_state(1)[0]
    )
Exemplo n.º 11
0
class MultiDiscrete(gym.spaces.MultiDiscrete):
    """

    Inherited from gym.spaces.MultiDiscrete for compatibility with gym.

    rlberry wraps gym.spaces to make sure the seeding
    mechanism is unified in the library (rlberry.seeding)

    Attributes
    ----------
    rng : numpy.random._generator.Generator
        random number generator provided by rlberry.seeding

    Methods
    -------
    reseed()
        get new random number generator
    """
    def __init__(self, nvec, dtype=np.int64):
        gym.spaces.MultiDiscrete.__init__(self, nvec, dtype=dtype)
        self.seeder = Seeder()

    @property
    def rng(self):
        return self.seeder.rng

    def reseed(self, seed_seq=None):
        """
        Get new random number generator.

        Parameters
        ----------
        seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None
            Seed sequence from which to spawn the random number generator.
            If None, generate random seed.
            If int, use as entropy for SeedSequence.
            If seeder, use seeder.seed_seq
        """
        self.seeder.reseed(seed_seq)

    def sample(self):
        sample = self.rng.random(self.nvec.shape) * self.nvec
        return sample.astype(self.dtype)
Exemplo n.º 12
0
def test_rescale_wrapper_seeding(ModelClass):
    env1 = RescaleRewardWrapper(ModelClass(), (0, 1))
    seeder = Seeder(123)
    env1.reseed(seeder)

    env2 = RescaleRewardWrapper(ModelClass(), (0, 1))
    seeder = Seeder(456)
    env2.reseed(seeder)

    env3 = RescaleRewardWrapper(ModelClass(), (0, 1))
    seeder = Seeder(123)
    env3.reseed(seeder)

    if deepcopy(env1).is_online():
        traj1 = get_env_trajectory(env1, 500)
        traj2 = get_env_trajectory(env2, 500)
        traj3 = get_env_trajectory(env3, 500)

        assert not compare_trajectories(traj1, traj2)
        assert compare_trajectories(traj1, traj3)
Exemplo n.º 13
0
def test_env_seeding(env_name):
    seeder1 = Seeder(123)
    env1 = gym_make(env_name)
    env1.reseed(seeder1)

    seeder2 = Seeder(456)
    env2 = gym_make(env_name)
    env2.reseed(seeder2)

    seeder3 = Seeder(123)
    env3 = gym_make(env_name)
    env3.reseed(seeder3)

    if deepcopy(env1).is_online():
        traj1 = get_env_trajectory(env1, 500)
        traj2 = get_env_trajectory(env2, 500)
        traj3 = get_env_trajectory(env3, 500)

        assert not compare_trajectories(traj1, traj2)
        assert compare_trajectories(traj1, traj3)
Exemplo n.º 14
0
    def reseed(self, seed_seq=None):
        """
        Get new random number generator for the model.

        Parameters
        ----------
        seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None
            Seed sequence from which to spawn the random number generator.
            If None, generate random seed.
            If int, use as entropy for SeedSequence.
            If seeder, use seeder.seed_seq
        """
        # self.seeder
        if seed_seq is None:
            self.seeder = self.seeder.spawn()
        else:
            self.seeder = Seeder(seed_seq)
        # spaces
        self.observation_space.reseed(self.seeder.seed_seq)
        self.action_space.reseed(self.seeder.seed_seq)
Exemplo n.º 15
0
def test_double_wrapper_copy_reseeding(ModelClass):
    env = Wrapper(Wrapper(ModelClass()))
    seeder = Seeder(123)
    env.reseed(seeder)

    c_env = deepcopy(env)
    c_env.reseed()

    if deepcopy(env).is_online():
        traj1 = get_env_trajectory(env, 500)
        traj2 = get_env_trajectory(c_env, 500)
        assert not compare_trajectories(traj1, traj2)
Exemplo n.º 16
0
def test_copy_reseeding(env_name):
    seeder = Seeder(123)
    env = gym_make(env_name)
    env.reseed(seeder)

    c_env = deepcopy(env)
    c_env.reseed()

    if deepcopy(env).is_online():
        traj1 = get_env_trajectory(env, 500)
        traj2 = get_env_trajectory(c_env, 500)
        assert not compare_trajectories(traj1, traj2)
Exemplo n.º 17
0
def test_adversarial():
    r1 = np.concatenate((2 * np.ones((500, 1)), np.ones((500, 1))), axis=1)

    r2 = np.concatenate((np.ones((500, 1)), 2 * np.ones((500, 1))), axis=1)

    rewards = np.concatenate((r1, r2))

    env = AdversarialBandit(rewards=rewards)
    safe_reseed(env, Seeder(TEST_SEED))

    sample = [env.step(1)[1] for f in range(1000)]
    assert np.abs(np.mean(sample) - 1.5) < 1e-10
Exemplo n.º 18
0
def test_gym_copy_reseeding():
    seeder = Seeder(123)
    if _GYM_INSTALLED:
        gym_env = gym.make("Acrobot-v1")
        env = Wrapper(gym_env)
        env.reseed(seeder)

        c_env = deepcopy(env)
        c_env.reseed()

        if deepcopy(env).is_online():
            traj1 = get_env_trajectory(env, 500)
            traj2 = get_env_trajectory(c_env, 500)
            assert not compare_trajectories(traj1, traj2)
Exemplo n.º 19
0
def test_gym_copy_reseeding_2():
    seeder = Seeder(123)
    if _GYM_INSTALLED:
        gym_env = gym.make("Acrobot-v1")
        # nested wrapping
        env = RescaleRewardWrapper(Wrapper(Wrapper(gym_env)), (0, 1))
        env.reseed(seeder)

        c_env = deepcopy(env)
        c_env.reseed()

        if deepcopy(env).is_online():
            traj1 = get_env_trajectory(env, 500)
            traj2 = get_env_trajectory(c_env, 500)
            assert not compare_trajectories(traj1, traj2)
Exemplo n.º 20
0
 def reseed(self, seed_seq=None):
     # self.seeder
     if seed_seq is None:
         self.seeder = self.seeder.spawn()
     else:
         self.seeder = Seeder(seed_seq)
     # seed gym.Env that is not a rlberry Model
     if not isinstance(self.env, Model):
         # get a seed for gym environment; spaces are reseeded below.
         safe_reseed(self.env, self.seeder, reseed_spaces=False)
     # seed rlberry Model
     else:
         self.env.reseed(self.seeder)
     safe_reseed(self.observation_space, self.seeder)
     safe_reseed(self.action_space, self.seeder)
Exemplo n.º 21
0
class Model(gym.Env):
    """
    Base class for an environment model.

    Attributes
    ----------
    name : string
        environment identifier
    observation_space : rlberry.spaces.Space
        observation space
    action_space : rlberry.spaces.Space
        action space
    reward_range : tuple
        tuple (r_min, r_max) containing the minimum and the maximum reward
    seeder : rlberry.seeding.Seeder
        Seeder, containing random number generator.

    Methods
    -------
    reseed(seed_seq)
        get new Seeder
    reset()
        puts the environment in a default state and returns this state
    step(action)
        returns the outcome of an action
    sample(state, action)
        returns a transition sampled from taking an action in a given state
    is_online()
        returns true if reset() and step() methods are implemented
    is_generative()
        returns true if sample() method is implemented
    """

    name = ""

    def __init__(self):
        self.observation_space = None
        self.action_space = None
        self.reward_range: tuple = (-np.inf, np.inf)
        # random number generator
        self.seeder = Seeder()

    def reseed(self, seed_seq=None):
        """
        Get new random number generator for the model.

        Parameters
        ----------
        seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None
            Seed sequence from which to spawn the random number generator.
            If None, generate random seed.
            If int, use as entropy for SeedSequence.
            If seeder, use seeder.seed_seq
        """
        # self.seeder
        if seed_seq is None:
            self.seeder = self.seeder.spawn()
        else:
            self.seeder = Seeder(seed_seq)
        # spaces
        self.observation_space.reseed(self.seeder.seed_seq)
        self.action_space.reseed(self.seeder.seed_seq)

    def sample(self, state, action):
        """
        Execute a step from a state-action pair.

        Parameters
        ----------
        state : object
            state from which to sample
        action : object
            action to take in the environment

        Returns
        -------
        observation : object
        reward : float
        done  : bool
        info  : dict
        """
        raise NotImplementedError("sample() method not implemented.")

    def is_online(self):
        logger.warning("Checking if Model is\
online calls reset() and step() methods.")
        try:
            self.reset()
            self.step(self.action_space.sample())
            return True
        except Exception as ex:
            if isinstance(ex, NotImplementedError):
                return False
            else:
                raise

    def is_generative(self):
        logger.warning("Checking if Model is \
generative calls sample() method.")
        try:
            self.sample(self.observation_space.sample(),
                        self.action_space.sample())
            return True
        except Exception as ex:
            if isinstance(ex, NotImplementedError):
                return False
            else:
                raise

    @classmethod
    def _get_param_names(cls):
        """Get parameter names for the Model"""
        # fetch the constructor or the original constructor before
        # deprecation wrapping if any
        init = getattr(cls.__init__, "deprecated_original", cls.__init__)
        if init is object.__init__:
            # No explicit constructor to introspect
            return []

        # introspect the constructor arguments to find the model parameters
        # to represent
        init_signature = inspect.signature(init)
        # Consider the constructor parameters excluding 'self'
        parameters = [
            p for p in init_signature.parameters.values()
            if p.name != "self" and p.kind != p.VAR_KEYWORD
        ]

        # Extract and sort argument names excluding 'self'
        return sorted([p.name for p in parameters])

    def get_params(self, deep=True):
        """
        Get parameters for this model.
        Parameters
        ----------
        deep : bool, default=True
            If True, will return the parameters for this model and
            contained subobjects.
        Returns
        -------
        params : dict
            Parameter names mapped to their values.
        """
        out = dict()
        for key in self._get_param_names():
            value = getattr(self, key)
            if deep and hasattr(value, "get_params"):
                deep_items = value.get_params().items()
                out.update((key + "__" + k, val) for k, val in deep_items)
            out[key] = value
        return out

    @property
    def unwrapped(self):
        return self

    @property
    def rng(self):
        """Random number generator."""
        return self.seeder.rng
Exemplo n.º 22
0
class Box(gym.spaces.Box):
    """
    Class that represents a space that is a cartesian product in R^n:

    [a_1, b_1] x [a_2, b_2] x ... x [a_n, b_n]


    Inherited from gym.spaces.Box for compatibility with gym.

    rlberry wraps gym.spaces to make sure the seeding
    mechanism is unified in the library (rlberry.seeding)

    Attributes
    ----------
    rng : numpy.random._generator.Generator
        random number generator provided by rlberry.seeding

    Methods
    -------
    reseed()
        get new random number generator
    """
    def __init__(self, low, high, shape=None, dtype=np.float64):
        gym.spaces.Box.__init__(self, low, high, shape=shape, dtype=dtype)
        self.seeder = Seeder()

    @property
    def rng(self):
        return self.seeder.rng

    def reseed(self, seed_seq=None):
        """
        Get new random number generator.

        Parameters
        ----------
        seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None
            Seed sequence from which to spawn the random number generator.
            If None, generate random seed.
            If int, use as entropy for SeedSequence.
            If seeder, use seeder.seed_seq
        """
        self.seeder.reseed(seed_seq)

    def sample(self):
        """
        Adapted from:
        https://raw.githubusercontent.com/openai/gym/master/gym/spaces/box.py


        Generates a single random sample inside of the Box.

        In creating a sample of the box, each coordinate is sampled according
        to the form of the interval:

        * [a, b] : uniform distribution
        * [a, oo) : shifted exponential distribution
        * (-oo, b] : shifted negative exponential distribution
        * (-oo, oo) : normal distribution
        """
        high = self.high if self.dtype.kind == "f" else self.high.astype(
            "int64") + 1
        sample = np.empty(self.shape)

        # Masking arrays which classify the coordinates according to interval
        # type
        unbounded = ~self.bounded_below & ~self.bounded_above
        upp_bounded = ~self.bounded_below & self.bounded_above
        low_bounded = self.bounded_below & ~self.bounded_above
        bounded = self.bounded_below & self.bounded_above

        # Vectorized sampling by interval type
        sample[unbounded] = self.rng.normal(size=unbounded[unbounded].shape)

        sample[low_bounded] = (
            self.rng.exponential(size=low_bounded[low_bounded].shape) +
            self.low[low_bounded])

        sample[upp_bounded] = (
            -self.rng.exponential(size=upp_bounded[upp_bounded].shape) +
            self.high[upp_bounded])

        sample[bounded] = self.rng.uniform(low=self.low[bounded],
                                           high=high[bounded],
                                           size=bounded[bounded].shape)
        if self.dtype.kind == "i":
            sample = np.floor(sample)

        return sample.astype(self.dtype)
Exemplo n.º 23
0
 def __init__(self, low, high, shape=None, dtype=np.float64):
     gym.spaces.Box.__init__(self, low, high, shape=shape, dtype=dtype)
     self.seeder = Seeder()
Exemplo n.º 24
0
A demo of twinrooms environment
===============================
 Illustration of TwinRooms environment

.. video:: ../../video_plot_twinrooms.mp4
   :width: 600

"""
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_twinrooms.jpg'

from rlberry.envs.benchmarks.generalization.twinrooms import TwinRooms
from rlberry.agents.mbqvi import MBQVIAgent
from rlberry.wrappers.discretize_state import DiscretizeStateWrapper
from rlberry.seeding import Seeder

seeder = Seeder(123)

env = TwinRooms()
env = DiscretizeStateWrapper(env, n_bins=20)
env.reseed(seeder)
horizon = 20
agent = MBQVIAgent(env, n_samples=10, gamma=1.0, horizon=horizon)
agent.reseed(seeder)
agent.fit()

state = env.reset()
env.enable_rendering()
for ii in range(10):
    action = agent.policy(state)
    ns, rr, _, _ = env.step(action)
    state = ns
Exemplo n.º 25
0
    def __init__(
        self,
        agent_class,
        train_env,
        fit_budget=None,
        eval_env=None,
        init_kwargs=None,
        fit_kwargs=None,
        eval_kwargs=None,
        agent_name=None,
        n_fit=4,
        output_dir=None,
        parallelization="thread",
        max_workers=None,
        mp_context="spawn",
        worker_logging_level="INFO",
        seed=None,
        enable_tensorboard=False,
        outdir_id_style="timestamp",
        default_writer_kwargs=None,
        init_kwargs_per_instance=None,
    ):
        # agent_class should only be None when the constructor is called
        # by the class method AgentManager.load(), since the agent class
        # will be loaded.

        if agent_class is None:
            return None  # Must only happen when load() method is called.

        self.seeder = Seeder(seed)
        self.eval_seeder = self.seeder.spawn(1)

        self.agent_name = agent_name
        if agent_name is None:
            self.agent_name = agent_class.name

        # Check train_env and eval_env
        assert isinstance(
            train_env, Tuple
        ), "[AgentManager]train_env must be Tuple (constructor, kwargs)"
        if eval_env is not None:
            assert isinstance(
                eval_env, Tuple
            ), "[AgentManager]train_env must be Tuple (constructor, kwargs)"

        # check options
        assert outdir_id_style in [None, "unique", "timestamp"]

        # create oject identifier
        self.unique_id = metadata_utils.get_unique_id(self)
        self.timestamp_id = metadata_utils.get_readable_id(self)

        # Agent class
        self.agent_class = agent_class

        # Train env
        self.train_env = train_env

        # Check eval_env
        if eval_env is None:
            eval_env = deepcopy(train_env)

        self._eval_env = eval_env

        # check kwargs
        fit_kwargs = fit_kwargs or {}
        eval_kwargs = eval_kwargs or {}

        # params
        base_init_kwargs = init_kwargs or {}
        self._base_init_kwargs = deepcopy(base_init_kwargs)
        self.fit_kwargs = deepcopy(fit_kwargs)
        self.eval_kwargs = deepcopy(eval_kwargs)
        self.n_fit = n_fit
        self.parallelization = parallelization
        self.max_workers = max_workers
        self.mp_context = mp_context
        self.worker_logging_level = worker_logging_level
        self.output_dir = output_dir
        if fit_budget is not None:
            self.fit_budget = fit_budget
        else:
            try:
                self.fit_budget = self.fit_kwargs.pop("fit_budget")
            except KeyError:
                raise ValueError("[AgentManager] fit_budget missing in __init__().")
        # extra params per instance
        if init_kwargs_per_instance is not None:
            assert len(init_kwargs_per_instance) == n_fit
            init_kwargs_per_instance = deepcopy(init_kwargs_per_instance)
        self.init_kwargs_per_instance = init_kwargs_per_instance or [
            dict() for _ in range(n_fit)
        ]

        # output dir
        if output_dir is None:
            output_dir_ = metadata_utils.RLBERRY_TEMP_DATA_DIR
        else:
            output_dir_ = output_dir
        self.output_dir_ = Path(output_dir_) / "manager_data"
        if outdir_id_style == "unique":
            self.output_dir_ = self.output_dir_ / (
                self.agent_name + "_" + self.unique_id
            )
        elif outdir_id_style == "timestamp":
            self.output_dir_ = self.output_dir_ / (
                self.agent_name + "_" + self.timestamp_id
            )

        # Create list of writers for each agent that will be trained
        # 'default' will keep Agent's use of DefaultWriter.
        self.writers = [("default", None) for _ in range(n_fit)]

        # Parameters to setup Agent's DefaultWriter
        self.agent_default_writer_kwargs = [
            dict(
                name=self.agent_name,
                log_interval=3,
                tensorboard_kwargs=None,
                execution_metadata=metadata_utils.ExecutionMetadata(obj_worker_id=idx),
            )
            for idx in range(n_fit)
        ]
        self.tensorboard_dir = None
        if enable_tensorboard:
            self.tensorboard_dir = self.output_dir_ / "tensorboard"
            for idx, params in enumerate(self.agent_default_writer_kwargs):
                params["tensorboard_kwargs"] = dict(
                    log_dir=self.tensorboard_dir / str(idx)
                )
        # Update DefaultWriter according to user's settings.
        default_writer_kwargs = default_writer_kwargs or {}
        if default_writer_kwargs:
            logger.warning(
                "(Re)defining the following DefaultWriter"
                f" parameters in AgentManager: {list(default_writer_kwargs.keys())}"
            )
        for ii in range(n_fit):
            self.agent_default_writer_kwargs[ii].update(default_writer_kwargs)

        # agent handlers and init kwargs
        self._set_init_kwargs()  # init_kwargs for each agent
        self.agent_handlers = None
        self._reset_agent_handlers()
        self.default_writer_data = None
        self.best_hyperparams = None

        # optuna study and database
        self.optuna_study = None
        self.db_filename = None
        self.optuna_storage_url = None

        # rlberry version for reproducibility purpose
        self.rlberry_version = rlberry.__version__
Exemplo n.º 26
0
 def __init__(self):
     self.observation_space = None
     self.action_space = None
     self.reward_range: tuple = (-np.inf, np.inf)
     # random number generator
     self.seeder = Seeder()
Exemplo n.º 27
0
def test_cor_normal():
    env = CorruptedNormalBandit(means=[0, 1], cor_prop=0.1)
    safe_reseed(env, Seeder(TEST_SEED))

    sample = [env.step(1)[1] for f in range(1000)]
    assert np.abs(np.median(sample) - 1) < 0.5
Exemplo n.º 28
0
def test_normal():
    env = NormalBandit(means=[0, 1])
    safe_reseed(env, Seeder(TEST_SEED))

    sample = [env.step(1)[1] for f in range(1000)]
    assert np.abs(np.mean(sample) - 1) < 0.1
Exemplo n.º 29
0
def test_bernoulli():
    env = BernoulliBandit(p=[0.05, 0.95])
    safe_reseed(env, Seeder(TEST_SEED))

    sample = [env.step(1)[1] for f in range(1000)]
    assert np.abs(np.mean(sample) - 0.95) < 0.1
Exemplo n.º 30
0
class AgentManager:
    """
    Class to train, optimize hyperparameters, evaluate and gather
    statistics about an agent.

    Notes
    -----
    If parallelization="process" and mp_context="spawn", make sure your main code
    has a guard `if __name__ == '__main__'`. See https://github.com/google/jax/issues/1805
    and https://stackoverflow.com/a/66290106.

    Parameters
    ----------
    agent_class
        Class of the agent.
    train_env : tuple (constructor, kwargs)
        Enviroment used to initialize/train the agent.
    fit_budget : int
        Budget used to call :meth:`rlberry.agents.agent.Agent.fit`.
        If None, must be given in ``fit_kwargs['fit_budget']``.
    eval_env : Tuple (constructor, kwargs)
        Environment used to evaluate the agent. If None, set to ``train_env``.
    init_kwargs : dict
        Arguments required by the agent's constructor. Shared across all n_fit instances.
    fit_kwargs : dict
        Extra arguments to call :meth:`rlberry.agents.agent.Agent.fit`.
    eval_kwargs : dict
        Arguments required to call :meth:`rlberry.agents.agent.Agent.eval`.
    agent_name : str
        Name of the agent. If None, set to agent_class.name
    n_fit : int
        Number of agent instances to fit.
    output_dir : str or :class:`pathlib.Path`
        Directory where to store data.
    parallelization: {'thread', 'process'}, default: 'thread'
        Whether to parallelize  agent training using threads or processes.
    max_workers: None or int, default: None
        Number of processes/threads used in a call to fit().
        If None and parallelization='process', it will default to the
        number of processors on the machine.
        If None and parallelization='thread', it will default to the
        number of processors on the machine, multiplied by 5.
    mp_context: {'spawn', 'fork'}, default: 'spawn'.
        Context for python multiprocessing module.
        Warning: If you're using JAX or PyTorch, it only works with 'spawn'.
                 If running code on a notebook or interpreter, use 'fork'.
    worker_logging_level : str, default: 'INFO'
        Logging level in each of the threads/processes used to fit agents.
    seed : :class:`numpy.random.SeedSequence`, :class:`~rlberry.seeding.seeder.Seeder` or int, default : None
        Seed sequence from which to spawn the random number generator.
        If None, generate random seed.
        If int, use as entropy for SeedSequence.
        If seeder, use seeder.seed_seq
    enable_tensorboard : bool, default : False
        If True, enable tensorboard logging in Agent's :class:`~rlberry.utils.writers.DefaultWriter`.
    outdir_id_style: {None, 'unique', 'timestamp'}, default = 'timestamp'
        If None, data is saved to output_dir/manager_data
        If 'unique', data is saved to ``output_dir/manager_data/<AGENT_NAME_UNIQUE_ID>``
        If 'timestamp', data is saved to ``output_dir/manager_data/<AGENT_NAME_TIMESTAMP_SHORT_ID>``
    default_writer_kwargs : dict
        Optional arguments for :class:`~rlberry.utils.writers.DefaultWriter`.
    init_kwargs_per_instance : List[dict] (optional)
        List of length ``n_fit`` containing the params to initialize each of
        the ``n_fit`` agent instances. It can be useful if different instances
        require different parameters. If the same parameter is defined by
        ``init_kwargs`` and ``init_kwargs_per_instance``, the value given by
        ``init_kwargs_per_instance`` will be used.


    Attributes
    ----------
    output_dir : :class:`pathlib.Path`
        Directory where the manager saves data.
    """

    def __init__(
        self,
        agent_class,
        train_env,
        fit_budget=None,
        eval_env=None,
        init_kwargs=None,
        fit_kwargs=None,
        eval_kwargs=None,
        agent_name=None,
        n_fit=4,
        output_dir=None,
        parallelization="thread",
        max_workers=None,
        mp_context="spawn",
        worker_logging_level="INFO",
        seed=None,
        enable_tensorboard=False,
        outdir_id_style="timestamp",
        default_writer_kwargs=None,
        init_kwargs_per_instance=None,
    ):
        # agent_class should only be None when the constructor is called
        # by the class method AgentManager.load(), since the agent class
        # will be loaded.

        if agent_class is None:
            return None  # Must only happen when load() method is called.

        self.seeder = Seeder(seed)
        self.eval_seeder = self.seeder.spawn(1)

        self.agent_name = agent_name
        if agent_name is None:
            self.agent_name = agent_class.name

        # Check train_env and eval_env
        assert isinstance(
            train_env, Tuple
        ), "[AgentManager]train_env must be Tuple (constructor, kwargs)"
        if eval_env is not None:
            assert isinstance(
                eval_env, Tuple
            ), "[AgentManager]train_env must be Tuple (constructor, kwargs)"

        # check options
        assert outdir_id_style in [None, "unique", "timestamp"]

        # create oject identifier
        self.unique_id = metadata_utils.get_unique_id(self)
        self.timestamp_id = metadata_utils.get_readable_id(self)

        # Agent class
        self.agent_class = agent_class

        # Train env
        self.train_env = train_env

        # Check eval_env
        if eval_env is None:
            eval_env = deepcopy(train_env)

        self._eval_env = eval_env

        # check kwargs
        fit_kwargs = fit_kwargs or {}
        eval_kwargs = eval_kwargs or {}

        # params
        base_init_kwargs = init_kwargs or {}
        self._base_init_kwargs = deepcopy(base_init_kwargs)
        self.fit_kwargs = deepcopy(fit_kwargs)
        self.eval_kwargs = deepcopy(eval_kwargs)
        self.n_fit = n_fit
        self.parallelization = parallelization
        self.max_workers = max_workers
        self.mp_context = mp_context
        self.worker_logging_level = worker_logging_level
        self.output_dir = output_dir
        if fit_budget is not None:
            self.fit_budget = fit_budget
        else:
            try:
                self.fit_budget = self.fit_kwargs.pop("fit_budget")
            except KeyError:
                raise ValueError("[AgentManager] fit_budget missing in __init__().")
        # extra params per instance
        if init_kwargs_per_instance is not None:
            assert len(init_kwargs_per_instance) == n_fit
            init_kwargs_per_instance = deepcopy(init_kwargs_per_instance)
        self.init_kwargs_per_instance = init_kwargs_per_instance or [
            dict() for _ in range(n_fit)
        ]

        # output dir
        if output_dir is None:
            output_dir_ = metadata_utils.RLBERRY_TEMP_DATA_DIR
        else:
            output_dir_ = output_dir
        self.output_dir_ = Path(output_dir_) / "manager_data"
        if outdir_id_style == "unique":
            self.output_dir_ = self.output_dir_ / (
                self.agent_name + "_" + self.unique_id
            )
        elif outdir_id_style == "timestamp":
            self.output_dir_ = self.output_dir_ / (
                self.agent_name + "_" + self.timestamp_id
            )

        # Create list of writers for each agent that will be trained
        # 'default' will keep Agent's use of DefaultWriter.
        self.writers = [("default", None) for _ in range(n_fit)]

        # Parameters to setup Agent's DefaultWriter
        self.agent_default_writer_kwargs = [
            dict(
                name=self.agent_name,
                log_interval=3,
                tensorboard_kwargs=None,
                execution_metadata=metadata_utils.ExecutionMetadata(obj_worker_id=idx),
            )
            for idx in range(n_fit)
        ]
        self.tensorboard_dir = None
        if enable_tensorboard:
            self.tensorboard_dir = self.output_dir_ / "tensorboard"
            for idx, params in enumerate(self.agent_default_writer_kwargs):
                params["tensorboard_kwargs"] = dict(
                    log_dir=self.tensorboard_dir / str(idx)
                )
        # Update DefaultWriter according to user's settings.
        default_writer_kwargs = default_writer_kwargs or {}
        if default_writer_kwargs:
            logger.warning(
                "(Re)defining the following DefaultWriter"
                f" parameters in AgentManager: {list(default_writer_kwargs.keys())}"
            )
        for ii in range(n_fit):
            self.agent_default_writer_kwargs[ii].update(default_writer_kwargs)

        # agent handlers and init kwargs
        self._set_init_kwargs()  # init_kwargs for each agent
        self.agent_handlers = None
        self._reset_agent_handlers()
        self.default_writer_data = None
        self.best_hyperparams = None

        # optuna study and database
        self.optuna_study = None
        self.db_filename = None
        self.optuna_storage_url = None

        # rlberry version for reproducibility purpose
        self.rlberry_version = rlberry.__version__

    def _init_optuna_storage_url(self):
        self.output_dir_.mkdir(parents=True, exist_ok=True)
        self.db_filename = self.output_dir_ / "optuna_data.db"

        if create_database(self.db_filename):
            self.optuna_storage_url = f"sqlite:///{self.db_filename}"
        else:
            self.db_filename = None
            self.optuna_storage_url = "sqlite:///:memory:"
            logger.warning(
                f"Unable to create databate {self.db_filename}. Using sqlite:///:memory:"
            )

    def _set_init_kwargs(self):
        init_seeders = self.seeder.spawn(self.n_fit, squeeze=False)
        self.init_kwargs = []
        for ii in range(self.n_fit):
            kwargs_ii = deepcopy(self._base_init_kwargs)
            kwargs_ii.update(
                dict(
                    env=self.train_env,
                    eval_env=self._eval_env,
                    copy_env=False,
                    seeder=init_seeders[ii],
                    output_dir=Path(self.output_dir_) / f"output_{ii}",
                    _execution_metadata=self.agent_default_writer_kwargs[ii][
                        "execution_metadata"
                    ],
                    _default_writer_kwargs=self.agent_default_writer_kwargs[ii],
                )
            )
            per_instance_kwargs = self.init_kwargs_per_instance[ii]
            kwargs_ii.update(per_instance_kwargs)
            self.init_kwargs.append(kwargs_ii)

    def _reset_agent_handlers(self):
        handlers_seeders = self.seeder.spawn(self.n_fit, squeeze=False)
        self.agent_handlers = [
            AgentHandler(
                id=ii,
                filename=self.output_dir_ / Path(f"agent_handlers/idx_{ii}"),
                seeder=handlers_seeders[ii],
                agent_class=self.agent_class,
                agent_instance=None,
                # kwargs
                agent_kwargs=self.init_kwargs[ii],
            )
            for ii in range(self.n_fit)
        ]
        self.clear_handlers()

    def build_eval_env(self) -> types.Env:
        """Return an instantiated and reseeded evaluation environment.

        Returns
        -------
        :class:`types.Env`
            Instance of evaluation environment.
        """
        return process_env(self._eval_env, self.seeder)

    def get_writer_data(self):
        """Return a dataframe containing data from the writer of the agents.

        Returns
        -------
        :class:`pandas.DataFrame`
            Data from the agents' writers.
        """
        return self.default_writer_data

    def get_agent_instances(self):
        """Returns a list containing ``n_fit`` agent instances.

        Returns
        -------
        list of :class:`~rlberry.agents.agent.Agent`
            ``n_fit`` instances of the managed agents.
        """
        if self.agent_handlers:
            return [
                agent_handler.get_instance() for agent_handler in self.agent_handlers
            ]
        return []

    def eval_agents(self, n_simulations: Optional[int] = None) -> list:
        """
        Call :meth:`eval` method in the managed agents and returns a list with the results.

        Parameters
        ----------
        n_simulations : int
            Total number of agent evaluations. If None, set to 2*(number of agents)

        Returns
        -------
        list
            list of length ``n_simulations`` containing the outputs
            of :meth:`~rlberry.agents.agent.Agent.eval`.
        """
        if not n_simulations:
            n_simulations = 2 * self.n_fit
        values = []
        for ii in range(n_simulations):
            # randomly choose one of the fitted agents
            agent_idx = self.eval_seeder.rng.choice(len(self.agent_handlers))
            agent = self.agent_handlers[agent_idx]
            if agent.is_empty():
                logger.error(
                    "Calling eval() in an AgentManager instance contaning an empty AgentHandler."
                    " Returning []."
                )
                return []
            values.append(agent.eval(**self.eval_kwargs))
            logger.info(f"[eval]... simulation {ii + 1}/{n_simulations}")
        return values

    def clear_output_dir(self):
        """Delete output_dir and all its data."""
        try:
            shutil.rmtree(self.output_dir_)
        except FileNotFoundError:
            logger.warning(f"No directory {self.output_dir_} found to be deleted.")

    def clear_handlers(self):
        """Delete files from output_dir/agent_handlers that are managed by this class."""
        for handler in self.agent_handlers:
            if handler._fname.exists():
                handler._fname.unlink()

    def set_writer(self, idx, writer_fn, writer_kwargs=None):
        """Defines the writer for one of the managed agents.

        Note
        -----
        Must be called right after creating an instance of AgentManager.

        Parameters
        ----------
        writer_fn : callable, None or 'default'
            Returns a writer for an agent, e.g. tensorboard SummaryWriter,
            rlberry DefaultWriter.
            If 'default', use the default writer in the Agent class.
            If None, disable any writer
        writer_kwargs : dict or None
            kwargs for writer_fn
        idx : int
            Index of the agent to set the writer (0 <= idx < `n_fit`).
            AgentManager fits `n_fit` agents, the writer of each one of them
            needs to be set separetely.
        """
        assert (
            idx >= 0 and idx < self.n_fit
        ), "Invalid index sent to AgentManager.set_writer()"
        writer_kwargs = writer_kwargs or {}
        self.writers[idx] = (writer_fn, writer_kwargs)

    def fit(self, budget=None, **kwargs):
        """Fit the agent instances in parallel.

        Parameters
        ----------
        budget: int or None
            Computational or sample complexity budget.
        """
        del kwargs
        budget = budget or self.fit_budget

        # If spawn, test that protected by if __name__ == "__main__"
        if self.mp_context == "spawn":
            try:
                _check_not_importing_main()
            except RuntimeError as exc:
                raise RuntimeError(
                    """Warning: in AgentManager, if mp_context='spawn' and
                        parallelization="process" then the script must be run
                        outside a notebook and protected by a  if __name__ == '__main__':
                        For example:
                            if __name__ == '__main__':
                                agent = AgentManager(UCBVIAgent,(Chain, {}),
                                                mp_context="spawn",
                                                parallelization="process")

                                agent.fit(10)
                                   """
                ) from exc

        logger.info(
            f"Running AgentManager fit() for {self.agent_name}"
            f" with n_fit = {self.n_fit} and max_workers = {self.max_workers}."
        )
        seeders = self.seeder.spawn(self.n_fit)
        if not isinstance(seeders, list):
            seeders = [seeders]

        # remove agent instances from memory so that the agent handlers can be
        # sent to different workers
        for handler in self.agent_handlers:
            handler.dump()

        if self.parallelization == "thread":
            executor_class = concurrent.futures.ThreadPoolExecutor
            lock = threading.Lock()
        elif self.parallelization == "process":
            executor_class = functools.partial(
                concurrent.futures.ProcessPoolExecutor,
                mp_context=multiprocessing.get_context(self.mp_context),
            )
            lock = multiprocessing.Manager().Lock()
        else:
            raise ValueError(
                f"Invalid backend for parallelization: {self.parallelization}"
            )

        args = [
            (
                lock,
                handler,
                self.agent_class,
                budget,
                init_kwargs,
                deepcopy(self.fit_kwargs),
                writer,
                self.worker_logging_level,
                seeder,
            )
            for init_kwargs, handler, seeder, writer in zip(
                self.init_kwargs, self.agent_handlers, seeders, self.writers
            )
        ]

        if len(args) == 1:
            workers_output = [_fit_worker(args[0])]

        else:
            with executor_class(max_workers=self.max_workers) as executor:
                futures = []
                for arg in args:
                    futures.append(executor.submit(_fit_worker, arg))

                workers_output = []
                for future in concurrent.futures.as_completed(futures):
                    workers_output.append(future.result())
                executor.shutdown()

        workers_output.sort(key=lambda x: x.id)
        self.agent_handlers = workers_output

        logger.info("... trained!")

        # gather all stats in a dictionary
        self._gather_default_writer_data()

    def _gather_default_writer_data(self):
        """Gather DefaultWriter data in a dictionary"""
        self.default_writer_data = {}
        for ii, agent in enumerate(self.agent_handlers):
            if not agent.is_empty() and isinstance(agent.writer, DefaultWriter):
                self.default_writer_data[ii] = agent.writer.data

    def save(self):
        """Save AgentManager data to :attr:`~rlberry.manager.agent_manager.AgentManager.output_dir`.

        Saves object so that the data can be later loaded to recreate an AgentManager instance.

        Returns
        -------
        :class:`pathlib.Path`
            Filename where the AgentManager object was saved.
        """
        # use self.output_dir
        output_dir = self.output_dir_
        output_dir = Path(output_dir)

        # create dir if it does not exist
        output_dir.mkdir(parents=True, exist_ok=True)
        # save optimized hyperparameters
        if self.best_hyperparams is not None:
            fname = Path(output_dir) / "best_hyperparams.json"
            _safe_serialize_json(self.best_hyperparams, fname)
        # save default_writer_data that can be aggregated in a pandas DataFrame
        if self.default_writer_data is not None:
            data_list = []
            for idx in self.default_writer_data:
                df = self.default_writer_data[idx]
                data_list.append(df)
            if len(data_list) > 0:
                all_writer_data = pd.concat(data_list, ignore_index=True)
                try:
                    output = pd.DataFrame(all_writer_data)
                    # save
                    fname = Path(output_dir) / "data.csv"
                    output.to_csv(fname, index=None)
                except Exception:
                    logger.warning("Could not save default_writer_data.")

        #
        # Pickle AgentManager instance
        #

        # clear agent handlers
        for handler in self.agent_handlers:
            handler.dump()

        # save
        filename = Path("manager_obj").with_suffix(".pickle")
        filename = output_dir / filename
        filename.parent.mkdir(parents=True, exist_ok=True)
        try:
            with filename.open("wb") as ff:
                pickle.dump(self.__dict__, ff)
            logger.info("Saved AgentManager({}) using pickle.".format(self.agent_name))
        except Exception:
            try:
                with filename.open("wb") as ff:
                    dill.dump(self.__dict__, ff)
                logger.info(
                    "Saved AgentManager({}) using dill.".format(self.agent_name)
                )
            except Exception as ex:
                logger.warning("[AgentManager] Instance cannot be pickled: " + str(ex))

        return filename

    @classmethod
    def load(cls, filename):
        """Loads an AgentManager instance from a file.

        Parameters
        ----------
        filename: str or :class:`pathlib.Path`

        Returns
        -------
        :class:`rlberry.manager.AgentManager`
            Loaded instance of AgentManager.
        """
        filename = Path(filename).with_suffix(".pickle")

        obj = cls(None, None, None)
        try:
            with filename.open("rb") as ff:
                tmp_dict = pickle.load(ff)
            logger.info("Loaded AgentManager using pickle.")
        except Exception:
            with filename.open("rb") as ff:
                tmp_dict = dill.load(ff)
            logger.info("Loaded AgentManager using dill.")

        obj.__dict__.clear()
        obj.__dict__.update(tmp_dict)

        return obj

    def __eq__(self, other):

        result = True
        self_init_kwargs = [_strip_seed_dir(kw) for kw in self.init_kwargs]
        other_init_kwargs = [_strip_seed_dir(kw) for kw in other.init_kwargs]
        result = result and all(
            [
                self_init_kwargs[f] == other_init_kwargs[f]
                for f in range(len(self_init_kwargs))
            ]
        )

        self_eval_kwargs = self.eval_kwargs or {}
        other_eval_kwargs = other.eval_kwargs or {}
        result = result and (self_eval_kwargs == other_eval_kwargs)

        result = result and (other.agent_class == self.agent_class)

        result = result and (self.fit_kwargs == other.fit_kwargs)

        result = result and (self.fit_budget == other.fit_budget)

        return result

    def optimize_hyperparams(
        self,
        n_trials=256,
        timeout=60,
        n_fit=2,
        n_optuna_workers=2,
        optuna_parallelization="thread",
        sampler_method="optuna_default",
        pruner_method="halving",
        continue_previous=False,
        fit_fraction=1.0,
        sampler_kwargs=None,
        disable_evaluation_writers=True,
    ):
        """Run hyperparameter optimization and updates init_kwargs with the best hyperparameters found.

        Currently supported sampler_method:
            'random' -> Random Search
            'optuna_default' -> TPE
            'grid' -> Grid Search
            'cmaes' -> CMA-ES

        Currently supported pruner_method:
            'none'
            'halving'

        Note
        ----
        * After calling this method, agent handlers from previous calls to fit() will be erased.
        It is suggested to call fit() *after* a call to optimize_hyperparams().
        * This method calls self.save() before the optuna optimization starts, to ensure
        that we can continue the optimization later even if the program is stopped before the
        optimization is finished.

        Parameters
        ----------
        n_trials: int
            Number of agent evaluations
        timeout: int
            Stop study after the given number of second(s).
            Set to None for unlimited time.
        n_fit: int
            Number of agents to fit for each hyperparam evaluation.
        n_optuna_workers: int
            Number of workers used by optuna for optimization.
        optuna_parallelization : 'thread' or 'process'
            Whether to use threads or processes for optuna parallelization.
        sampler_method : str
            Optuna sampling method.
        pruner_method : str
            Optuna pruner method.
        continue_previous : bool
            Set to true to continue previous Optuna study. If true,
            sampler_method and pruner_method will be
            the same as in the previous study.
        fit_fraction : double, in ]0, 1]
            Fraction of the agent to fit for partial evaluation
            (allows pruning of trials).
        sampler_kwargs : dict or None
            Allows users to use different Optuna samplers with
            personalized arguments.
        evaluation_function : callable(agent_list, eval_env, **kwargs)->double, default: None
            Function to maximize, that takes a list of agents and an environment as input, and returns a double.
            If None, search for hyperparameters that maximize the mean reward.
        evaluation_function_kwargs : dict or None
            kwargs for evaluation_function
        disable_evaluation_writers : bool, default: True
            If true, disable writers of agents used in the hyperparameter evaluation.

        Returns
        -------
        dict
            Optimized hyperparameters.
        """
        #
        # setup
        #
        TEMP_DIR = self.output_dir_ / "optim"

        global _OPTUNA_INSTALLED
        if not _OPTUNA_INSTALLED:
            logging.error("Optuna not installed.")
            return

        assert fit_fraction > 0.0 and fit_fraction <= 1.0

        #
        # Create optuna study
        #
        if continue_previous:
            assert self.optuna_study is not None
            study = self.optuna_study

        else:
            if sampler_kwargs is None:
                sampler_kwargs = {}
            # get sampler
            if sampler_method == "random":
                sampler = optuna.samplers.RandomSampler()
            elif sampler_method == "grid":
                assert (
                    sampler_kwargs is not None
                ), "To use GridSampler, a search_space dictionary must be provided."
                sampler = optuna.samplers.GridSampler(**sampler_kwargs)
            elif sampler_method == "cmaes":
                sampler = optuna.samplers.CmaEsSampler(**sampler_kwargs)
            elif sampler_method == "optuna_default":
                sampler = optuna.samplers.TPESampler(**sampler_kwargs)
            else:
                raise NotImplementedError(
                    "Sampler method %s is not implemented." % sampler_method
                )

            # get pruner
            if pruner_method == "halving":
                pruner = optuna.pruners.SuccessiveHalvingPruner(
                    min_resource=1, reduction_factor=4, min_early_stopping_rate=0
                )
            elif pruner_method == "none":
                pruner = None
            else:
                raise NotImplementedError(
                    "Pruner method %s is not implemented." % pruner_method
                )

            # storage
            self._init_optuna_storage_url()
            storage = optuna.storages.RDBStorage(self.optuna_storage_url)

            # optuna study
            study = optuna.create_study(
                sampler=sampler, pruner=pruner, storage=storage, direction="maximize"
            )
            self.optuna_study = study

        # save, to that optimization can be resumed later
        self.save()

        #
        # Objective function
        #
        objective = functools.partial(
            _optuna_objective,
            base_init_kwargs=self._base_init_kwargs,  # self._base_init_kwargs
            agent_class=self.agent_class,  # self.agent_class
            train_env=self.train_env,  # self.train_env
            eval_env=self._eval_env,
            fit_budget=self.fit_budget,  # self.fit_budget
            eval_kwargs=self.eval_kwargs,  # self.eval_kwargs
            n_fit=n_fit,
            temp_dir=TEMP_DIR,  # TEMP_DIR
            disable_evaluation_writers=disable_evaluation_writers,
            fit_fraction=fit_fraction,
        )

        try:
            if optuna_parallelization == "thread":
                with concurrent.futures.ThreadPoolExecutor() as executor:
                    for _ in range(n_optuna_workers):
                        executor.submit(
                            study.optimize,
                            objective,
                            n_trials=n_trials,
                            timeout=timeout,
                            gc_after_trial=True,
                        )
                    executor.shutdown()
            elif optuna_parallelization == "process":
                with concurrent.futures.ProcessPoolExecutor(
                    mp_context=multiprocessing.get_context(self.mp_context)
                ) as executor:
                    for _ in range(n_optuna_workers):
                        executor.submit(
                            study.optimize,
                            objective,
                            n_trials=n_trials // n_optuna_workers,
                            timeout=timeout,
                            gc_after_trial=True,
                        )
                    executor.shutdown()
            else:
                raise ValueError(
                    f"Invalid value for optuna_parallelization: {optuna_parallelization}."
                )

        except KeyboardInterrupt:
            logger.warning("Evaluation stopped.")

        # clear temp folder
        try:
            shutil.rmtree(TEMP_DIR)
        except FileNotFoundError as ex:
            logger.warning(f"Could not delete {TEMP_DIR}: {ex}")

        # continue
        try:
            best_trial = study.best_trial
        except ValueError as ex:
            logger.error(f"Hyperparam optimization failed due to the error: {ex}")
            return dict()

        logger.info(f"Number of finished trials: {len(study.trials)}")
        logger.info("Best trial:")
        logger.info(f"Value: {best_trial.value}")
        logger.info("Params:")
        for key, value in best_trial.params.items():
            logger.info(f"    {key}: {value}")

        # store best parameters
        self.best_hyperparams = best_trial.params

        # update using best parameters
        self._base_init_kwargs.update(best_trial.params)

        # reset init_kwargs and agent handlers, so that they take the new
        # parameters
        self._set_init_kwargs()
        self._reset_agent_handlers()

        return deepcopy(best_trial.params)