Exemplo n.º 1
0
def test_sampler_config_2():
    config = sampler_config_2()
    sampler = SamplerManager(config)
    assert sampler.is_empty() is False
    assert isinstance(sampler.samplers["angle"], GaussianSampler)

    # Check angle gaussian sampler
    assert sampler.samplers["angle"].mean == config["angle"]["mean"]
    assert sampler.samplers["angle"].st_dev == config["angle"]["st_dev"]
Exemplo n.º 2
0
def test_empty_samplers():
    empty_sampler = SamplerManager({})
    assert empty_sampler.is_empty()
    empty_cur_sample = empty_sampler.sample_all()
    assert empty_cur_sample == {}

    none_sampler = SamplerManager(None)
    assert none_sampler.is_empty()
    none_cur_sample = none_sampler.sample_all()
    assert none_cur_sample == {}
Exemplo n.º 3
0
def _check_environment_trains(
    env,
    trainer_config,
    reward_processor=default_reward_processor,
    meta_curriculum=None,
    success_threshold=0.9,
    env_manager=None,
):
    # Create controller and begin training.
    with tempfile.TemporaryDirectory() as dir:
        run_id = "id"
        save_freq = 99999
        seed = 1337
        StatsReporter.writers.clear(
        )  # Clear StatsReporters so we don't write to file
        debug_writer = DebugWriter()
        StatsReporter.add_writer(debug_writer)
        # Make sure threading is turned off for determinism
        trainer_config["threading"] = False
        if env_manager is None:
            env_manager = SimpleEnvManager(env, FloatPropertiesChannel())
        trainer_factory = TrainerFactory(
            trainer_config=trainer_config,
            summaries_dir=dir,
            run_id=run_id,
            model_path=dir,
            keep_checkpoints=1,
            train_model=True,
            load_model=False,
            seed=seed,
            meta_curriculum=meta_curriculum,
            multi_gpu=False,
        )

        tc = TrainerController(
            trainer_factory=trainer_factory,
            summaries_dir=dir,
            model_path=dir,
            run_id=run_id,
            meta_curriculum=meta_curriculum,
            train=True,
            training_seed=seed,
            sampler_manager=SamplerManager(None),
            resampling_interval=None,
            save_freq=save_freq,
        )

        # Begin training
        tc.start_learning(env_manager)
        if (success_threshold is not None
            ):  # For tests where we are just checking setup and not reward
            processed_rewards = [
                reward_processor(rewards)
                for rewards in env.final_rewards.values()
            ]
            assert all(not math.isnan(reward) for reward in processed_rewards)
            assert all(reward > success_threshold
                       for reward in processed_rewards)
Exemplo n.º 4
0
def test_sampler_config_1():
    config = sampler_config_1()
    sampler = SamplerManager(config)

    assert sampler.is_empty() is False
    assert isinstance(sampler.samplers["mass"], UniformSampler)
    assert isinstance(sampler.samplers["gravity"], MultiRangeUniformSampler)

    cur_sample = sampler.sample_all()

    # Check uniform sampler for mass
    assert sampler.samplers["mass"].min_value == config["mass"]["min_value"]
    assert sampler.samplers["mass"].max_value == config["mass"]["max_value"]
    assert config["mass"]["min_value"] <= cur_sample["mass"]
    assert config["mass"]["max_value"] >= cur_sample["mass"]

    # Check multirange_uniform sampler for gravity
    assert sampler.samplers["gravity"].intervals == config["gravity"][
        "intervals"]
    assert check_value_in_intervals(cur_sample["gravity"],
                                    sampler.samplers["gravity"].intervals)
Exemplo n.º 5
0
def basic_trainer_controller():
    return TrainerController(
        trainer_factory=None,
        model_path="test_model_path",
        summaries_dir="test_summaries_dir",
        run_id="test_run_id",
        save_freq=100,
        meta_curriculum=None,
        train=True,
        training_seed=99,
        sampler_manager=SamplerManager({}),
        resampling_interval=None,
    )
Exemplo n.º 6
0
def basic_trainer_controller():
    trainer_factory_mock = MagicMock()
    trainer_factory_mock.ghost_controller = GhostController()
    return TrainerController(
        trainer_factory=trainer_factory_mock,
        output_path="test_model_path",
        run_id="test_run_id",
        meta_curriculum=None,
        train=True,
        training_seed=99,
        sampler_manager=SamplerManager({}),
        resampling_interval=None,
    )
Exemplo n.º 7
0
def test_initialization_seed(numpy_random_seed, tensorflow_set_seed):
    seed = 27
    TrainerController(
        trainer_factory=None,
        model_path="",
        summaries_dir="",
        run_id="1",
        save_freq=1,
        meta_curriculum=None,
        train=True,
        training_seed=seed,
        sampler_manager=SamplerManager({}),
        resampling_interval=None,
    )
    numpy_random_seed.assert_called_with(seed)
    tensorflow_set_seed.assert_called_with(seed)
