コード例 #1
0
def test_preprocessing_init_from_yaml_config():
    """ Pre-processor unit test """

    # load config
    config = load_env_config(test_preprocessing_module, "dummy_config_file.yml")

    # init environment
    env = build_dummy_structured_environment()
    env = PreProcessingWrapper(env, **config["preprocessing_wrapper"])
    assert isinstance(env, PreProcessingWrapper)

    # test application of wrapper
    obs = env.reset()
    observation_keys = list(obs.keys())

    assert 'observation_1_categorical_feature' not in observation_keys
    assert 'observation_1_categorical_feature-one_hot' not in observation_keys

    for key in ['observation_0_feature_series',
                'observation_0_feature_series-dummy']:
        assert key in observation_keys
        assert obs[key] in env.observation_spaces_dict[0][key]

    obs = env.step(env.action_space.sample())[0]
    observation_keys = list(obs.keys())

    assert 'observation_0_feature_series' not in observation_keys
    assert 'observation_0_feature_series-dummy' not in observation_keys
    assert 'observation_1_categorical_feature' not in observation_keys

    for key in ['observation_1_categorical_feature-one_hot']:
        assert key in observation_keys
        assert obs[key] in env.observation_spaces_dict[1][key]
コード例 #2
0
def test_instantiation_with_wrapper_factory():
    """
    Test instantiation, types and parsing for/of a wrapped (VR) environment.
    """

    # Register dummy wrappers.
    registry = WrapperFactory()

    default_config = load_env_config(dummy_wrappers_module, "dummy_env_config_with_dummy_wrappers.yml")
    env_config = default_config['env']
    env_config["core_env"] = {"_target_": DummyCoreEnvironment, "observation_space": ObservationConversion().space()}
    env: Wrapper[MazeEnv] = registry.wrap_from_config(
        DummyEnvironment(**env_config),
        default_config['wrappers']
    )

    # Make sure types are correctly inferred.
    assert isinstance(env, Wrapper)
    assert isinstance(env, DummyWrapper)
    assert isinstance(env, DummyWrapperA)
    assert isinstance(env, DummyWrapperB)
    assert isinstance(env, DummyEnvironment)

    # Check if arguments are set correctly and methods are available.
    assert getattr(env, "do_stuff")
    assert getattr(env, "arg_a")
    assert getattr(env, "arg_b")
    assert getattr(env, "arg_c")
    assert env.do_stuff() == "b"
コード例 #3
0
def test_observation_stack_init_from_yaml_config():
    """ Pre-processor unit test """

    # load config
    config = load_env_config(wrapper_module,
                             "dummy_observation_stack_config_file.yml")

    # init environment
    env = build_dummy_structured_environment()
    env = ObservationStackWrapper(env, **config["observation_stack_wrapper"])

    # test application of wrapper
    assertion_routine(env)
コード例 #4
0
def test_wrap_method():
    """
    Tests .wrap() method.
    """

    default_config: dict = load_env_config(dummy_wrappers_module, "dummy_env_config_with_dummy_wrappers.yml")
    env_config: dict = default_config['env']
    env_config["core_env"] = {"_target_": DummyCoreEnvironment, "observation_space": ObservationConversion().space()}
    env = DummyEnvironment(**env_config)

    env_a: DummyWrapperA = DummyWrapperA.wrap(env, arg_a=1)
    assert isinstance(env_a, DummyWrapperA)

    try:
        DummyWrapperB.wrap(env)
        raise Exception("Wrapping shouldn't work without specifying the needed arguments.")
    except TypeError:
        pass

    env_b: DummyWrapperB = DummyWrapperB.wrap(env, arg_b=2, arg_c=3)
    assert isinstance(env_b, DummyWrapperB)
コード例 #5
0
def test_observation_normalization_init_from_yaml_config():
    """ observation normalization test """

    # load config
    config = load_env_config(test_observation_normalization_module,
                             "dummy_config_file.yml")

    # init environment
    env = GymMazeEnv("CartPole-v0")
    env = ObservationNormalizationWrapper(
        env, **config["observation_normalization_wrapper"])
    assert isinstance(env, ObservationNormalizationWrapper)

    stats = env.get_statistics()
    assert "stat_1" in stats["observation"] and "stat_2" in stats["observation"]

    norm_strategies = getattr(env, "_normalization_strategies")
    strategy = norm_strategies["observation"]
    assert isinstance(strategy, ObservationNormalizationStrategy)
    assert strategy._clip_min == 0
    assert strategy._clip_max == 1
    assert np.all(strategy._statistics["stat_1"] == np.asarray([0, 0, 0, 0]))
    assert np.all(strategy._statistics["stat_2"] == np.asarray([1, 1, 1, 1]))