예제 #1
0
    def test_config(normalization_config):
        # init environment
        env = GymMazeEnv("CartPole-v0")

        # wrap env with observation normalization
        env = ObservationNormalizationWrapper(
            env,
            default_strategy=normalization_config["default_strategy"],
            default_strategy_config=normalization_config[
                "default_strategy_config"],
            default_statistics=normalization_config["default_statistics"],
            statistics_dump=normalization_config["statistics_dump"],
            sampling_policy=normalization_config['sampling_policy'],
            exclude=normalization_config["exclude"],
            manual_config=normalization_config["manual_config"])

        # check if action space clipping was applied
        assert np.alltrue(env.observation_space["observation"].high <= 1.0)
        assert np.alltrue(env.observation_space["observation"].low >= 0.0)

        # check if stats have been set properly
        statistics = env.get_statistics()
        assert np.all(statistics["observation"]["mean"] == np.zeros(shape=4))
        assert np.all(statistics["observation"]["std"] == np.ones(shape=4))

        # test sampling
        obs = random_env_steps(env, steps=100)
        assert np.min(obs) >= 0 and np.max(obs) <= 1
def test_heuristic_lunar_lander_policy():
    """unit tests"""
    policy = HeuristicLunarLanderPolicy()
    env = GymMazeEnv("LunarLander-v2")

    obs = env.reset()
    action = policy.compute_action(obs)
    obs, _, _, _ = env.step(action)
예제 #3
0
def train_function(n_epochs: int, distributed_env_cls) -> A2C:
    """Trains the cart pole environment with the multi-step a2c implementation.
    """

    # initialize distributed env
    envs = distributed_env_cls([lambda: GymMazeEnv(env="CartPole-v0") for _ in range(2)])

    # initialize the env and enable statistics collection
    eval_env = distributed_env_cls([lambda: GymMazeEnv(env="CartPole-v0") for _ in range(2)],
                                   logging_prefix='eval')

    # init distribution mapper
    env = GymMazeEnv(env="CartPole-v0")
    distribution_mapper = DistributionMapper(action_space=env.action_space, distribution_mapper_config={})

    # initialize policies
    policies = {0: FlattenConcatPolicyNet({'observation': (4,)}, {'action': (2,)}, hidden_units=[16], non_lin=nn.Tanh)}

    # initialize critic
    critics = {0: FlattenConcatStateValueNet({'observation': (4,)}, hidden_units=[16], non_lin=nn.Tanh)}

    # algorithm configuration
    algorithm_config = A2CAlgorithmConfig(
        n_epochs=n_epochs,
        epoch_length=2,
        patience=10,
        critic_burn_in_epochs=0,
        n_rollout_steps=20,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.0,
        max_grad_norm=0.0,
        device="cpu",
        rollout_evaluator=RolloutEvaluator(eval_env=eval_env, n_episodes=1, model_selection=None, deterministic=True)
    )

    # initialize actor critic model
    model = TorchActorCritic(
        policy=TorchPolicy(networks=policies, distribution_mapper=distribution_mapper, device=algorithm_config.device),
        critic=TorchSharedStateCritic(networks=critics, obs_spaces_dict=env.observation_spaces_dict,
                                      device=algorithm_config.device,
                                      stack_observations=False),
        device=algorithm_config.device)

    a2c = A2C(rollout_generator=RolloutGenerator(envs),
              algorithm_config=algorithm_config,
              evaluator=algorithm_config.rollout_evaluator,
              model=model,
              model_selection=None)

    # train agent
    a2c.train()

    return a2c
예제 #4
0
def cartpole_env_factory():
    """ Env factory for the cartpole MazeEnv """
    # Registered gym environments can be instantiated first and then provided to GymMazeEnv:
    cartpole_env = gym.make("CartPole-v0")
    maze_env = GymMazeEnv(env=cartpole_env)

    # Another possibility is to supply the gym env string to GymMazeEnv directly:
    maze_env = GymMazeEnv(env="CartPole-v0")

    return maze_env
