def _init(args): env_name = args.env_name print('Using environment %s' % env_name) params_dict = { 'env_name': [env_name], 'rundir': [args.rundir], 'ent_wt': [args.trpo_ent], 'trpo_step': [args.trpo_step], 'hid_size': [args.hid_size], 'hid_layers': [args.hid_layers], 'many_runs': [args.repeat > 1] } if args.repeat > 1: # stacked parallel thing doesn't work, bleh warnings.warn( "You're trying to use --repeat N for N > 1, but that " "disables parallel sampling. This is probably going to be " "heinously slow or something, use at own risk.") # parallel_sampler.initialize(n_parallel=1) # parallel_sampler.set_seed(1) run_sweep_parallel(main, params_dict, repeat=args.repeat) else: parallel_sampler.initialize(n_parallel=8) parallel_sampler.set_seed(1) run_sweep_serial(main, params_dict, repeat=1)
expert_trajs=experts, state_only=True, fusion=fusion, max_itrs=10) policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) algo = IRLTRPO( env=env, policy=policy, irl_model=irl_model, n_itr=1000, batch_size=10000, max_path_length=500, discount=0.99, store_paths=True, irl_model_wt=1.0, entropy_weight=0.1, zero_environment_reward=True, baseline=LinearFeatureBaseline(env_spec=env.spec), ) with rllab_logdir(algo=algo, dirname='data/ant_state_irl/%s' % exp_name): with tf.Session(): algo.train() if __name__ == "__main__": params_dict = {'fusion': [True]} run_sweep_parallel(main, params_dict, repeat=3)
default=1e-5, type=float, help='step size for --adaptive-beta', ) if __name__ == "__main__": args = parser.parse_args() env_name = args.env_name print('Args:', args) params_dict = { 'method': [args.method], 'rundir': [args.rundir], 'env_name': [env_name], 'disc_step': [args.disc_step], 'trpo_step': [args.trpo_step], 'ent_wt': [args.trpo_ent], 'hid_size': [args.hid_size], 'hid_layers': [args.hid_layers], 'disc_iters': [args.disc_iters], 'disc_batch_size': [args.disc_batch_size], 'disc_gp': [args.dreg_gp], 'max_traj': [args.max_traj], # VAIR/VAIL params 'beta': [getattr(args, 'beta', None)], 'adaptive_beta': [getattr(args, 'adaptive_beta', None)], 'target_kl': [getattr(args, 'adaptive_beta_target_kl', None)], 'beta_step': [getattr(args, 'adaptive_beta_step', None)], } run_sweep_parallel(main, params_dict, repeat=args.repeat) # run_sweep_serial(main, params_dict, repeat=1)