def init_agent(env, config, total_step, seed): if env.agent == 'ia2c': return IA2C(env.n_s_ls, env.n_a_ls, env.neighbor_mask, env.distance_mask, env.coop_gamma, total_step, config, seed=seed) elif env.agent == 'ia2c_fp': return IA2C_FP(env.n_s_ls, env.n_a_ls, env.neighbor_mask, env.distance_mask, env.coop_gamma, total_step, config, seed=seed) elif env.agent == 'ma2c_nc': return MA2C_NC(env.n_s_ls, env.n_a_ls, env.neighbor_mask, env.distance_mask, env.coop_gamma, total_step, config, seed=seed) elif env.agent == 'ma2c_ic3': # this is actually CommNet return MA2C_IC3(env.n_s_ls, env.n_a_ls, env.neighbor_mask, env.distance_mask, env.coop_gamma, total_step, config, seed=seed) elif env.agent == 'ma2c_cu': return IA2C_CU(env.n_s_ls, env.n_a_ls, env.neighbor_mask, env.distance_mask, env.coop_gamma, total_step, config, seed=seed) elif env.agent == 'ma2c_dial': return MA2C_DIAL(env.n_s_ls, env.n_a_ls, env.neighbor_mask, env.distance_mask, env.coop_gamma, total_step, config, seed=seed) else: return None
def evaluate_fn(agent_dir, output_dir, seeds, port): agent = agent_dir.split('/')[-1] if not check_dir(agent_dir): logging.error('Evaluation: %s does not exist!' % agent) return # load config file for env config_dir = find_file(agent_dir) if not config_dir: return config = configparser.ConfigParser() config.read(config_dir) # init env env, greedy_policy = init_env(config['ENV_CONFIG'], port=port, naive_policy=True) env.init_test_seeds(seeds) # load model for agent if agent != 'greedy': # init centralized or multi agent if env.agent == 'ia2c': model = IA2C(env.n_s_ls, env.n_a, env.neighbor_mask, env.distance_mask, env.coop_gamma, 0, config['MODEL_CONFIG']) elif env.agent == 'ia2c_fp': model = IA2C_FP(env.n_s_ls, env.n_a, env.neighbor_mask, env.distance_mask, env.coop_gamma, 0, config['MODEL_CONFIG']) elif env.agent == 'ma2c_nc': model = MA2C_NC(env.n_s, env.n_a, env.neighbor_mask, env.distance_mask, env.coop_gamma, 0, config['MODEL_CONFIG']) else: return if not model.load(agent_dir + '/'): return else: model = greedy_policy env.agent = agent # collect evaluation data evaluator = Evaluator(env, model, output_dir) evaluator.run()
def train(args): base_dir = args.base_dir dirs = init_dir(base_dir) init_log(dirs['log']) config_dir = args.config_dir copy_file(config_dir, dirs['data']) config = configparser.ConfigParser() config.read(config_dir) in_test, post_test = init_test_flag(args.test_mode) # init env env = init_env(config['ENV_CONFIG']) logging.info('Training: a dim %d, agent dim: %d' % (env.n_a, env.n_agent)) # init step counter total_step = int(config.getfloat('TRAIN_CONFIG', 'total_step')) test_step = int(config.getfloat('TRAIN_CONFIG', 'test_interval')) log_step = int(config.getfloat('TRAIN_CONFIG', 'log_interval')) global_counter = Counter(total_step, test_step, log_step) # init centralized or multi agent seed = config.getint('ENV_CONFIG', 'seed') if env.agent == 'ia2c': model = IA2C(env.n_s_ls, env.n_a, env.neighbor_mask, env.distance_mask, env.coop_gamma, total_step, config['MODEL_CONFIG'], seed=seed) elif env.agent == 'ia2c_fp': model = IA2C_FP(env.n_s_ls, env.n_a, env.neighbor_mask, env.distance_mask, env.coop_gamma, total_step, config['MODEL_CONFIG'], seed=seed) elif env.agent == 'ma2c_nc': model = MA2C_NC(env.n_s, env.n_a, env.neighbor_mask, env.distance_mask, env.coop_gamma, total_step, config['MODEL_CONFIG'], seed=seed) else: model = None # disable multi-threading for safe SUMO implementation summary_writer = tf.summary.FileWriter(dirs['log']) trainer = Trainer(env, model, global_counter, summary_writer, in_test, output_path=dirs['data']) trainer.run() # save model final_step = global_counter.cur_step logging.info('Training: save final model at step %d ...' % final_step) model.save(dirs['model'], final_step) # post-training test if post_test: test_dirs = init_dir(base_dir, pathes=['eva_data']) evaluator = Evaluator(env, model, test_dirs['eva_data']) evaluator.run()