예제 #5
0
def test_observation_monitoring():
    """ Observation logging unit test """
    env = GymMazeEnv(env="CartPole-v0")

    env = ObservationVisualizationWrapper.wrap(env, plot_function=None)
    env = LogStatsWrapper.wrap(env, logging_prefix="train")

    with SimpleStatsLoggingSetup(env, log_dir="."):
        env.reset()
        done = False
        while not done:
            obs, rew, done, info = env.step(env.action_space.sample())
def test_examples_part2():
    """
    Tests snippets in maze/docs/source/concepts_and_structure/run_context_overview.rst.
    Adds some performance-specific configuration that should not influence snippets' functionality.
    Split for runtime reasons.
    """

    a2c_overrides = {"runner.concurrency": 1}
    es_overrides = {
        "algorithm.n_epochs": 1,
        "algorithm.n_rollouts_per_update": 1
    }
    env_factory = lambda: GymMazeEnv('CartPole-v0')
    alg_config = _get_alg_config('CartPole-v0', "dev")

    # ------------------------------------------------------------------

    rc = RunContext(algorithm=alg_config,
                    runner="dev",
                    overrides=a2c_overrides,
                    configuration="test")
    rc.train(n_epochs=1)

    # ------------------------------------------------------------------

    with pytest.raises(omegaconf.errors.ConfigAttributeError):
        rc = RunContext(overrides={"algorithm": alg_config}, runner="dev")
        rc.train(n_epochs=1)

    # ------------------------------------------------------------------

    rc = RunContext(env=lambda: env_factory(),
                    overrides=es_overrides,
                    runner="dev",
                    configuration="test")
    rc.train(n_epochs=1)

    # Run trained policy.
    env = env_factory()
    obs = env.reset()
    for i in range(10):
        action = rc.compute_action(obs)
        obs, rewards, dones, info = env.step(action)

    # ------------------------------------------------------------------

    rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'),
                    overrides=es_overrides,
                    runner="dev",
                    configuration="test")
    rc.train()
    rc.evaluate()
예제 #7
0
def test_evaluation():
    """
    Tests evaluation.
    """

    # Test with ES: No rollout evaluator in config.
    rc = run_context.RunContext(
        env=lambda: GymMazeEnv(env=gym.make("CartPole-v0")),
        silent=True,
        configuration="test",
        overrides={
            "runner.normalization_samples": 1,
            "runner.shared_noise_table_size": 10
        })
    rc.train(1)
    stats = rc.evaluate(n_episodes=5)
    assert len(stats) == 1
    assert stats[0][(BaseEnvEvents.reward, "episode_count", None)] in (5, 6)

    # Test with A2C: Partially specified rollout evaluator in config.
    rc = run_context.RunContext(
        env=lambda: GymMazeEnv(env=gym.make("CartPole-v0")),
        silent=True,
        algorithm="a2c",
        configuration="test",
        overrides={"runner.concurrency": 1})
    rc.train(1)
    stats = rc.evaluate(n_episodes=2)
    assert len(stats) == 1
    assert stats[0][(BaseEnvEvents.reward, "episode_count", None)] in (2, 3)

    # Test with A2C and instanatiated RolloutEvaluator.
    rc = run_context.RunContext(
        env=lambda: GymMazeEnv(env=gym.make("CartPole-v0")),
        silent=True,
        algorithm="a2c",
        configuration="test",
        overrides={
            "runner.concurrency":
            1,
            "algorithm.rollout_evaluator":
            RolloutEvaluator(eval_env=SequentialVectorEnv(
                [lambda: GymMazeEnv("CartPole-v0")]),
                             n_episodes=1,
                             model_selection=None,
                             deterministic=True)
        })
    rc.train(1)
    stats = rc.evaluate(n_episodes=5)
    assert len(stats) == 1
    assert stats[0][(BaseEnvEvents.reward, "episode_count", None)] in (1, 2)
