コード例 #1
0
def evaluate_fn(agent_dir, output_dir, seeds, port, demo, policy_type):
    agent = agent_dir.split('/')[-1]
    if not check_dir(agent_dir):
        logging.error('Evaluation: %s does not exist!' % agent)
        return
    # load config file for env
    config_dir = find_file(agent_dir + '/data/')
    if not config_dir:
        return
    config = configparser.ConfigParser()
    config.read(config_dir)

    # init env
    env, greedy_policy = init_env(config['ENV_CONFIG'],
                                  port=port,
                                  naive_policy=True)
    logging.info(
        'Evaluation: s dim: %d, a dim %d, s dim ls: %r, a dim ls: %r' %
        (env.n_s, env.n_a, env.n_s_ls, env.n_a_ls))
    env.init_test_seeds(seeds)

    # load model for agent
    if agent != 'greedy':
        # init centralized or multi agent
        if agent == 'a2c':
            model = A2C(env.n_s, env.n_a, 0, config['MODEL_CONFIG'])
        elif agent == 'ia2c':
            model = IA2C(env.n_s_ls, env.n_a_ls, env.n_w_ls, 0,
                         config['MODEL_CONFIG'])
        elif agent == 'ma2c':
            model = MA2C(env.n_s_ls, env.n_a_ls, env.n_w_ls, env.n_f_ls, 0,
                         config['MODEL_CONFIG'])
        elif agent == 'iqld':
            model = IQL(env.n_s_ls,
                        env.n_a_ls,
                        env.n_w_ls,
                        0,
                        config['MODEL_CONFIG'],
                        seed=0,
                        model_type='dqn')
        else:
            model = IQL(env.n_s_ls,
                        env.n_a_ls,
                        env.n_w_ls,
                        0,
                        config['MODEL_CONFIG'],
                        seed=0,
                        model_type='lr')
        if not model.load(agent_dir + '/model/'):
            return
    else:
        model = greedy_policy
    env.agent = agent
    # collect evaluation data
    evaluator = Evaluator(env,
                          model,
                          output_dir,
                          demo=demo,
                          policy_type=policy_type)
    evaluator.run()
コード例 #2
0
ファイル: main.py プロジェクト: zhiyongc/deeprl_network
def evaluate_fn(agent_dir, output_dir, seeds, port, demo):
    agent = agent_dir.split('/')[-1]
    if not check_dir(agent_dir):
        logging.error('Evaluation: %s does not exist!' % agent)
        return
    # load config file 
    config_dir = find_file(agent_dir + '/data/')
    if not config_dir:
        return
    config = configparser.ConfigParser()
    config.read(config_dir)

    # init env
    env = init_env(config['ENV_CONFIG'], port=port)
    env.init_test_seeds(seeds)

    # load model for agent
    model = init_agent(env, config['MODEL_CONFIG'], 0, 0)
    if model is None:
        return
    model_dir = agent_dir + '/model/'
    if not model.load(model_dir):
        return
    # collect evaluation data
    evaluator = Evaluator(env, model, output_dir, gui=demo)
    evaluator.run()
コード例 #3
0
ファイル: main.py プロジェクト: Dogordog/deeprl_dist
def evaluate_fn(agent_dir, output_dir, seeds, port):
    agent = agent_dir.split('/')[-1]
    if not check_dir(agent_dir):
        logging.error('Evaluation: %s does not exist!' % agent)
        return
    # load config file for env
    config_dir = find_file(agent_dir + '/data/')
    if not config_dir:
        return
    config = configparser.ConfigParser()
    config.read(config_dir)

    # init env
    env, greedy_policy = init_env(config['ENV_CONFIG'],
                                  port=port,
                                  naive_policy=True)
    env.init_test_seeds(seeds)

    # load model for agent
    if agent != 'greedy':
        # init centralized or multi agent
        model = init_agent(env, config['MODEL_CONFIG'], 0, 0)
        if model is None:
            return
        if not model.load(agent_dir + '/model/'):
            return
    else:
        model = greedy_policy
    # collect evaluation data
    evaluator = Evaluator(env, model, output_dir)
    evaluator.run()
