def launch(config): config = convert_dottable(config) env = Env(config.env.scenario, config.env.topology, durations=config.env.durations) agent_id_list = [str(agent_id) for agent_id in env.agent_idx_list] config["agents"]["input_dim"] = CIMStateShaper( **config.env.state_shaping).dim agent_manager = POAgentManager(name="cim_learner", mode=AgentManagerMode.TRAIN, agent_dict=create_po_agents( agent_id_list, config.agents)) proxy_params = { "group_name": os.environ["GROUP"], "expected_peers": { "actor": int(os.environ["NUM_ACTORS"]) }, "redis_address": ("localhost", 6379) } learner = SimpleLearner( agent_manager=agent_manager, actor=ActorProxy(proxy_params=proxy_params, experience_collecting_func= merge_experiences_with_trajectory_boundaries), scheduler=Scheduler(config.main_loop.max_episode), logger=Logger("cim_learner", auto_timestamp=False)) learner.learn() learner.test() learner.dump_models(os.path.join(os.getcwd(), "models")) learner.exit()
def cim_dqn_learner(): env = Env(**training_config["env"]) agent = MultiAgentWrapper( {name: get_dqn_agent() for name in env.agent_idx_list}) scheduler = TwoPhaseLinearParameterScheduler( training_config["max_episode"], **training_config["exploration"]) actor = ActorProxy( training_config["group"], training_config["num_actors"], update_trigger=training_config["learner_update_trigger"]) learner = OffPolicyLearner(actor, scheduler, agent, **training_config["training"]) learner.run()
def launch(config, distributed_config): config = convert_dottable(config) distributed_config = convert_dottable(distributed_config) env = Env(config.env.scenario, config.env.topology, durations=config.env.durations) agent_id_list = [str(agent_id) for agent_id in env.agent_idx_list] config["agents"]["algorithm"]["input_dim"] = CIMStateShaper( **config.env.state_shaping).dim agent_manager = DQNAgentManager(name="cim_learner", mode=AgentManagerMode.TRAIN, agent_dict=create_dqn_agents( agent_id_list, config.agents)) proxy_params = { "group_name": os.environ["GROUP"] if "GROUP" in os.environ else distributed_config.group, "expected_peers": { "actor": int(os.environ["NUM_ACTORS"] if "NUM_ACTORS" in os.environ else distributed_config.num_actors) }, "redis_address": (distributed_config.redis.hostname, distributed_config.redis.port), "max_retries": 15 } learner = SimpleLearner( agent_manager=agent_manager, actor=ActorProxy( proxy_params=proxy_params, experience_collecting_func=concat_experiences_by_agent), scheduler=TwoPhaseLinearParameterScheduler( config.main_loop.max_episode, **config.main_loop.exploration), logger=Logger("cim_learner", auto_timestamp=False)) learner.learn() learner.test() learner.dump_models(os.path.join(os.getcwd(), "models")) learner.exit()
from maro.simulator import Env from maro.utils import Logger from components.agent_manager import DQNAgentManager from components.config import config from components.state_shaper import CIMStateShaper if __name__ == "__main__": env = Env(config.env.scenario, config.env.topology, durations=config.env.durations) agent_id_list = [str(agent_id) for agent_id in env.agent_idx_list] state_shaper = CIMStateShaper(**config.state_shaping) exploration_config = {"epsilon_range_dict": {"_all_": config.exploration.epsilon_range}, "split_point_dict": {"_all_": config.exploration.split_point}, "with_cache": config.exploration.with_cache } explorer = TwoPhaseLinearExplorer(agent_id_list, config.general.total_training_episodes, **exploration_config) agent_manager = DQNAgentManager(name="cim_remote_learner", agent_id_list=agent_id_list, mode=AgentMode.TRAIN, state_shaper=state_shaper, explorer=explorer) proxy_params = {"group_name": config.distributed.group_name, "expected_peers": config.distributed.learner.peer, "redis_address": (config.distributed.redis.host_name, config.distributed.redis.port) } learner = SimpleLearner(trainable_agents=agent_manager, actor=ActorProxy(proxy_params=proxy_params), logger=Logger("distributed_cim_learner", auto_timestamp=False)) learner.train(total_episodes=config.general.total_training_episodes) learner.test() learner.dump_models(os.path.join(os.getcwd(), "models"))