예제 #8
0
def test_multirun():
    """
    Tests multirun capabilities.
    """

    with pytest.raises(BaseException):
        rc = run_context.RunContext(env=lambda: GymMazeEnv('CartPole-v0'),
                                    silent=True,
                                    algorithm="ppo",
                                    overrides={
                                        "runner.normalization_samples": 1,
                                        "runner.concurrency": 1,
                                        "algorithm.lr": "0.0001,0.0005,0.001"
                                    },
                                    configuration="test",
                                    multirun=False)
        rc.train(n_epochs=1)

    with pytest.raises(TypeError):
        rc = run_context.RunContext(env=lambda: GymMazeEnv('CartPole-v0'),
                                    silent=True,
                                    algorithm="ppo",
                                    overrides={
                                        "runner.normalization_samples": 1,
                                        "runner.concurrency": 1,
                                        "algorithm.lr":
                                        [0.0001, 0.0005, 0.001]
                                    },
                                    configuration="test",
                                    multirun=False)
        rc.train(n_epochs=1)

    rc = run_context.RunContext(env=lambda: GymMazeEnv('CartPole-v0'),
                                silent=True,
                                algorithm="ppo",
                                overrides={
                                    "runner.normalization_samples": 1,
                                    "runner.concurrency": 1,
                                    "algorithm.lr": [0.0001, 0.0005, 0.001]
                                },
                                configuration="test",
                                multirun=True)
    rc.train(n_epochs=1)

    assert len(rc.policy) == 3
    assert len(rc.run_dir) == 3
    assert len(rc.config[RunMode.TRAINING]) == 3
    assert len(rc.env_factory) == 3
    assert len(rc.evaluate()) == 3
예제 #9
0
def test_inconsistency_identification_type_3() -> None:
    """
    Tests identification of inconsistency due to derived config group.
    """

    es_dev_runner_config = {
        'state_dict_dump_file': 'state_dict.pt',
        'spaces_config_dump_file': 'spaces_config.pkl',
        'normalization_samples': 10000,
        '_target_': 'maze.train.trainers.es.ESDevRunner',
        'n_eval_rollouts': 10,
        'shared_noise_table_size': 100000000
    }
    a2c_alg_config = A2CAlgorithmConfig(
        n_epochs=1,
        epoch_length=25,
        patience=15,
        critic_burn_in_epochs=0,
        n_rollout_steps=100,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.00025,
        max_grad_norm=0.0,
        device='cpu',
        rollout_evaluator=RolloutEvaluator(eval_env=SequentialVectorEnv(
            [lambda: GymMazeEnv(env="CartPole-v0")]),
                                           n_episodes=1,
                                           model_selection=None,
                                           deterministic=True))
    default_overrides = {
        "runner.normalization_samples": 1,
        "runner.concurrency": 1
    }

    rc = run_context.RunContext(algorithm=a2c_alg_config,
                                env=lambda: GymMazeEnv(env="CartPole-v0"),
                                silent=True,
                                runner="dev",
                                overrides=default_overrides)
    rc.train(1)

    run_context.RunContext(env=lambda: GymMazeEnv(env="CartPole-v0"),
                           runner=es_dev_runner_config,
                           silent=True,
                           overrides=default_overrides)
    rc.train(1)
