示例#1
0
def _check_environment_trains(
    env, config, meta_curriculum=None, success_threshold=0.99
):
    # Create controller and begin training.
    with tempfile.TemporaryDirectory() as dir:
        run_id = "id"
        save_freq = 99999
        seed = 1337

        trainer_config = yaml.safe_load(config)
        env_manager = SimpleEnvManager(env, FloatPropertiesChannel())
        trainer_factory = TrainerFactory(
            trainer_config=trainer_config,
            summaries_dir=dir,
            run_id=run_id,
            model_path=dir,
            keep_checkpoints=1,
            train_model=True,
            load_model=False,
            seed=seed,
            meta_curriculum=meta_curriculum,
            multi_gpu=False,
        )

        tc = TrainerController(
            trainer_factory=trainer_factory,
            summaries_dir=dir,
            model_path=dir,
            run_id=run_id,
            meta_curriculum=meta_curriculum,
            train=True,
            training_seed=seed,
            sampler_manager=SamplerManager(None),
            resampling_interval=None,
            save_freq=save_freq,
        )

        # Begin training
        tc.start_learning(env_manager)
        print(tc._get_measure_vals())
        for brain_name, mean_reward in tc._get_measure_vals().items():
            assert not math.isnan(mean_reward)
            assert mean_reward > success_threshold
def worker(
    parent_conn: Connection,
    step_queue: Queue,
    pickled_env_factory: str,
    worker_id: int,
    engine_configuration: EngineConfig,
) -> None:
    env_factory: Callable[[int, List[SideChannel]],
                          UnityEnvironment] = cloudpickle.loads(
                              pickled_env_factory)
    shared_float_properties = FloatPropertiesChannel()
    engine_configuration_channel = EngineConfigurationChannel()
    engine_configuration_channel.set_configuration(engine_configuration)
    env: BaseEnv = env_factory(
        worker_id, [shared_float_properties, engine_configuration_channel])

    def _send_response(cmd_name, payload):
        parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))

    def _generate_all_brain_info() -> AllBrainInfo:
        all_brain_info = {}
        for brain_name in env.get_agent_groups():
            all_brain_info[brain_name] = step_result_to_brain_info(
                env.get_step_result(brain_name),
                env.get_agent_group_spec(brain_name),
                worker_id,
            )
        return all_brain_info

    def external_brains():
        result = {}
        for brain_name in env.get_agent_groups():
            result[brain_name] = group_spec_to_brain_parameters(
                brain_name, env.get_agent_group_spec(brain_name))
        return result

    try:
        while True:
            cmd: EnvironmentCommand = parent_conn.recv()
            if cmd.name == "step":
                all_action_info = cmd.payload
                for brain_name, action_info in all_action_info.items():
                    if len(action_info.action) != 0:
                        env.set_actions(brain_name, action_info.action)
                env.step()
                all_brain_info = _generate_all_brain_info()
                # The timers in this process are independent from all the processes and the "main" process
                # So after we send back the root timer, we can safely clear them.
                # Note that we could randomly return timers a fraction of the time if we wanted to reduce
                # the data transferred.
                # TODO get gauges from the workers and merge them in the main process too.
                step_response = StepResponse(all_brain_info, get_timer_root())
                step_queue.put(
                    EnvironmentResponse("step", worker_id, step_response))
                reset_timers()
            elif cmd.name == "external_brains":
                _send_response("external_brains", external_brains())
            elif cmd.name == "get_properties":
                reset_params = {}
                for k in shared_float_properties.list_properties():
                    reset_params[k] = shared_float_properties.get_property(k)

                _send_response("get_properties", reset_params)
            elif cmd.name == "reset":
                for k, v in cmd.payload.items():
                    shared_float_properties.set_property(k, v)
                env.reset()
                all_brain_info = _generate_all_brain_info()
                _send_response("reset", all_brain_info)
            elif cmd.name == "close":
                break
    except (KeyboardInterrupt, UnityCommunicationException,
            UnityTimeOutException):
        logger.info(
            f"UnityEnvironment worker {worker_id}: environment stopping.")
        step_queue.put(EnvironmentResponse("env_close", worker_id, None))
    finally:
        # If this worker has put an item in the step queue that hasn't been processed by the EnvManager, the process
        # will hang until the item is processed. We avoid this behavior by using Queue.cancel_join_thread()
        # See https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue.cancel_join_thread for
        # more info.
        logger.debug(f"UnityEnvironment worker {worker_id} closing.")
        step_queue.cancel_join_thread()
        step_queue.close()
        env.close()
        logger.debug(f"UnityEnvironment worker {worker_id} done.")
示例#3
0
def test_float_properties():
    sender = FloatPropertiesChannel()
    receiver = FloatPropertiesChannel()

    sender.set_property("prop1", 1.0)

    data = UnityEnvironment._generate_side_channel_data(
        {sender.channel_type: sender})
    UnityEnvironment._parse_side_channel_message(
        {receiver.channel_type: receiver}, data)

    val = receiver.get_property("prop1")
    assert val == 1.0
    val = receiver.get_property("prop2")
    assert val is None
    sender.set_property("prop2", 2.0)

    data = UnityEnvironment._generate_side_channel_data(
        {sender.channel_type: sender})
    UnityEnvironment._parse_side_channel_message(
        {receiver.channel_type: receiver}, data)

    val = receiver.get_property("prop1")
    assert val == 1.0
    val = receiver.get_property("prop2")
    assert val == 2.0
    assert len(receiver.list_properties()) == 2
    assert "prop1" in receiver.list_properties()
    assert "prop2" in receiver.list_properties()
    val = sender.get_property("prop1")
    assert val == 1.0

    assert receiver.get_property_dict_copy() == {"prop1": 1.0, "prop2": 2.0}
    assert receiver.get_property_dict_copy() == sender.get_property_dict_copy()