예제 #1
0
def init_agent(env, config, total_step, seed):
    if env.agent == 'ia2c':
        return IA2C(env.n_s_ls,
                    env.n_a_ls,
                    env.neighbor_mask,
                    env.distance_mask,
                    env.coop_gamma,
                    total_step,
                    config,
                    seed=seed)
    elif env.agent == 'ia2c_fp':
        return IA2C_FP(env.n_s_ls,
                       env.n_a_ls,
                       env.neighbor_mask,
                       env.distance_mask,
                       env.coop_gamma,
                       total_step,
                       config,
                       seed=seed)
    elif env.agent == 'ma2c_nc':
        return MA2C_NC(env.n_s_ls,
                       env.n_a_ls,
                       env.neighbor_mask,
                       env.distance_mask,
                       env.coop_gamma,
                       total_step,
                       config,
                       seed=seed)
    elif env.agent == 'ma2c_ic3':
        # this is actually CommNet
        return MA2C_IC3(env.n_s_ls,
                        env.n_a_ls,
                        env.neighbor_mask,
                        env.distance_mask,
                        env.coop_gamma,
                        total_step,
                        config,
                        seed=seed)
    elif env.agent == 'ma2c_cu':
        return IA2C_CU(env.n_s_ls,
                       env.n_a_ls,
                       env.neighbor_mask,
                       env.distance_mask,
                       env.coop_gamma,
                       total_step,
                       config,
                       seed=seed)
    elif env.agent == 'ma2c_dial':
        return MA2C_DIAL(env.n_s_ls,
                         env.n_a_ls,
                         env.neighbor_mask,
                         env.distance_mask,
                         env.coop_gamma,
                         total_step,
                         config,
                         seed=seed)
    else:
        return None
예제 #2
0
def evaluate_fn(agent_dir, output_dir, seeds, port):
    agent = agent_dir.split('/')[-1]
    if not check_dir(agent_dir):
        logging.error('Evaluation: %s does not exist!' % agent)
        return
    # load config file for env
    config_dir = find_file(agent_dir)
    if not config_dir:
        return
    config = configparser.ConfigParser()
    config.read(config_dir)

    # init env
    env, greedy_policy = init_env(config['ENV_CONFIG'],
                                  port=port,
                                  naive_policy=True)
    env.init_test_seeds(seeds)

    # load model for agent
    if agent != 'greedy':
        # init centralized or multi agent
        if env.agent == 'ia2c':
            model = IA2C(env.n_s_ls, env.n_a, env.neighbor_mask,
                         env.distance_mask, env.coop_gamma, 0,
                         config['MODEL_CONFIG'])
        elif env.agent == 'ia2c_fp':
            model = IA2C_FP(env.n_s_ls, env.n_a, env.neighbor_mask,
                            env.distance_mask, env.coop_gamma, 0,
                            config['MODEL_CONFIG'])
        elif env.agent == 'ma2c_nc':
            model = MA2C_NC(env.n_s, env.n_a, env.neighbor_mask,
                            env.distance_mask, env.coop_gamma, 0,
                            config['MODEL_CONFIG'])
        else:
            return
        if not model.load(agent_dir + '/'):
            return
    else:
        model = greedy_policy
    env.agent = agent
    # collect evaluation data
    evaluator = Evaluator(env, model, output_dir)
    evaluator.run()
예제 #3
0
def train(args):
    base_dir = args.base_dir
    dirs = init_dir(base_dir)
    init_log(dirs['log'])
    config_dir = args.config_dir
    copy_file(config_dir, dirs['data'])
    config = configparser.ConfigParser()
    config.read(config_dir)
    in_test, post_test = init_test_flag(args.test_mode)

    # init env
    env = init_env(config['ENV_CONFIG'])
    logging.info('Training: a dim %d, agent dim: %d' % (env.n_a, env.n_agent))

    # init step counter
    total_step = int(config.getfloat('TRAIN_CONFIG', 'total_step'))
    test_step = int(config.getfloat('TRAIN_CONFIG', 'test_interval'))
    log_step = int(config.getfloat('TRAIN_CONFIG', 'log_interval'))
    global_counter = Counter(total_step, test_step, log_step)

    # init centralized or multi agent
    seed = config.getint('ENV_CONFIG', 'seed')

    if env.agent == 'ia2c':
        model = IA2C(env.n_s_ls,
                     env.n_a,
                     env.neighbor_mask,
                     env.distance_mask,
                     env.coop_gamma,
                     total_step,
                     config['MODEL_CONFIG'],
                     seed=seed)
    elif env.agent == 'ia2c_fp':
        model = IA2C_FP(env.n_s_ls,
                        env.n_a,
                        env.neighbor_mask,
                        env.distance_mask,
                        env.coop_gamma,
                        total_step,
                        config['MODEL_CONFIG'],
                        seed=seed)
    elif env.agent == 'ma2c_nc':
        model = MA2C_NC(env.n_s,
                        env.n_a,
                        env.neighbor_mask,
                        env.distance_mask,
                        env.coop_gamma,
                        total_step,
                        config['MODEL_CONFIG'],
                        seed=seed)
    else:
        model = None

    # disable multi-threading for safe SUMO implementation
    summary_writer = tf.summary.FileWriter(dirs['log'])
    trainer = Trainer(env,
                      model,
                      global_counter,
                      summary_writer,
                      in_test,
                      output_path=dirs['data'])
    trainer.run()

    # save model
    final_step = global_counter.cur_step
    logging.info('Training: save final model at step %d ...' % final_step)
    model.save(dirs['model'], final_step)

    # post-training test
    if post_test:
        test_dirs = init_dir(base_dir, pathes=['eva_data'])
        evaluator = Evaluator(env, model, test_dirs['eva_data'])
        evaluator.run()