예제 #10
0
def test_observation_normalization_manual_default_stats():
    """ observation normalization test """

    # init environment
    env = GymMazeEnv("CartPole-v0")

    # normalization config
    normalization_config = {
        "default_strategy":
        "maze.normalization_strategies.MeanZeroStdOneObservationNormalizationStrategy",
        "default_strategy_config": {
            "clip_range": (0, 1),
            "axis": 0
        },
        "default_statistics": {
            "mean": [0, 0, 0, 0],
            "std": [1, 1, 1, 1]
        },
        "statistics_dump": "statistics.pkl",
        "sampling_policy": RandomPolicy(env.action_spaces_dict),
        "exclude": None,
        "manual_config": None,
    }

    # wrap env with observation normalization
    env = ObservationNormalizationWrapper(
        env,
        default_strategy=normalization_config["default_strategy"],
        default_strategy_config=normalization_config[
            "default_strategy_config"],
        default_statistics=normalization_config["default_statistics"],
        statistics_dump=normalization_config["statistics_dump"],
        sampling_policy=normalization_config['sampling_policy'],
        exclude=normalization_config["exclude"],
        manual_config=normalization_config["manual_config"])

    # check if action space clipping was applied
    assert np.alltrue(env.observation_space["observation"].high <= 1.0)
    assert np.alltrue(env.observation_space["observation"].low >= 0.0)

    # check if stats have been set properly
    statistics = env.get_statistics()
    assert np.all(statistics["observation"]["mean"] == np.zeros(shape=4))
    assert np.all(statistics["observation"]["std"] == np.ones(shape=4))

    # test sampling
    obs = random_env_steps(env, steps=100)
    assert np.min(obs) >= 0 and np.max(obs) <= 1
예제 #11
0
def run_observation_normalization_pipeline(
        normalization_config) -> ObservationNormalizationWrapper:
    """ observation normalization test """

    # wrap env with observation normalization
    env = GymMazeEnv("CartPole-v0")
    env = ObservationNormalizationWrapper(
        env,
        default_strategy=normalization_config["default_strategy"],
        default_strategy_config=normalization_config[
            "default_strategy_config"],
        default_statistics=normalization_config["default_statistics"],
        statistics_dump=normalization_config["statistics_dump"],
        exclude=normalization_config["exclude"],
        sampling_policy=RandomPolicy(env.action_spaces_dict),
        manual_config=normalization_config["manual_config"])

    # estimate normalization statistics
    statistics = obtain_normalization_statistics(env, n_samples=1000)

    # check statistics
    for sub_step_key in env.observation_spaces_dict:
        for obs_key in env.observation_spaces_dict[sub_step_key].spaces:
            assert obs_key in statistics
            for stats_key in statistics[obs_key]:
                stats = statistics[obs_key][stats_key]
                assert isinstance(stats, np.ndarray)

    # test normalization
    random_env_steps(env, steps=100)

    return env
예제 #12
0
def test_rollouts_from_python():
    env, agent = GymMazeEnv("CartPole-v0"), DummyCartPolePolicy()

    sequential = SequentialRolloutRunner(n_episodes=2,
                                         max_episode_steps=2,
                                         record_trajectory=False,
                                         record_event_logs=False,
                                         render=False)
    sequential.maze_seeding = MazeSeeding(1234, 4321, False)
    sequential.run_with(env=env, wrappers={}, agent=agent)

    parallel = ParallelRolloutRunner(n_episodes=2,
                                     max_episode_steps=2,
                                     record_trajectory=False,
                                     record_event_logs=False,
                                     n_processes=2)
    parallel.maze_seeding = MazeSeeding(1234, 4321, False)
    # Test with a wrapper config as well
    parallel.run_with(env=env,
                      wrappers={
                          MazeEnvMonitoringWrapper: {
                              "observation_logging": True,
                              "action_logging": False,
                              "reward_logging": False
                          }
                      },
                      agent=agent)
def _get_alg_config(env_name: str, runner_type: str) -> A2CAlgorithmConfig:
    """
    Returns algorithm config used in tests.
    :param env_name: Env name for rollout evaluator.
    :param runner_type: Runner type. "dev" or "local".
    :return: A2CAlgorithmConfig instance.
    """

    env_factory = lambda: GymMazeEnv(env_name)
    return A2CAlgorithmConfig(
        n_epochs=1,
        epoch_length=25,
        patience=15,
        critic_burn_in_epochs=0,
        n_rollout_steps=100,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.00025,
        max_grad_norm=0.0,
        device='cpu',
        rollout_evaluator=RolloutEvaluator(
            eval_env=SubprocVectorEnv([env_factory])
            if runner_type == "local" else SequentialVectorEnv([env_factory]),
            n_episodes=1,
            model_selection=None,
            deterministic=True))
