예제 #1
0
def select_env(model_class: BaseAlgorithm) -> gym.Env:
    """
    Selects an environment with the correct action space as DQN only supports discrete action space
    """
    if model_class == DQN:
        return IdentityEnv(10)
    else:
        return IdentityEnvBox(10)
def select_env(model_class: BaseAlgorithm) -> gym.Env:
    """
    Selects an environment with the correct action space as QRDQN, DQNClipped, DQNReg only support discrete action space
    """
    if model_class in {QRDQN, DQNReg, DQNClipped}:
        return IdentityEnv(10)
    else:
        return IdentityEnvBox(10)
예제 #3
0
def test_identity(model_name):
    """
    Test if the algorithm (with a given policy)
    can learn an identity transformation (i.e. return observation as an action)
    :param model_name: (str) Name of the RL model
    """
    env = DummyVecEnv([lambda: IdentityEnv(10)])

    model = LEARN_FUNC_DICT[model_name](env)
    evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)

    obs = env.reset()
    assert model.action_probability(obs).shape == (
        1, 10), "Error: action_probability not returning correct shape"
    action = env.action_space.sample()
    action_prob = model.action_probability(obs, actions=action)
    assert np.prod(action_prob.shape) == 1, "Error: not scalar probability"
    action_logprob = model.action_probability(obs, actions=action, logp=True)
    assert np.allclose(action_prob,
                       np.exp(action_logprob)), (action_prob, action_logprob)

    # Free memory
    del model, env
예제 #4
0
from stable_baselines3 import A2C, PPO, SAC, TD3, DQN
from stable_baselines3.common.identity_env import (IdentityEnvBox, IdentityEnv,
                                                   IdentityEnvMultiBinary,
                                                   IdentityEnvMultiDiscrete)

from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.noise import NormalActionNoise

DIM = 4


@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
@pytest.mark.parametrize("env", [
    IdentityEnv(DIM),
    IdentityEnvMultiDiscrete(DIM),
    IdentityEnvMultiBinary(DIM)
])
def test_discrete(model_class, env):
    env_ = DummyVecEnv([lambda: env])
    kwargs = {}
    n_steps = 3000
    if model_class == DQN:
        kwargs = dict(learning_starts=0)
        n_steps = 4000
        # DQN only support discrete actions
        if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
            return

    model = model_class('MlpPolicy', env_, gamma=0.5, seed=1,
import numpy as np
import pytest

from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.identity_env import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.vec_env import DummyVecEnv

DIM = 4


@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
@pytest.mark.parametrize("env", [IdentityEnv(DIM), IdentityEnvMultiDiscrete(DIM), IdentityEnvMultiBinary(DIM)])
def test_discrete(model_class, env):
    env_ = DummyVecEnv([lambda: env])
    kwargs = {}
    n_steps = 3000
    if model_class == DQN:
        kwargs = dict(learning_starts=0)
        n_steps = 4000
        # DQN only support discrete actions
        if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
            return

    model = model_class("MlpPolicy", env_, gamma=0.5, seed=1, **kwargs).learn(n_steps)

    evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90)
    obs = env.reset()

    assert np.shape(model.predict(obs)[0]) == np.shape(obs)
예제 #6
0
def test_discrete(model_class):
    env = IdentityEnv(10)
    model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000)

    evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)