コード例 #1
0
def create_render_map(model_labels,
                      model_args_filepaths,
                      model_params_filepaths,
                      multi=False,
                      rand=None,
                      max_steps=200,
                      n_vehs=None,
                      remove_ngsim=False):
    render_map = dict()
    env_kwargs = dict()
    if rand != None:
        env_kwargs = dict(random_seed=rand)
    if not multi:
        env_kwargs = dict(egoid=worst_egoid, start=worst_start)
    render_kwargs = dict(camera_rotation=45.,
                         canvas_height=500,
                         canvas_width=600)
    for i in range(len(model_labels)):
        print('\nrunning: {}'.format(model_labels[i]))

        # create session
        tf.reset_default_graph()
        sess = tf.InteractiveSession()

        # load args and params
        args = hyperparams.load_args(model_args_filepaths[i])
        print('\nargs loaded from {}'.format(model_args_filepaths[i]))
        if multi:
            args.env_multiagent = True
            if remove_ngsim:
                args.remove_ngsim_veh = True

            if n_vehs:
                args.n_envs = 100
                args.n_vehs = 100
        params = hgail.misc.utils.load_params(model_params_filepaths[i])
        print('\nparams loaded from {}'.format(model_params_filepaths[i]))

        # load env and params
        # Raunak adding in an argument for videmaking
        # See build_ngsim_env in utils.py for what this does
        env, _, _ = utils.build_ngsim_env(args, videoMaking=True)
        print("Raunak says: This is the videmaker reporting")
        normalized_env = hgail.misc.utils.extract_normalizing_env(env)
        if normalized_env is not None:
            normalized_env._obs_mean = params['normalzing']['obs_mean']
            normalized_env._obs_var = params['normalzing']['obs_var']

        # load policy
        if 'hgail' in model_labels[i]:
            policy = utils.build_hierarchy(args, env)
        else:
            policy = utils.build_policy(args, env)

        # initialize variables
        sess.run(tf.global_variables_initializer())

        # load params
        if 'hgail' in model_labels[i]:
            for j, level in enumerate(policy):
                level.algo.policy.set_param_values(params[j]['policy'])
            policy = policy[0].algo.policy
        else:
            policy.set_param_values(params['policy'])

        # collect imgs
        if args.env_multiagent:
            imgs = mutliagent_simulate(env,
                                       policy,
                                       max_steps=max_steps,
                                       env_kwargs=env_kwargs,
                                       render_kwargs=render_kwargs)
        else:
            imgs = simulate(env,
                            policy,
                            max_steps=max_steps,
                            env_kwargs=env_kwargs,
                            render_kwargs=render_kwargs)
        render_map[model_labels[i]] = imgs
    return render_map
コード例 #2
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='validation settings')
    parser.add_argument('--n_proc', type=int, default=1)
    parser.add_argument('--exp_dir',
                        type=str,
                        default='../../data/experiments/gail/')
    parser.add_argument('--params_filename', type=str, default='itr_2000.npz')
    parser.add_argument('--n_runs_per_ego_id', type=int, default=1)
    parser.add_argument('--use_hgail', type=str2bool, default=False)
    parser.add_argument('--use_multiagent', type=str2bool, default=False)
    parser.add_argument('--n_multiagent_trajs', type=int, default=10000)
    parser.add_argument('--debug', type=str2bool, default=False)
    run_args = parser.parse_args()

    args_filepath = os.path.join(run_args.exp_dir, 'imitate/log/args.npz')
    args = hyperparams.load_args(args_filepath)
    if run_args.use_multiagent:
        args.env_multiagent = True

    if run_args.debug:
        collect_fn = single_process_collect_trajectories
    else:
        collect_fn = parallel_collect_trajectories

    filenames = [
        "trajdata_i80_trajectories-0400-0415.txt",
        "trajdata_i80_trajectories-0500-0515.txt",
        "trajdata_i80_trajectories-0515-0530.txt",
        "trajdata_i101_trajectories-0805am-0820am.txt",
        "trajdata_i101_trajectories-0820am-0835am.txt",
        "trajdata_i101_trajectories-0750am-0805am.txt"