コード例 #4
0
    def run(self):
        # checkpoint
        self.model = self.model.eval()
        self.dataset, test_gallery = get_initial_test(
            self.config, test=True)  # return dataset instance

        print("data set len is :", len(self.dataset))
        data_gallery, vID_gallery, label_gallery = test_gallery[
            0], test_gallery[1], test_gallery[2]

        print("sample leve ----------->", self.config.test.sampler)

        # dataloader define
        self.data_loader = DataLoader(dataset=self.dataset,
                                      batch_size=1,
                                      sampler=SequentialSampler(self.dataset),
                                      collate_fn=self.collate_fn,
                                      num_workers=self.num_workers)

        len_gallery = len(label_gallery)

        feature_gallery = self.extract_gallery_feature(data_gallery,
                                                       len_gallery)

        probe_feature = list()
        probe_vID = list()

        for seq, vID, label, _ in tqdm(self.data_loader):

            seq = torch.from_numpy(seq).float().cuda()
            # print(seq.size())
            fc, out = self.model(seq)
            n, num_bin = fc.size()
            feat = fc.view(n, -1).data.cpu().numpy()

            for ii in range(n):
                feat[ii] = feat[ii] / np.linalg.norm(feat[ii])

            probe_feature.append(feat)
            probe_vID += vID

        test_gallery = feature_gallery, vID_gallery, label_gallery
        feature_probe = np.concatenate(probe_feature, 0)
        test_probe = feature_probe, probe_vID

        self.save_npy(feature_gallery, "feature_gallery.npy")
        self.save_npy(vID_gallery, "vID_gallery.npy")
        self.save_npy(label_gallery, "label_gallery.npy")
        self.save_npy(feature_probe, "feature_probe.npy")
        self.save_npy(probe_vID, "probe_vID.npy")

        evaluation = Evaluator(test_gallery, test_probe, self.config)
        evaluation.run()
コード例 #5
0
ファイル: main.py プロジェクト: Dogordog/deeprl_dist
def train(args):
    base_dir = args.base_dir
    dirs = init_dir(base_dir)
    init_log(dirs['log'])
    config_dir = args.config_dir
    copy_file(config_dir, dirs['data'])
    config = configparser.ConfigParser()
    config.read(config_dir)
    in_test, post_test = init_test_flag(args.test_mode)

    # init env
    env = init_env(config['ENV_CONFIG'])
    logging.info('Training: a dim %d, agent dim: %d' % (env.n_a, env.n_agent))

    # init step counter
    total_step = int(config.getfloat('TRAIN_CONFIG', 'total_step'))
    test_step = int(config.getfloat('TRAIN_CONFIG', 'test_interval'))
    log_step = int(config.getfloat('TRAIN_CONFIG', 'log_interval'))
    global_counter = Counter(total_step, test_step, log_step)

    # init centralized or multi agent
    seed = config.getint('ENV_CONFIG', 'seed')
    model = init_agent(env, config['MODEL_CONFIG'], total_step, seed)

    # disable multi-threading for safe SUMO implementation
    summary_writer = tf.summary.FileWriter(dirs['log'])
    trainer = Trainer(env,
                      model,
                      global_counter,
                      summary_writer,
                      in_test,
                      output_path=dirs['data'])
    trainer.run()

    # save model
    final_step = global_counter.cur_step
    logging.info('Training: save final model at step %d ...' % final_step)
    model.save(dirs['model'], final_step)

    # post-training test
    if post_test:
        test_dirs = init_dir(base_dir, pathes=['eva_data'])
        evaluator = Evaluator(env, model, test_dirs['eva_data'])
        evaluator.run()
