Пример #1
0
def main():
    filename_template = os.path.dirname(
        os.path.realpath(__file__)
    ) + '/../../datasets/regression/{num_agents}_agents_{dataset_name}_cadrl_dataset_action_value_{mode}.p'
    env, one_env = create_env()
    modes = [
        {
            'mode': 'train',
            'num_datapts': 100000,
        },
        {
            'mode': 'test',
            'num_datapts': 20000,
        },
    ]
    for mode in modes:
        STATES, ACTIONS, VALUES = fill(env,
                                       one_env,
                                       num_datapts=mode['num_datapts'])
        filename = filename_template.format(
            mode=mode['mode'],
            dataset_name=Config.DATASET_NAME,
            num_agents=Config.MAX_NUM_AGENTS_IN_ENVIRONMENT)
        file_dir = os.path.dirname(filename)
        os.makedirs(file_dir, exist_ok=True)

        with open(filename, "wb") as f:
            pickle.dump([STATES, ACTIONS, VALUES], f)

    print("Files written.")
def main():
    env, one_env = create_env()
    dt = one_env.dt_nominal

    last_time = 0.0

    plot_save_dir = file_dir + '/figs/'
    os.makedirs(plot_save_dir, exist_ok=True)

    pkl_dir = file_dir + '/trajs/'
    os.makedirs(pkl_dir, exist_ok=True)

    one_env.plot_save_dir = plot_save_dir
    one_env.scenario = ["single_agents_swap", "single_agents_random_swap","single_agents_random_positions", "single_corridor_scenario"]
    #one_env.scenario = ["agent_with_corridor"]
    one_env.ego_policy = "RVOPolicy"
    one_env.number_of_agents = 5
    env.reset()
    trajs = []
    if not os.path.isfile(file_dir + '/trajs/'+ "dataset_with_images.csv"):
        csvfile = open(file_dir + '/trajs/'+ "dataset.csv", 'w')
    # Write header
    writer = csv.writer(csvfile, delimiter=',', quoting=csv.QUOTE_MINIMAL)
    #writer.writerow(["Id","Time s","Time ns", "Position X",
    #                 "Position Y", "Velocity X", "Velocity Y", "Goal X",
    #                 "Goal Y", "coop_coef"])
    test_case = 0
    pbar = tqdm(total=num_test_cases)
    id = 0
    while test_case < num_test_cases:

        times_to_goal, extra_times_to_goal, collision, all_at_goal, any_stuck, agents = run_episode(env, one_env)

        # Change the global state history according with the number of steps required to finish the episode
        #if all_at_goal:
        for agent in agents:
            agent.global_state_history = agent.global_state_history[:agent.step_num]
        last_time = add_traj(agents, trajs, dt,last_time,writer)
        test_case +=1

        pbar.update(1)

        if (test_case % 500 == 0) and (test_case>8):
            fname = pkl_dir+'RVO'+ str(id) + '.pkl'
            #fname = pkl_dir + 'RVO' + str(id) + '.json'
            # Protocol 2 makes it compatible for Python 2 and 3
            #json.dump(trajs, open(fname, 'a'),indent=test_case)
            pickle.dump(trajs, open(fname,'wb'), protocol=2)
            print('dumped {}'.format(fname))
            trajs = []
            id += 1
    pbar.close()

    print("Experiment over.")
Пример #3
0
def main():
    env, one_env = create_env()
    dt = one_env.dt_nominal
    file_dir_template = os.path.dirname(os.path.realpath(
        __file__)) + '/../results/{results_subdir}/{num_agents}_agents'

    trajs = [[] for _ in range(num_test_cases)]

    for num_agents in num_agents_to_test:

        file_dir = file_dir_template.format(num_agents=num_agents,
                                            results_subdir=results_subdir)
        plot_save_dir = file_dir + '/figs/'
        os.makedirs(plot_save_dir, exist_ok=True)
        one_env.plot_save_dir = plot_save_dir

        test_case_args['num_agents'] = num_agents
        test_case_args['side_length'] = 7
        for test_case in tqdm(range(num_test_cases)):
            # test_case_args['test_case_index'] = test_case
            # test_case_args['num_test_cases'] = num_test_cases
            for policy in policies:
                one_env.plot_policy_name = policy
                policy_class = policies[policy]['policy']
                test_case_args['agents_policy'] = policy_class
                agents = test_case_fn(**test_case_args)
                for agent in agents:
                    if 'checkpt_name' in policies[policy]:
                        agent.policy.env = env
                        agent.policy.initialize_network(**policies[policy])
                one_env.set_agents(agents)
                one_env.test_case_index = test_case
                init_obs = env.reset()

                times_to_goal, extra_times_to_goal, collision, all_at_goal, any_stuck, agents = run_episode(
                    env, one_env)

                max_ts = [t / dt for t in times_to_goal]
                trajs = add_traj(agents, trajs, dt, test_case, max_ts)

        # print(trajs)

        one_env.reset()

        pkl_dir = file_dir + '/trajs/'
        os.makedirs(pkl_dir, exist_ok=True)
        fname = pkl_dir + policy + '.pkl'
        pickle.dump(trajs, open(fname, 'wb'))
        print('dumped {}'.format(fname))

    print("Experiment over.")
