Пример #1
0
def test_agent_manager_partial_fit_and_tuple_env():
    # Define train and evaluation envs
    train_env = (
        GridWorld,
        None,
    )  # tuple (constructor, kwargs) must also work in AgentManager

    # Parameters
    params = {}
    eval_kwargs = dict(eval_horizon=10)

    # Run AgentManager
    stats = AgentManager(
        DummyAgent,
        train_env,
        init_kwargs=params,
        n_fit=4,
        fit_budget=5,
        eval_kwargs=eval_kwargs,
        seed=123,
    )
    stats2 = AgentManager(
        DummyAgent,
        train_env,
        init_kwargs=params,
        n_fit=4,
        fit_budget=5,
        eval_kwargs=eval_kwargs,
        seed=123,
    )

    # Run partial fit
    stats.fit(10)
    stats.fit(20)
    for agent in stats.agent_handlers:
        assert agent.total_budget == 30

    # Run fit
    stats2.fit()

    # learning curves
    plot_writer_data([stats],
                     tag="episode_rewards",
                     show=False,
                     preprocess_func=np.cumsum)

    # compare final policies
    evaluate_agents([stats], show=False)

    # delete some writers
    stats.set_writer(0, None)
    stats.set_writer(3, None)

    stats.clear_output_dir()
    stats2.clear_output_dir()
Пример #2
0
    eval_kwargs = dict(eval_horizon=HORIZON, n_simulations=20)

    # -----------------------------
    # Run AgentManager
    # -----------------------------
    ppo_stats = AgentManager(
        PPOAgent,
        train_env,
        fit_budget=N_EPISODES,
        init_kwargs=params_ppo,
        eval_kwargs=eval_kwargs,
        n_fit=4,
    )

    ppo_stats.set_writer(0,
                         SummaryWriter,
                         writer_kwargs={"comment": "worker_0"})
    ppo_stats.set_writer(1,
                         SummaryWriter,
                         writer_kwargs={"comment": "worker_1"})

    agent_manager_list = [ppo_stats]

    agent_manager_list[0].fit()
    agent_manager_list[0].save(
    )  # after fit, writers are set to None to avoid pickle problems.

    # compare final policies
    output = evaluate_agents(agent_manager_list)
    print(output)