コード例 #6
0
ファイル: main.py プロジェクト: murtazarang/deeprl_dist
def evaluate_fn(agent_dir, output_dir, seeds, port):
    agent = agent_dir.split('/')[-1]
    if not check_dir(agent_dir):
        logging.error('Evaluation: %s does not exist!' % agent)
        return
    # load config file for env
    config_dir = find_file(agent_dir)
    if not config_dir:
        return
    config = configparser.ConfigParser()
    config.read(config_dir)

    # init env
    env, greedy_policy = init_env(config['ENV_CONFIG'],
                                  port=port,
                                  naive_policy=True)
    env.init_test_seeds(seeds)

    # load model for agent
    if agent != 'greedy':
        # init centralized or multi agent
        if env.agent == 'ia2c':
            model = IA2C(env.n_s_ls, env.n_a, env.neighbor_mask,
                         env.distance_mask, env.coop_gamma, 0,
                         config['MODEL_CONFIG'])
        elif env.agent == 'ia2c_fp':
            model = IA2C_FP(env.n_s_ls, env.n_a, env.neighbor_mask,
                            env.distance_mask, env.coop_gamma, 0,
                            config['MODEL_CONFIG'])
        elif env.agent == 'ma2c_nc':
            model = MA2C_NC(env.n_s, env.n_a, env.neighbor_mask,
                            env.distance_mask, env.coop_gamma, 0,
                            config['MODEL_CONFIG'])
        else:
            return
        if not model.load(agent_dir + '/'):
            return
    else:
        model = greedy_policy
    env.agent = agent
    # collect evaluation data
    evaluator = Evaluator(env, model, output_dir)
    evaluator.run()
コード例 #7
0
def evaluate_fn(agent_dir, output_dir, seeds, port, demo):
    agent = agent_dir.split('/')[-1]
    doubleQ = True
    if agent == 'ddqn':
        doubleQ = False
        agent = 'dqn'
    if not check_dir(agent_dir):
        logging.error('Evaluation: %s does not exist!' % agent)
        return
    # load config file for env
    config_dir = find_file(agent_dir + '/data/')
    if not config_dir:
        return
    config = configparser.ConfigParser()
    config.read(config_dir)

    # init env
    env, greedy_policy = init_env(config['ENV_CONFIG'],
                                  port=port,
                                  naive_policy=True)
    logging.info(
        'Evaluation: s dim: %d, a dim %d, s dim ls: %r, a dim ls: %r' %
        (env.n_s, env.n_a, env.n_s_ls, env.n_a_ls))
    env.init_test_seeds(seeds)

    # load model for agent
    if agent != 'greedy':
        # init centralized or multi agent
        if agent == 'a2c':
            model = A2C(env.n_s, env.n_a, 0, config['MODEL_CONFIG'])
        elif agent == 'ia2c':
            model = IA2C(env.n_s_ls, env.n_a_ls, env.n_w_ls, 0,
                         config['MODEL_CONFIG'])
        elif agent == 'ma2c':
            model = MA2C(env.n_s_ls, env.n_a_ls, env.n_w_ls, env.n_f_ls, 0,
                         config['MODEL_CONFIG'])
        elif agent == 'codql':
            print('This is codql')
            model = MFQ(nb_agent=len(env.n_s_ls),
                        a_dim=env.n_a_ls[0],
                        s_dim=env.n_s_ls[0],
                        s_dim_wave=env.n_s_ls[0] - env.n_w_ls[0],
                        s_dim_wait=env.n_w_ls[0],
                        config=config['MODEL_CONFIG'])
        elif agent == 'dqn':
            model = DQN(nb_agent=len(env.n_s_ls),
                        a_dim=env.n_a_ls[0],
                        s_dim=env.n_s_ls[0],
                        s_dim_wave=env.n_s_ls[0] - env.n_w_ls[0],
                        s_dim_wait=env.n_w_ls[0],
                        config=config['MODEL_CONFIG'],
                        doubleQ=doubleQ)  #doubleQ=False denotes dqn else ddqn
        elif agent == 'ddpg':
            model = DDPGEN(nb_agent=len(env.n_s_ls),
                           share_params=True,
                           a_dim=env.n_a_ls[0],
                           s_dim=env.n_s_ls[0],
                           s_dim_wave=env.n_s_ls[0] - env.n_w_ls[0],
                           s_dim_wait=env.n_w_ls[0])
        elif agent == 'iqld':
            model = IQL(env.n_s_ls,
                        env.n_a_ls,
                        env.n_w_ls,
                        0,
                        config['MODEL_CONFIG'],
                        seed=0,
                        model_type='dqn')
        else:
            model = IQL(env.n_s_ls,
                        env.n_a_ls,
                        env.n_w_ls,
                        0,
                        config['MODEL_CONFIG'],
                        seed=0,
                        model_type='lr')
        if not model.load(agent_dir + '/model/'):
            return
    else:
        model = greedy_policy
    env.agent = agent
    # collect evaluation data
    evaluator = Evaluator(env, model, output_dir, demo=demo)
    evaluator.run()