Пример #4
0
 def _set_env(self, id):
     if Config.GAME_CHOICE == Config.game_grid:
         self.game = Gridworld(id, Config.ENV_ROW, Config.ENV_COL,
                               Config.PIXEL_SIZE, Config.MAX_ITER,
                               Config.AGENT_COLOR, Config.TARGET_COLOR,
                               Config.DISPLAY_SCREEN, Config.TIMER_DURATION,
                               Config.IMAGE_WIDTH, Config.IMAGE_HEIGHT,
                               Config.STACKED_FRAMES, Config.DEBUG)
     elif Config.GAME_CHOICE == Config.game_collision_avoidance:
         from gym_collision_avoidance.experiments.src.env_utils import run_episode, create_env, store_stats
         env, one_env = create_env()
         self.game = env
     else:
         raise ValueError(
             "[ ERROR ] Invalid choice of game. Check Config.py for choices"
         )
Пример #5
0
def main():
    np.random.seed(0)

    test_case_fn = tc.formation
    test_case_args = {}

    env, one_env = create_env()

    one_env.set_plot_save_dir(
        os.path.dirname(os.path.realpath(__file__)) +
        '/../results/cadrl_formations/')

    for num_agents in Config.NUM_AGENTS_TO_TEST:
        for policy in Config.POLICIES_TO_TEST:
            np.random.seed(0)
            prev_agents = None
            for test_case in range(Config.NUM_TEST_CASES):
                _ = reset_env(env, one_env, test_case_fn, test_case_args,
                              test_case, num_agents, policies, policy,
                              prev_agents)
                episode_stats, prev_agents = run_episode(env, one_env)

    return True
num_test_cases = 100
test_case_args = {}
Config.PLOT_CIRCLES_ALONG_TRAJ = True

vpref1 = False
radius_bounds = [0.5, 0.5]
if vpref1:
    test_case_args['vpref_constraint'] = True
    test_case_args['radius_bounds'] = radius_bounds
    vpref1_str = 'vpref1.0_r{}-{}/'.format(radius_bounds[0], radius_bounds[1])
else:
    vpref1_str = ''

Config.NUM_TEST_CASES = num_test_cases

env, one_env = create_env()

for num_agents in num_agents_to_test:

    one_env.set_plot_save_dir(
        os.path.dirname(os.path.realpath(__file__)) +
        '/../results/full_test_suites/{vpref1_str}{num_agents}_agents/figs/'.
        format(vpref1_str=vpref1_str, num_agents=num_agents))

    test_case_args['num_agents'] = num_agents
    stats = {}
    for policy in policies:
        stats[policy] = {}
        stats[policy]['non_collision_inds'] = []
        stats[policy]['all_at_goal_inds'] = []
        stats[policy]['stuck_inds'] = []
def main():
    np.random.seed(0)

    test_case_fn = tc.full_test_suite
    test_case_args = {}

    if Config.FIXED_RADIUS_AND_VPREF:
        radius_bounds = [0.5, 0.5]
        test_case_args['vpref_constraint'] = True
        test_case_args['radius_bounds'] = radius_bounds
        vpref1_str = 'vpref1.0_r{}-{}/'.format(radius_bounds[0],
                                               radius_bounds[1])
    else:
        vpref1_str = ''

    env, one_env = create_env()

    print(
        "Running {test_cases} test cases for {num_agents} for policies: {policies}"
        .format(
            test_cases=Config.NUM_TEST_CASES,
            num_agents=Config.NUM_AGENTS_TO_TEST,
            policies=Config.POLICIES_TO_TEST,
        ))
    with tqdm(total=len(Config.NUM_AGENTS_TO_TEST) *
              len(Config.POLICIES_TO_TEST) * Config.NUM_TEST_CASES) as pbar:
        for num_agents in Config.NUM_AGENTS_TO_TEST:
            one_env.set_plot_save_dir(
                os.path.dirname(os.path.realpath(__file__)) +
                '/../results/full_test_suites/{vpref1_str}{num_agents}_agents/figs/'
                .format(vpref1_str=vpref1_str, num_agents=num_agents))
            for policy in Config.POLICIES_TO_TEST:
                np.random.seed(0)
                prev_agents = None
                df = pd.DataFrame()
                for test_case in range(Config.NUM_TEST_CASES):
                    ##### Actually run the episode ##########
                    _ = reset_env(env, one_env, test_case_fn, test_case_args,
                                  test_case, num_agents, policies, policy,
                                  prev_agents)
                    episode_stats, prev_agents = run_episode(env, one_env)
                    df = store_stats(df, {
                        'test_case': test_case,
                        'policy_id': policy
                    }, episode_stats)
                    ########################################
                    pbar.update(1)

                if Config.RECORD_PICKLE_FILES:
                    file_dir = os.path.dirname(
                        os.path.realpath(__file__)
                    ) + '/../results/full_test_suites/{vpref1_str}'.format(
                        vpref1_str=vpref1_str)
                    file_dir += '{num_agents}_agents/stats/'.format(
                        num_agents=num_agents)
                    os.makedirs(file_dir, exist_ok=True)
                    log_filename = file_dir + '/stats_{}.p'.format(policy)
                    # log_filename = file_dir+'/stats_{}_{}.p'.format(policy, now.strftime("%m_%d_%Y__%H_%M_%S"))
                    df.to_pickle(log_filename)

    return True