Пример #1
0
def _check_environment_trains(
    env,
    trainer_config,
    reward_processor=default_reward_processor,
    meta_curriculum=None,
    success_threshold=0.9,
    env_manager=None,
):
    # Create controller and begin training.
    with tempfile.TemporaryDirectory() as dir:
        run_id = "id"
        save_freq = 99999
        seed = 1337
        StatsReporter.writers.clear(
        )  # Clear StatsReporters so we don't write to file
        debug_writer = DebugWriter()
        StatsReporter.add_writer(debug_writer)
        # Make sure threading is turned off for determinism
        trainer_config["threading"] = False
        if env_manager is None:
            env_manager = SimpleEnvManager(env, FloatPropertiesChannel())
        trainer_factory = TrainerFactory(
            trainer_config=trainer_config,
            summaries_dir=dir,
            run_id=run_id,
            model_path=dir,
            keep_checkpoints=1,
            train_model=True,
            load_model=False,
            seed=seed,
            meta_curriculum=meta_curriculum,
            multi_gpu=False,
        )

        tc = TrainerController(
            trainer_factory=trainer_factory,
            summaries_dir=dir,
            model_path=dir,
            run_id=run_id,
            meta_curriculum=meta_curriculum,
            train=True,
            training_seed=seed,
            sampler_manager=SamplerManager(None),
            resampling_interval=None,
            save_freq=save_freq,
        )

        # Begin training
        tc.start_learning(env_manager)
        if (success_threshold is not None
            ):  # For tests where we are just checking setup and not reward
            processed_rewards = [
                reward_processor(rewards)
                for rewards in env.final_rewards.values()
            ]
            assert all(not math.isnan(reward) for reward in processed_rewards)
            assert all(reward > success_threshold
                       for reward in processed_rewards)
def check_environment_trains(
    env,
    trainer_config,
    reward_processor=default_reward_processor,
    env_parameter_manager=None,
    success_threshold=0.9,
    env_manager=None,
    training_seed=None,
):
    if env_parameter_manager is None:
        env_parameter_manager = EnvironmentParameterManager()
    # Create controller and begin training.
    with tempfile.TemporaryDirectory() as dir:
        run_id = "id"
        seed = 1337 if training_seed is None else training_seed
        StatsReporter.writers.clear(
        )  # Clear StatsReporters so we don't write to file
        debug_writer = DebugWriter()
        StatsReporter.add_writer(debug_writer)
        if env_manager is None:
            env_manager = SimpleEnvManager(env, EnvironmentParametersChannel())
        trainer_factory = TrainerFactory(
            trainer_config=trainer_config,
            output_path=dir,
            train_model=True,
            load_model=False,
            seed=seed,
            param_manager=env_parameter_manager,
            multi_gpu=False,
        )

        tc = TrainerController(
            trainer_factory=trainer_factory,
            output_path=dir,
            run_id=run_id,
            param_manager=env_parameter_manager,
            train=True,
            training_seed=seed,
        )

        # Begin training
        tc.start_learning(env_manager)
        if (success_threshold is not None
            ):  # For tests where we are just checking setup and not reward
            processed_rewards = [
                reward_processor(rewards)
                for rewards in env.final_rewards.values()
            ]
            assert all(not math.isnan(reward) for reward in processed_rewards)
            assert all(reward > success_threshold
                       for reward in processed_rewards)
Пример #3
0
def _check_environment_trains(env,
                              config,
                              meta_curriculum=None,
                              success_threshold=0.99):
    # Create controller and begin training.
    with tempfile.TemporaryDirectory() as dir:
        run_id = "id"
        save_freq = 99999
        seed = 1337
        StatsReporter.writers.clear(
        )  # Clear StatsReporters so we don't write to file
        trainer_config = yaml.safe_load(config)
        env_manager = SimpleEnvManager(env, FloatPropertiesChannel())
        trainer_factory = TrainerFactory(
            trainer_config=trainer_config,
            summaries_dir=dir,
            run_id=run_id,
            model_path=dir,
            keep_checkpoints=1,
            train_model=True,
            load_model=False,
            seed=seed,
            meta_curriculum=meta_curriculum,
            multi_gpu=False,
        )

        tc = TrainerController(
            trainer_factory=trainer_factory,
            summaries_dir=dir,
            model_path=dir,
            run_id=run_id,
            meta_curriculum=meta_curriculum,
            train=True,
            training_seed=seed,
            sampler_manager=SamplerManager(None),
            resampling_interval=None,
            save_freq=save_freq,
        )

        # Begin training
        tc.start_learning(env_manager)
        print(tc._get_measure_vals())
        if (success_threshold is not None
            ):  # For tests where we are just checking setup and not reward
            for mean_reward in tc._get_measure_vals().values():
                assert not math.isnan(mean_reward)
                assert mean_reward > success_threshold
Пример #4
0
def _check_environment_trains(env, config):
    # Create controller and begin training.
    with tempfile.TemporaryDirectory() as dir:
        run_id = "id"
        save_freq = 99999
        seed = 1337

        trainer_config = yaml.safe_load(config)
        env_manager = SimpleEnvManager(env, FloatPropertiesChannel())
        trainer_factory = TrainerFactory(
            trainer_config=trainer_config,
            summaries_dir=dir,
            run_id=run_id,
            model_path=dir,
            keep_checkpoints=1,
            train_model=True,
            load_model=False,
            seed=seed,
            meta_curriculum=None,
            multi_gpu=False,
        )

        tc = TrainerController(
            trainer_factory=trainer_factory,
            summaries_dir=dir,
            model_path=dir,
            run_id=run_id,
            meta_curriculum=None,
            train=True,
            training_seed=seed,
            sampler_manager=SamplerManager(None),
            resampling_interval=None,
            save_freq=save_freq,
        )

        # Begin training
        tc.start_learning(env_manager)
        print(tc._get_measure_vals())
        for brain_name, mean_reward in tc._get_measure_vals().items():
            assert not math.isnan(mean_reward)
            assert mean_reward > 0.99