def test_agent_manager_partial_fit_and_tuple_env(): # Define train and evaluation envs train_env = ( GridWorld, None, ) # tuple (constructor, kwargs) must also work in AgentManager # Parameters params = {} eval_kwargs = dict(eval_horizon=10) # Run AgentManager stats = AgentManager( DummyAgent, train_env, init_kwargs=params, n_fit=4, fit_budget=5, eval_kwargs=eval_kwargs, seed=123, ) stats2 = AgentManager( DummyAgent, train_env, init_kwargs=params, n_fit=4, fit_budget=5, eval_kwargs=eval_kwargs, seed=123, ) # Run partial fit stats.fit(10) stats.fit(20) for agent in stats.agent_handlers: assert agent.total_budget == 30 # Run fit stats2.fit() # learning curves plot_writer_data([stats], tag="episode_rewards", show=False, preprocess_func=np.cumsum) # compare final policies evaluate_agents([stats], show=False) # delete some writers stats.set_writer(0, None) stats.set_writer(3, None) stats.clear_output_dir() stats2.clear_output_dir()
eval_kwargs = dict(eval_horizon=HORIZON, n_simulations=20) # ----------------------------- # Run AgentManager # ----------------------------- ppo_stats = AgentManager( PPOAgent, train_env, fit_budget=N_EPISODES, init_kwargs=params_ppo, eval_kwargs=eval_kwargs, n_fit=4, ) ppo_stats.set_writer(0, SummaryWriter, writer_kwargs={"comment": "worker_0"}) ppo_stats.set_writer(1, SummaryWriter, writer_kwargs={"comment": "worker_1"}) agent_manager_list = [ppo_stats] agent_manager_list[0].fit() agent_manager_list[0].save( ) # after fit, writers are set to None to avoid pickle problems. # compare final policies output = evaluate_agents(agent_manager_list) print(output)
def execute_message(message: interface.Message, resources: interface.Resources) -> interface.Message: response = interface.Message.create(command=interface.Command.ECHO) # LIST_RESOURCES if message.command == interface.Command.LIST_RESOURCES: info = {} for rr in resources: info[rr] = resources[rr]["description"] response = interface.Message.create(info=info) # AGENT_MANAGER_CREATE_INSTANCE elif message.command == interface.Command.AGENT_MANAGER_CREATE_INSTANCE: params = message.params base_dir = pathlib.Path(metadata_utils.RLBERRY_DEFAULT_DATA_DIR) if "output_dir" in params: params[ "output_dir"] = base_dir / "server_data" / params["output_dir"] else: params["output_dir"] = base_dir / "server_data/" agent_manager = AgentManager(**params) filename = str(agent_manager.save()) response = interface.Message.create(info=dict( filename=filename, agent_name=agent_manager.agent_name, output_dir=str(agent_manager.output_dir).replace( "server_data/", "client_data/"), )) del agent_manager # AGENT_MANAGER_FIT elif message.command == interface.Command.AGENT_MANAGER_FIT: filename = message.params["filename"] budget = message.params["budget"] extra_params = message.params["extra_params"] agent_manager = AgentManager.load(filename) agent_manager.fit(budget, **extra_params) agent_manager.save() response = interface.Message.create(command=interface.Command.ECHO) del agent_manager # AGENT_MANAGER_EVAL elif message.command == interface.Command.AGENT_MANAGER_EVAL: filename = message.params["filename"] agent_manager = AgentManager.load(filename) eval_output = agent_manager.eval_agents( message.params["n_simulations"]) response = interface.Message.create(data=dict(output=eval_output)) del agent_manager # AGENT_MANAGER_CLEAR_OUTPUT_DIR elif message.command == interface.Command.AGENT_MANAGER_CLEAR_OUTPUT_DIR: filename = message.params["filename"] agent_manager = AgentManager.load(filename) agent_manager.clear_output_dir() response = interface.Message.create( message=f"Cleared output dir: {agent_manager.output_dir}") del agent_manager # AGENT_MANAGER_CLEAR_HANDLERS elif message.command == interface.Command.AGENT_MANAGER_CLEAR_HANDLERS: filename = message.params["filename"] agent_manager = AgentManager.load(filename) agent_manager.clear_handlers() agent_manager.save() response = interface.Message.create( message=f"Cleared handlers: {filename}") del agent_manager # AGENT_MANAGER_SET_WRITER elif message.command == interface.Command.AGENT_MANAGER_SET_WRITER: filename = message.params["filename"] agent_manager = AgentManager.load(filename) agent_manager.set_writer(**message.params["kwargs"]) agent_manager.save() del agent_manager # AGENT_MANAGER_OPTIMIZE_HYPERPARAMS elif message.command == interface.Command.AGENT_MANAGER_OPTIMIZE_HYPERPARAMS: filename = message.params["filename"] agent_manager = AgentManager.load(filename) best_params_dict = agent_manager.optimize_hyperparams( **message.params["kwargs"]) agent_manager.save() del agent_manager response = interface.Message.create(data=best_params_dict) # AGENT_MANAGER_GET_WRITER_DATA elif message.command == interface.Command.AGENT_MANAGER_GET_WRITER_DATA: # writer scalar data filename = message.params["filename"] agent_manager = AgentManager.load(filename) writer_data = agent_manager.get_writer_data() writer_data = writer_data or dict() for idx in writer_data: writer_data[idx] = writer_data[idx].to_csv(index=False) # tensoboard data tensorboard_bin_data = None if agent_manager.tensorboard_dir is not None: tensorboard_zip_file = rlberry.utils.io.zipdir( agent_manager.tensorboard_dir, agent_manager.output_dir / "tensorboard_data.zip", ) if tensorboard_zip_file is not None: tensorboard_bin_data = open(tensorboard_zip_file, "rb").read() tensorboard_bin_data = base64.b64encode( tensorboard_bin_data).decode("ascii") response = interface.Message.create( data=dict(writer_data=writer_data, tensorboard_bin_data=tensorboard_bin_data)) del agent_manager # end return response
def test_agent_manager_2(): # Define train and evaluation envs train_env = (GridWorld, {}) eval_env = (GridWorld, {}) # Parameters params = {} eval_kwargs = dict(eval_horizon=10) # Run AgentManager stats_agent1 = AgentManager( DummyAgent, train_env, eval_env=eval_env, fit_budget=5, eval_kwargs=eval_kwargs, init_kwargs=params, n_fit=4, seed=123, ) stats_agent2 = AgentManager( DummyAgent, train_env, eval_env=eval_env, fit_budget=5, eval_kwargs=eval_kwargs, init_kwargs=params, n_fit=4, seed=123, ) agent_manager_list = [stats_agent1, stats_agent2] for st in agent_manager_list: st.fit() # compare final policies evaluate_agents(agent_manager_list, show=False) evaluate_agents(agent_manager_list, show=False) # learning curves plot_writer_data(agent_manager_list, tag="episode_rewards", show=False) # check if fitted for agent_manager in agent_manager_list: assert len(agent_manager.agent_handlers) == 4 for agent in agent_manager.agent_handlers: assert agent.fitted # test saving/loading fname = stats_agent1.save() loaded_stats = AgentManager.load(fname) assert stats_agent1.unique_id == loaded_stats.unique_id # test hyperparemeter optimization loaded_stats.optimize_hyperparams(n_trials=5) # delete some writers stats_agent1.set_writer(1, None) stats_agent1.set_writer(2, None) stats_agent1.clear_output_dir() stats_agent2.clear_output_dir()