예제 #14
0
def test_readme():
    """
    Tests snippets in readme.md.
    """

    rc = RunContext(env=lambda: GymMazeEnv('CartPole-v0'))
    rc.train(n_epochs=1)

    # Run trained policy.
    env = GymMazeEnv('CartPole-v0')
    obs = env.reset()
    done = False

    while not done:
        action = rc.compute_action(obs)
        obs, reward, done, info = env.step(action)
        break
예제 #15
0
def test_gets_formatted_actions_and_observations():
    gym_env = gym.make("CartPole-v0")
    gym_obs = gym_env.reset()
    gym_act = gym_env.action_space.sample()

    wrapped_env = GymMazeEnv(env="CartPole-v0")
    wrapped_env.seed(1234)
    assert not wrapped_env.is_actor_done()
    assert wrapped_env.actor_id() == (0, 0)
    obs_dict, act_dict = wrapped_env.get_observation_and_action_dicts(
        gym_obs, gym_act, False)
    assert np.all(gym_obs.astype(np.float32) == obs_dict[0]["observation"])
    assert np.all(gym_act == act_dict[0]["action"])
    wrapped_env.close()
예제 #16
0
def test_sequential_rollout_with_rendering():
    env, agent = GymMazeEnv("CartPole-v0"), DummyCartPolePolicy()
    sequential = SequentialRolloutRunner(n_episodes=2,
                                         max_episode_steps=2,
                                         record_trajectory=True,
                                         record_event_logs=False,
                                         render=True)
    sequential.maze_seeding = MazeSeeding(1234, 4321, False)
    sequential.run_with(env=env, wrappers={}, agent=agent)
예제 #17
0
def test_random_sampling_seeding():
    """Test the seeding with a random env version and random sampling (fully stochastic)"""
    env = GymMazeEnv(env="CartPole-v0")
    policy = RandomPolicy(env.action_spaces_dict)

    perform_seeding_test(env,
                         policy,
                         is_deterministic_env=False,
                         is_deterministic_agent=False)
예제 #18
0
def test_heuristic_sampling():
    """Test the seeding with a deterministic env and deterministic heuristic"""
    env = GymMazeEnv(env="CartPole-v0")
    policy = DummyCartPolePolicy()

    perform_seeding_test(env,
                         policy,
                         is_deterministic_env=False,
                         is_deterministic_agent=True)
예제 #19
0
def main(n_epochs) -> None:
    """Trains the cart pole environment with the ES implementation.
    """

    env = GymMazeEnv(env="CartPole-v0")
    distribution_mapper = DistributionMapper(action_space=env.action_space,
                                             distribution_mapper_config={})

    obs_shapes = observation_spaces_to_in_shapes(env.observation_spaces_dict)
    action_shapes = {
        step_key: {
            action_head: distribution_mapper.required_logits_shape(action_head)
            for action_head in env.action_spaces_dict[step_key].spaces.keys()
        }
        for step_key in env.action_spaces_dict.keys()
    }

    # initialize policies
    policies = [
        PolicyNet(obs_shapes=obs_shapes[0],
                  action_logits_shapes=action_shapes[0],
                  non_lin=nn.SELU)
    ]

    # initialize optimizer
    policy = TorchPolicy(networks=list_to_dict(policies),
                         distribution_mapper=distribution_mapper,
                         device="cpu")

    shared_noise = SharedNoiseTable(count=1_000_000)

    algorithm_config = ESAlgorithmConfig(n_rollouts_per_update=100,
                                         n_timesteps_per_update=0,
                                         max_steps=0,
                                         optimizer=Adam(step_size=0.01),
                                         l2_penalty=0.005,
                                         noise_stddev=0.02,
                                         n_epochs=n_epochs,
                                         policy_wrapper=None)

    trainer = ESTrainer(algorithm_config=algorithm_config,
                        torch_policy=policy,
                        shared_noise=shared_noise,
                        normalization_stats=None)

    setup_logging(job_config=None)

    maze_rng = np.random.RandomState(None)

    # run with pseudo-distribution, without worker processes
    trainer.train(ESDummyDistributedRollouts(
        env=env,
        n_eval_rollouts=10,
        shared_noise=shared_noise,
        agent_instance_seed=MazeSeeding.generate_seed_from_random_state(
            maze_rng)),
                  model_selection=None)