Пример #3
0
def execute_message(message: interface.Message,
                    resources: interface.Resources) -> interface.Message:
    response = interface.Message.create(command=interface.Command.ECHO)
    # LIST_RESOURCES
    if message.command == interface.Command.LIST_RESOURCES:
        info = {}
        for rr in resources:
            info[rr] = resources[rr]["description"]
        response = interface.Message.create(info=info)
    # AGENT_MANAGER_CREATE_INSTANCE
    elif message.command == interface.Command.AGENT_MANAGER_CREATE_INSTANCE:
        params = message.params
        base_dir = pathlib.Path(metadata_utils.RLBERRY_DEFAULT_DATA_DIR)
        if "output_dir" in params:
            params[
                "output_dir"] = base_dir / "server_data" / params["output_dir"]
        else:
            params["output_dir"] = base_dir / "server_data/"
        agent_manager = AgentManager(**params)
        filename = str(agent_manager.save())
        response = interface.Message.create(info=dict(
            filename=filename,
            agent_name=agent_manager.agent_name,
            output_dir=str(agent_manager.output_dir).replace(
                "server_data/", "client_data/"),
        ))
        del agent_manager
    # AGENT_MANAGER_FIT
    elif message.command == interface.Command.AGENT_MANAGER_FIT:
        filename = message.params["filename"]
        budget = message.params["budget"]
        extra_params = message.params["extra_params"]
        agent_manager = AgentManager.load(filename)
        agent_manager.fit(budget, **extra_params)
        agent_manager.save()
        response = interface.Message.create(command=interface.Command.ECHO)
        del agent_manager
    # AGENT_MANAGER_EVAL
    elif message.command == interface.Command.AGENT_MANAGER_EVAL:
        filename = message.params["filename"]
        agent_manager = AgentManager.load(filename)
        eval_output = agent_manager.eval_agents(
            message.params["n_simulations"])
        response = interface.Message.create(data=dict(output=eval_output))
        del agent_manager
    # AGENT_MANAGER_CLEAR_OUTPUT_DIR
    elif message.command == interface.Command.AGENT_MANAGER_CLEAR_OUTPUT_DIR:
        filename = message.params["filename"]
        agent_manager = AgentManager.load(filename)
        agent_manager.clear_output_dir()
        response = interface.Message.create(
            message=f"Cleared output dir: {agent_manager.output_dir}")
        del agent_manager
    # AGENT_MANAGER_CLEAR_HANDLERS
    elif message.command == interface.Command.AGENT_MANAGER_CLEAR_HANDLERS:
        filename = message.params["filename"]
        agent_manager = AgentManager.load(filename)
        agent_manager.clear_handlers()
        agent_manager.save()
        response = interface.Message.create(
            message=f"Cleared handlers: {filename}")
        del agent_manager
    # AGENT_MANAGER_SET_WRITER
    elif message.command == interface.Command.AGENT_MANAGER_SET_WRITER:
        filename = message.params["filename"]
        agent_manager = AgentManager.load(filename)
        agent_manager.set_writer(**message.params["kwargs"])
        agent_manager.save()
        del agent_manager
    # AGENT_MANAGER_OPTIMIZE_HYPERPARAMS
    elif message.command == interface.Command.AGENT_MANAGER_OPTIMIZE_HYPERPARAMS:
        filename = message.params["filename"]
        agent_manager = AgentManager.load(filename)
        best_params_dict = agent_manager.optimize_hyperparams(
            **message.params["kwargs"])
        agent_manager.save()
        del agent_manager
        response = interface.Message.create(data=best_params_dict)
    # AGENT_MANAGER_GET_WRITER_DATA
    elif message.command == interface.Command.AGENT_MANAGER_GET_WRITER_DATA:
        # writer scalar data
        filename = message.params["filename"]
        agent_manager = AgentManager.load(filename)
        writer_data = agent_manager.get_writer_data()
        writer_data = writer_data or dict()
        for idx in writer_data:
            writer_data[idx] = writer_data[idx].to_csv(index=False)
        # tensoboard data
        tensorboard_bin_data = None
        if agent_manager.tensorboard_dir is not None:
            tensorboard_zip_file = rlberry.utils.io.zipdir(
                agent_manager.tensorboard_dir,
                agent_manager.output_dir / "tensorboard_data.zip",
            )
            if tensorboard_zip_file is not None:
                tensorboard_bin_data = open(tensorboard_zip_file, "rb").read()
                tensorboard_bin_data = base64.b64encode(
                    tensorboard_bin_data).decode("ascii")
        response = interface.Message.create(
            data=dict(writer_data=writer_data,
                      tensorboard_bin_data=tensorboard_bin_data))
        del agent_manager
    # end
    return response
Пример #4
0
def test_agent_manager_2():
    # Define train and evaluation envs
    train_env = (GridWorld, {})
    eval_env = (GridWorld, {})

    # Parameters
    params = {}
    eval_kwargs = dict(eval_horizon=10)

    # Run AgentManager
    stats_agent1 = AgentManager(
        DummyAgent,
        train_env,
        eval_env=eval_env,
        fit_budget=5,
        eval_kwargs=eval_kwargs,
        init_kwargs=params,
        n_fit=4,
        seed=123,
    )
    stats_agent2 = AgentManager(
        DummyAgent,
        train_env,
        eval_env=eval_env,
        fit_budget=5,
        eval_kwargs=eval_kwargs,
        init_kwargs=params,
        n_fit=4,
        seed=123,
    )
    agent_manager_list = [stats_agent1, stats_agent2]
    for st in agent_manager_list:
        st.fit()

    # compare final policies
    evaluate_agents(agent_manager_list, show=False)
    evaluate_agents(agent_manager_list, show=False)

    # learning curves
    plot_writer_data(agent_manager_list, tag="episode_rewards", show=False)

    # check if fitted
    for agent_manager in agent_manager_list:
        assert len(agent_manager.agent_handlers) == 4
        for agent in agent_manager.agent_handlers:
            assert agent.fitted

    # test saving/loading
    fname = stats_agent1.save()
    loaded_stats = AgentManager.load(fname)
    assert stats_agent1.unique_id == loaded_stats.unique_id

    # test hyperparemeter optimization
    loaded_stats.optimize_hyperparams(n_trials=5)

    # delete some writers
    stats_agent1.set_writer(1, None)
    stats_agent1.set_writer(2, None)

    stats_agent1.clear_output_dir()
    stats_agent2.clear_output_dir()