Example #1
0
def test_sampler_config_2():
    config = sampler_config_2()
    sampler = SamplerManager(config)
    assert sampler.is_empty() is False
    assert isinstance(sampler.samplers["angle"], GaussianSampler)

    # Check angle gaussian sampler
    assert sampler.samplers["angle"].mean == config["angle"]["mean"]
    assert sampler.samplers["angle"].st_dev == config["angle"]["st_dev"]
Example #2
0
def test_empty_samplers():
    empty_sampler = SamplerManager({})
    assert empty_sampler.is_empty()
    empty_cur_sample = empty_sampler.sample_all()
    assert empty_cur_sample == {}

    none_sampler = SamplerManager(None)
    assert none_sampler.is_empty()
    none_cur_sample = none_sampler.sample_all()
    assert none_cur_sample == {}
def test_simple():
    config = """
        default:
            trainer: ppo
            batch_size: 16
            beta: 5.0e-3
            buffer_size: 64
            epsilon: 0.2
            hidden_units: 128
            lambd: 0.95
            learning_rate: 5.0e-3
            max_steps: 2500
            memory_size: 256
            normalize: false
            num_epoch: 3
            num_layers: 2
            time_horizon: 64
            sequence_length: 64
            summary_freq: 500
            use_recurrent: false
            reward_signals:
                extrinsic:
                    strength: 1.0
                    gamma: 0.99
    """
    # Create controller and begin training.
    with tempfile.TemporaryDirectory() as dir:
        run_id = "id"
        save_freq = 99999
        tc = TrainerController(
            dir,
            dir,
            run_id,
            save_freq,
            meta_curriculum=None,
            load=False,
            train=True,
            keep_checkpoints=1,
            lesson=None,
            training_seed=1337,
            fast_simulation=True,
            sampler_manager=SamplerManager(None),
            resampling_interval=None,
        )

        # Begin training
        env = Simple1DEnvironment()
        env_manager = SimpleEnvManager(env)
        trainer_config = yaml.safe_load(config)
        tc.start_learning(env_manager, trainer_config)

        for brain_name, mean_reward in tc._get_measure_vals().items():
            assert not math.isnan(mean_reward)
            assert mean_reward > 0.99
Example #4
0
def test_sampler_config_1():
    config = sampler_config_1()
    sampler = SamplerManager(config)

    assert sampler.is_empty() is False
    assert isinstance(sampler.samplers["mass"], UniformSampler)
    assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler)

    cur_sample = sampler.sample_all()

    # Check uniform sampler for mass
    assert sampler.samplers["mass"].min_value == config["mass"]["min_value"]
    assert sampler.samplers["mass"].max_value == config["mass"]["max_value"]
    assert config["mass"]["min_value"] <= cur_sample["mass"]
    assert config["mass"]["max_value"] >= cur_sample["mass"]

    # Check multirange_uniform sampler for gravity
    assert sampler.samplers["gravity"].intervals == config["gravity"][
        "intervals"]
    assert check_value_in_intervals(cur_sample["gravity"],
                                    sampler.samplers["gravity"].intervals)
def basic_trainer_controller():
    return TrainerController(
        trainer_factory=None,
        model_path="test_model_path",
        summaries_dir="test_summaries_dir",
        run_id="test_run_id",
        save_freq=100,
        meta_curriculum=None,
        train=True,
        training_seed=99,
        sampler_manager=SamplerManager({}),
        resampling_interval=None,
    )
def test_initialization_seed(numpy_random_seed, tensorflow_set_seed):
    seed = 27
    TrainerController(
        trainer_factory=None,
        model_path="",
        summaries_dir="",
        run_id="1",
        save_freq=1,
        meta_curriculum=None,
        train=True,
        training_seed=seed,
        sampler_manager=SamplerManager({}),
        resampling_interval=None,
    )
    numpy_random_seed.assert_called_with(seed)
    tensorflow_set_seed.assert_called_with(seed)
Example #7
0
def _check_environment_trains(env, config):
    # Create controller and begin training.
    with tempfile.TemporaryDirectory() as dir:
        run_id = "id"
        save_freq = 99999
        seed = 1337

        trainer_config = yaml.safe_load(config)
        env_manager = SimpleEnvManager(env)
        trainer_factory = TrainerFactory(
            trainer_config=trainer_config,
            summaries_dir=dir,
            run_id=run_id,
            model_path=dir,
            keep_checkpoints=1,
            train_model=True,
            load_model=False,
            seed=seed,
            meta_curriculum=None,
            multi_gpu=False,
        )

        tc = TrainerController(
            trainer_factory=trainer_factory,
            summaries_dir=dir,
            model_path=dir,
            run_id=run_id,
            meta_curriculum=None,
            train=True,
            training_seed=seed,
            fast_simulation=True,
            sampler_manager=SamplerManager(None),
            resampling_interval=None,
            save_freq=save_freq,
        )

        # Begin training
        tc.start_learning(env_manager)
        print(tc._get_measure_vals())
        for brain_name, mean_reward in tc._get_measure_vals().items():
            assert not math.isnan(mean_reward)
            assert mean_reward > 0.99
Example #8
0
def create_sampler_manager(sampler_file_path, env_reset_params):
    sampler_config = None
    resample_interval = None
    if sampler_file_path is not None:
        sampler_config = load_config(sampler_file_path)
        if ("resampling-interval") in sampler_config:
            # Filter arguments that do not exist in the environment
            resample_interval = sampler_config.pop("resampling-interval")
            if (resample_interval <= 0) or (not isinstance(
                    resample_interval, int)):
                raise SamplerException(
                    "Specified resampling-interval is not valid. Please provide"
                    " a positive integer value for resampling-interval")
        else:
            raise SamplerException(
                "Resampling interval was not specified in the sampler file."
                " Please specify it with the 'resampling-interval' key in the sampler config file."
            )
    sampler_manager = SamplerManager(sampler_config)
    return sampler_manager, resample_interval
Example #9
0
def test_incorrect_sampler():
    config = incorrect_sampler_config()
    with pytest.raises(UnityException):
        SamplerManager(config)