예제 #20
0
def test_cartpole_model_composer():
    env = GymMazeEnv(env='CartPole-v0')
    path_to_model_config = code_snippets.__path__._path[
        0] + '/custom_plain_cartpole_net.yaml'

    model_composer = Factory(base_type=BaseModelComposer).instantiate(
        yaml.load(open(path_to_model_config, 'r')),
        action_spaces_dict=env.action_spaces_dict,
        observation_spaces_dict=env.observation_spaces_dict,
        agent_counts_dict=env.agent_counts_dict)
예제 #21
0
def _generate_inconsistency_type_2_configs(
) -> Tuple[Dict, Dict, Dict, A2CAlgorithmConfig, Dict]:
    """
    Returns configsf for tests of inconsistencies of type 2.
    :return: es_dev_runner_config, a2c_dev_runner_config, invalid_a2c_dev_runner_config, a2c_alg_config,
             default_overrides.
    """

    gym_env_name = "CartPole-v0"
    es_dev_runner_config = {
        'state_dict_dump_file': 'state_dict.pt',
        'spaces_config_dump_file': 'spaces_config.pkl',
        'normalization_samples': 1,
        '_target_': 'maze.train.trainers.es.ESDevRunner',
        'n_eval_rollouts': 1,
        'shared_noise_table_size': 10,
        "dump_interval": None
    }
    a2c_dev_runner_config = {
        'state_dict_dump_file': 'state_dict.pt',
        'spaces_config_dump_file': 'spaces_config.pkl',
        'normalization_samples': 1,
        '_target_':
        'maze.train.trainers.common.actor_critic.actor_critic_runners.ACDevRunner',
        "trainer_class": "maze.train.trainers.a2c.a2c_trainer.A2C",
        'concurrency': 1,
        "dump_interval": None,
        "eval_concurrency": 1
    }
    invalid_a2c_dev_runner_config = copy.deepcopy(a2c_dev_runner_config)
    invalid_a2c_dev_runner_config[
        "trainer_class"] = "maze.train.trainers.es.es_trainer.ESTrainer"

    a2c_alg_config = A2CAlgorithmConfig(
        n_epochs=1,
        epoch_length=25,
        patience=15,
        critic_burn_in_epochs=0,
        n_rollout_steps=100,
        lr=0.0005,
        gamma=0.98,
        gae_lambda=1.0,
        policy_loss_coef=1.0,
        value_loss_coef=0.5,
        entropy_coef=0.00025,
        max_grad_norm=0.0,
        device='cpu',
        rollout_evaluator=RolloutEvaluator(eval_env=SequentialVectorEnv(
            [lambda: GymMazeEnv(gym_env_name)]),
                                           n_episodes=1,
                                           model_selection=None,
                                           deterministic=True))
    default_overrides = {"env.name": gym_env_name}

    return es_dev_runner_config, a2c_dev_runner_config, invalid_a2c_dev_runner_config, a2c_alg_config, default_overrides
def test_no_dict_action_wrapper():
    """ gym env wrapper unit test """
    base_env = GymMazeEnv(env="CartPole-v0")
    env = NoDictObservationWrapper.wrap(base_env)

    assert isinstance(env.observation_space, spaces.Box)
    assert isinstance(env.observation_spaces_dict, dict)

    assert isinstance(env.observation_space.sample(), np.ndarray)
    assert env.observation_space.contains(env.observation_space.sample())
    assert env.observation_space.contains(env.reset())