Exemplo n.º 8
0
def test_initialization_seed(numpy_random_seed, tensorflow_set_seed):
    seed = 27
    trainer_factory_mock = MagicMock()
    trainer_factory_mock.ghost_controller = GhostController()
    TrainerController(
        trainer_factory=trainer_factory_mock,
        output_path="",
        run_id="1",
        meta_curriculum=None,
        train=True,
        training_seed=seed,
        sampler_manager=SamplerManager({}),
        resampling_interval=None,
    )
    numpy_random_seed.assert_called_with(seed)
    tensorflow_set_seed.assert_called_with(seed)
Exemplo n.º 9
0
def _check_environment_trains(env,
                              config,
                              meta_curriculum=None,
                              success_threshold=0.99):
    # Create controller and begin training.
    with tempfile.TemporaryDirectory() as dir:
        run_id = "id"
        save_freq = 99999
        seed = 1337
        StatsReporter.writers.clear(
        )  # Clear StatsReporters so we don't write to file
        trainer_config = yaml.safe_load(config)
        env_manager = SimpleEnvManager(env, FloatPropertiesChannel())
        trainer_factory = TrainerFactory(
            trainer_config=trainer_config,
            summaries_dir=dir,
            run_id=run_id,
            model_path=dir,
            keep_checkpoints=1,
            train_model=True,
            load_model=False,
            seed=seed,
            meta_curriculum=meta_curriculum,
            multi_gpu=False,
        )

        tc = TrainerController(
            trainer_factory=trainer_factory,
            summaries_dir=dir,
            model_path=dir,
            run_id=run_id,
            meta_curriculum=meta_curriculum,
            train=True,
            training_seed=seed,
            sampler_manager=SamplerManager(None),
            resampling_interval=None,
            save_freq=save_freq,
        )

        # Begin training
        tc.start_learning(env_manager)
        print(tc._get_measure_vals())
        if (success_threshold is not None
            ):  # For tests where we are just checking setup and not reward
            for mean_reward in tc._get_measure_vals().values():
                assert not math.isnan(mean_reward)
                assert mean_reward > success_threshold
Exemplo n.º 10
0
def create_sampler_manager(sampler_config, run_seed=None):
    resample_interval = None
    if sampler_config is not None:
        if "resampling-interval" in sampler_config:
            # Filter arguments that do not exist in the environment
            resample_interval = sampler_config.pop("resampling-interval")
            if (resample_interval <= 0) or (not isinstance(resample_interval, int)):
                raise SamplerException(
                    "Specified resampling-interval is not valid. Please provide"
                    " a positive integer value for resampling-interval"
                )

        else:
            raise SamplerException(
                "Resampling interval was not specified in the sampler file."
                " Please specify it with the 'resampling-interval' key in the sampler config file."
            )

    sampler_manager = SamplerManager(sampler_config, run_seed)
    return sampler_manager, resample_interval
Exemplo n.º 11
0
def _check_environment_trains(env, config):
    # Create controller and begin training.
    with tempfile.TemporaryDirectory() as dir:
        run_id = "id"
        save_freq = 99999
        seed = 1337

        trainer_config = yaml.safe_load(config)
        env_manager = SimpleEnvManager(env, FloatPropertiesChannel())
        trainer_factory = TrainerFactory(
            trainer_config=trainer_config,
            summaries_dir=dir,
            run_id=run_id,
            model_path=dir,
            keep_checkpoints=1,
            train_model=True,
            load_model=False,
            seed=seed,
            meta_curriculum=None,
            multi_gpu=False,
        )

        tc = TrainerController(
            trainer_factory=trainer_factory,
            summaries_dir=dir,
            model_path=dir,
            run_id=run_id,
            meta_curriculum=None,
            train=True,
            training_seed=seed,
            sampler_manager=SamplerManager(None),
            resampling_interval=None,
            save_freq=save_freq,
        )

        # Begin training
        tc.start_learning(env_manager)
        print(tc._get_measure_vals())
        for brain_name, mean_reward in tc._get_measure_vals().items():
            assert not math.isnan(mean_reward)
            assert mean_reward > 0.99
 def __init__(
     self,
     trainer_factory: TrainerFactory,
     model_path: str,
     summaries_dir: str,
     run_id: str,
     save_freq: int,
     meta_curriculum: Optional[MetaCurriculumAAI],
     train: bool,
     training_seed: int,
 ):
     # we remove the sampler manager as it is irrelevant for AAI
     super().__init__(
         trainer_factory=trainer_factory,
         model_path=model_path,
         summaries_dir=summaries_dir,
         run_id=run_id,
         save_freq=save_freq,
         meta_curriculum=meta_curriculum,
         train=train,
         training_seed=training_seed,
         sampler_manager=SamplerManager(reset_param_dict={}),
         resampling_interval=None,
     )
Exemplo n.º 13
0
def test_incorrect_sampler():
    config = incorrect_sampler_config()
    with pytest.raises(TrainerError):
        SamplerManager(config)