def create_render_map(model_labels, model_args_filepaths, model_params_filepaths, multi=False, rand=None, max_steps=200, n_vehs=None, remove_ngsim=False): render_map = dict() env_kwargs = dict() if rand != None: env_kwargs = dict(random_seed=rand) if not multi: env_kwargs = dict(egoid=worst_egoid, start=worst_start) render_kwargs = dict(camera_rotation=45., canvas_height=500, canvas_width=600) for i in range(len(model_labels)): print('\nrunning: {}'.format(model_labels[i])) # create session tf.reset_default_graph() sess = tf.InteractiveSession() # load args and params args = hyperparams.load_args(model_args_filepaths[i]) print('\nargs loaded from {}'.format(model_args_filepaths[i])) if multi: args.env_multiagent = True if remove_ngsim: args.remove_ngsim_veh = True if n_vehs: args.n_envs = 100 args.n_vehs = 100 params = hgail.misc.utils.load_params(model_params_filepaths[i]) print('\nparams loaded from {}'.format(model_params_filepaths[i])) # load env and params # Raunak adding in an argument for videmaking # See build_ngsim_env in utils.py for what this does env, _, _ = utils.build_ngsim_env(args, videoMaking=True) print("Raunak says: This is the videmaker reporting") normalized_env = hgail.misc.utils.extract_normalizing_env(env) if normalized_env is not None: normalized_env._obs_mean = params['normalzing']['obs_mean'] normalized_env._obs_var = params['normalzing']['obs_var'] # load policy if 'hgail' in model_labels[i]: policy = utils.build_hierarchy(args, env) else: policy = utils.build_policy(args, env) # initialize variables sess.run(tf.global_variables_initializer()) # load params if 'hgail' in model_labels[i]: for j, level in enumerate(policy): level.algo.policy.set_param_values(params[j]['policy']) policy = policy[0].algo.policy else: policy.set_param_values(params['policy']) # collect imgs if args.env_multiagent: imgs = mutliagent_simulate(env, policy, max_steps=max_steps, env_kwargs=env_kwargs, render_kwargs=render_kwargs) else: imgs = simulate(env, policy, max_steps=max_steps, env_kwargs=env_kwargs, render_kwargs=render_kwargs) render_map[model_labels[i]] = imgs return render_map
if __name__ == '__main__': parser = argparse.ArgumentParser(description='validation settings') parser.add_argument('--n_proc', type=int, default=1) parser.add_argument('--exp_dir', type=str, default='../../data/experiments/gail/') parser.add_argument('--params_filename', type=str, default='itr_2000.npz') parser.add_argument('--n_runs_per_ego_id', type=int, default=1) parser.add_argument('--use_hgail', type=str2bool, default=False) parser.add_argument('--use_multiagent', type=str2bool, default=False) parser.add_argument('--n_multiagent_trajs', type=int, default=10000) parser.add_argument('--debug', type=str2bool, default=False) run_args = parser.parse_args() args_filepath = os.path.join(run_args.exp_dir, 'imitate/log/args.npz') args = hyperparams.load_args(args_filepath) if run_args.use_multiagent: args.env_multiagent = True if run_args.debug: collect_fn = single_process_collect_trajectories else: collect_fn = parallel_collect_trajectories filenames = [ "trajdata_i80_trajectories-0400-0415.txt", "trajdata_i80_trajectories-0500-0515.txt", "trajdata_i80_trajectories-0515-0530.txt", "trajdata_i101_trajectories-0805am-0820am.txt", "trajdata_i101_trajectories-0820am-0835am.txt", "trajdata_i101_trajectories-0750am-0805am.txt"