예제 #23
0
파일: test_es.py 프로젝트: enlite-ai/maze
def test_subproc_distributed_rollouts():
    policy, env, trainer = train_setup(n_epochs=2)

    rollouts = ESSubprocDistributedRollouts(
        env_factory=lambda: GymMazeEnv(env="CartPole-v0"),
        n_training_workers=2,
        n_eval_workers=1,
        shared_noise=trainer.shared_noise,
        env_seeds=[1337] * 3,
        agent_seed=1337)

    trainer.train(rollouts, model_selection=None)
예제 #24
0
def test_autoresolving_proxy_attribute():
    """
    Tests auto-resolving proxy attributes like critic (see for :py:class:`maze.api.utils._ATTRIBUTE_PROXIES` for more
    info).
    """

    cartpole_env_factory = lambda: GymMazeEnv(env=gym.make("CartPole-v0"))

    _, _, critic_composer, _, _ = _get_cartpole_setup_components()
    alg_config = A2CAlgorithmConfig(n_epochs=1,
                                    epoch_length=25,
                                    patience=15,
                                    critic_burn_in_epochs=0,
                                    n_rollout_steps=100,
                                    lr=0.0005,
                                    gamma=0.98,
                                    gae_lambda=1.0,
                                    policy_loss_coef=1.0,
                                    value_loss_coef=0.5,
                                    entropy_coef=0.00025,
                                    max_grad_norm=0.0,
                                    device='cpu',
                                    rollout_evaluator=RolloutEvaluator(
                                        eval_env=SequentialVectorEnv(
                                            [cartpole_env_factory]),
                                        n_episodes=1,
                                        model_selection=None,
                                        deterministic=True))
    default_overrides = {
        "runner.normalization_samples": 1,
        "runner.concurrency": 1
    }

    rc = run_context.RunContext(env=cartpole_env_factory,
                                silent=True,
                                algorithm=alg_config,
                                critic=critic_composer,
                                runner="dev",
                                overrides=default_overrides)
    rc.train(n_epochs=1)
    assert isinstance(rc._runners[RunMode.TRAINING][0].model_composer.critic,
                      TorchSharedStateCritic)

    rc = run_context.RunContext(env=cartpole_env_factory,
                                silent=True,
                                algorithm=alg_config,
                                critic="template_state",
                                runner="dev",
                                overrides=default_overrides)
    rc.train(n_epochs=1)
    assert isinstance(rc._runners[RunMode.TRAINING][0].model_composer.critic,
                      TorchStepStateCritic)
예제 #25
0
def test_cartpole_policy_model():
    env = GymMazeEnv(env='CartPole-v0')
    observation_spaces_dict = env.observation_spaces_dict
    action_spaces_dict = env.action_spaces_dict

    flat_action_space = flat_structured_space(action_spaces_dict)
    distribution_mapper = DistributionMapper(action_space=flat_action_space,
                                             distribution_mapper_config={})

    action_logits_shapes = {
        step_key: {
            action_head: distribution_mapper.required_logits_shape(action_head)
            for action_head in action_spaces_dict[step_key].spaces.keys()
        }
        for step_key in action_spaces_dict.keys()
    }

    obs_shapes = observation_spaces_to_in_shapes(observation_spaces_dict)

    policy = CustomPlainCartpolePolicyNet(obs_shapes[0],
                                          action_logits_shapes[0],
                                          hidden_layer_0=16,
                                          hidden_layer_1=32,
                                          use_bias=True)

    critic = CustomPlainCartpoleCriticNet(obs_shapes[0],
                                          hidden_layer_0=16,
                                          hidden_layer_1=32,
                                          use_bias=True)

    obs_np = env.reset()
    obs = {k: torch.from_numpy(v) for k, v in obs_np.items()}

    actions = policy(obs)
    values = critic(obs)

    assert 'action' in actions
    assert 'value' in values
예제 #26
0
def test_observation_normalization_init_from_yaml_config():
    """ observation normalization test """

    # load config
    config = load_env_config(test_observation_normalization_module,
                             "dummy_config_file.yml")

    # init environment
    env = GymMazeEnv("CartPole-v0")
    env = ObservationNormalizationWrapper(
        env, **config["observation_normalization_wrapper"])
    assert isinstance(env, ObservationNormalizationWrapper)

    stats = env.get_statistics()
    assert "stat_1" in stats["observation"] and "stat_2" in stats["observation"]

    norm_strategies = getattr(env, "_normalization_strategies")
    strategy = norm_strategies["observation"]
    assert isinstance(strategy, ObservationNormalizationStrategy)
    assert strategy._clip_min == 0
    assert strategy._clip_max == 1
    assert np.all(strategy._statistics["stat_1"] == np.asarray([0, 0, 0, 0]))
    assert np.all(strategy._statistics["stat_2"] == np.asarray([1, 1, 1, 1]))
예제 #27
0
def test_manual_rollout() -> None:
    """
    Test manual rollout via control loop.
    """

    env_factory = lambda: GymMazeEnv('CartPole-v0')
    rc = run_context.RunContext(env=lambda: env_factory(), silent=True)
    rc.train(n_epochs=1)

    env = env_factory()
    obs = env.reset()
    for i in range(2):
        action = rc.compute_action(obs)
        obs, rewards, dones, info = env.step(action)
예제 #28
0
def test_no_dict_action_wrapper():
    """ gym env wrapper unit test """
    base_env = GymMazeEnv(env="CartPole-v0")
    env = NoDictActionWrapper.wrap(base_env)

    assert isinstance(env.action_space, spaces.Discrete)
    assert isinstance(env.action_spaces_dict, dict)

    action = env.action_space.sample()
    out_action = env.action(action)
    assert isinstance(out_action, dict)
    assert out_action['action'] == action

    assert env.action_space.contains(env.reverse_action(out_action))
    assert env.reverse_action(out_action) == action
예제 #29
0
def test_experiment():
    """
    Tests whether experiments are correctly loaded.
    """

    rc = run_context.RunContext(env=lambda: GymMazeEnv('CartPole-v0'),
                                silent=True,
                                overrides={
                                    "runner.normalization_samples": 1,
                                    "runner.concurrency": 1
                                },
                                experiment="cartpole_ppo_wrappers")
    rc.train(1)

    assert isinstance(rc._runners[RunMode.TRAINING][0]._trainer, PPO)
    assert rc._runners[RunMode.TRAINING][0]._cfg.algorithm.lr == 0.0001
예제 #30
0
파일: test_es.py 프로젝트: enlite-ai/maze
def train_setup(
        n_epochs: int,
        policy_wrapper=None) -> Tuple[TorchPolicy, StructuredEnv, ESTrainer]:
    """Trains the cart pole environment with the multi-step a2c implementation.
    """

    # initialize distributed env
    env = GymMazeEnv(env="CartPole-v0")

    # initialize distribution mapper
    distribution_mapper = DistributionMapper(action_space=env.action_space,
                                             distribution_mapper_config={})

    # initialize policies
    policies = {
        0:
        FlattenConcatPolicyNet({'observation': (4, )}, {'action': (2, )},
                               hidden_units=[16],
                               non_lin=nn.Tanh)
    }

    # initialize optimizer
    policy = TorchPolicy(networks=policies,
                         distribution_mapper=distribution_mapper,
                         device="cpu")

    # reduce the noise table size to speed up testing
    shared_noise = SharedNoiseTable(count=1_000_000)

    algorithm_config = ESAlgorithmConfig(n_rollouts_per_update=100,
                                         n_timesteps_per_update=0,
                                         max_steps=0,
                                         optimizer=Adam(step_size=0.01),
                                         l2_penalty=0.005,
                                         noise_stddev=0.02,
                                         n_epochs=n_epochs,
                                         policy_wrapper=policy_wrapper)

    # train agent
    trainer = ESTrainer(algorithm_config=algorithm_config,
                        shared_noise=shared_noise,
                        torch_policy=policy,
                        normalization_stats=None)

    return policy, env, trainer