def train(args): import gailtf.baselines.common.tf_util as U sess = U.single_threaded_session() sess.__enter__() rank = MPI.COMM_WORLD.Get_rank() if rank != 0: logger.set_level(logger.DISABLED) workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank() set_global_seeds(workerseed) env = gym.make(args.env_id) def policy_fn(name, ob_space, ac_space): return MlpPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space, hid_size=32, num_hid_layers=2) env = bench.Monitor( env, logger.get_dir() and osp.join(logger.get_dir(), "%i.monitor.json" % rank)) env.seed(workerseed) gym.logger.setLevel(logging.WARN) task_name = "trpo." + args.env_id.split("-")[0] + "." + ("%.2f" % args.entcoeff) args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name) trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1, max_timesteps=args.num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3, sample_stochastic=args.sample_stochastic, task_name=task_name, save_per_iter=args.save_per_iter, ckpt_dir=args.checkpoint_dir, load_model_path=args.load_model_path, task=args.task) env.close()
def train(args): from gailtf.baselines.ppo1 import mlp_policy, pposgd_simple U.make_session(num_cpu=args.num_cpu).__enter__() set_global_seeds(args.seed) env = gym.make(args.env_id) def policy_fn(name, ob_space, ac_space): return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, hid_size=64, num_hid_layers=2) env = bench.Monitor( env, logger.get_dir() and osp.join(logger.get_dir(), "monitor.json")) env.seed(args.seed) gym.logger.setLevel(logging.WARN) task_name = "ppo." + args.env_id.split("-")[0] + "." + ("%.2f" % args.entcoeff) args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name) pposgd_simple.learn(env, policy_fn, max_timesteps=args.num_timesteps, timesteps_per_batch=2048, clip_param=0.2, entcoeff=args.entcoeff, optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64, gamma=0.99, lam=0.95, schedule='linear', ckpt_dir=args.checkpoint_dir, save_per_iter=args.save_per_iter, task=args.task, sample_stochastic=args.sample_stochastic, load_model_path=args.load_model_path, task_name=task_name) env.close()
def main(args): from gailtf.baselines.ppo1 import mlp_policy U.make_session(num_cpu=args.num_cpu).__enter__() set_global_seeds(args.seed) env = gym.make(args.env_id) def policy_fn(name, ob_space, ac_space, reuse=False): return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, reuse=reuse, hid_size=64, num_hid_layers=2) env = bench.Monitor( env, logger.get_dir() and osp.join(logger.get_dir(), "monitor.json")) env.seed(args.seed) gym.logger.setLevel(logging.WARN) task_name = get_task_name(args) args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name) args.log_dir = osp.join(args.log_dir, task_name) cmd = hfo_py.get_hfo_path( ) + ' --offense-npcs=1 --defense-npcs=1 --log-dir /home/yupeng/Desktop/workspace/src2/GAMIL-tf0/gail-tf/log/soccer_data/ --record --frames=200' print(cmd) # os.system(cmd) dataset = Mujoco_Dset(expert_data_path=args.expert_data_path, ret_threshold=args.ret_threshold, traj_limitation=args.traj_limitation) # previous: dataset = Mujoco_Dset(expert_path=args.expert_path, ret_threshold=args.ret_threshold, traj_limitation=args.traj_limitation) pretrained_weight = None if (args.pretrained and args.task == 'train') or args.algo == 'bc': # Pretrain with behavior cloning from gailtf.algo import behavior_clone if args.algo == 'bc' and args.task == 'evaluate': behavior_clone.evaluate(env, policy_fn, args.load_model_path_high, args.load_model_path_low, stochastic_policy=args.stochastic_policy) sys.exit() if args.task == 'train' and args.action_space_level == 'high': print("training high level policy") pretrained_weight_high = behavior_clone.learn( env, policy_fn, dataset, max_iters=args.BC_max_iter, pretrained=args.pretrained, ckpt_dir=args.checkpoint_dir + '/high_level', log_dir=args.log_dir + '/high_level', task_name=task_name, high_level=True) if args.task == 'train' and args.action_space_level == 'low': print("training low level policy") pretrained_weight_low = behavior_clone.learn( env, policy_fn, dataset, max_iters=args.BC_max_iter, pretrained=args.pretrained, ckpt_dir=args.checkpoint_dir + '/low_level', log_dir=args.log_dir + '/low_level', task_name=task_name, high_level=False) if args.algo == 'bc': sys.exit() from gailtf.network.adversary import TransitionClassifier # discriminator discriminator = TransitionClassifier(env, args.adversary_hidden_size, entcoeff=args.adversary_entcoeff) if args.algo == 'trpo': # Set up for MPI seed from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() if rank != 0: logger.set_level(logger.DISABLED) workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank() set_global_seeds(workerseed) env.seed(workerseed) from gailtf.algo import trpo_mpi if args.task == 'train': trpo_mpi.learn(env, policy_fn, discriminator, dataset, pretrained=args.pretrained, pretrained_weight=pretrained_weight, g_step=args.g_step, d_step=args.d_step, timesteps_per_batch=1024, max_kl=args.max_kl, cg_iters=10, cg_damping=0.1, max_timesteps=args.num_timesteps, entcoeff=args.policy_entcoeff, gamma=0.995, lam=0.97, vf_iters=5, vf_stepsize=1e-3, ckpt_dir=args.checkpoint_dir, log_dir=args.log_dir, save_per_iter=args.save_per_iter, load_model_path=args.load_model_path, task_name=task_name) elif args.task == 'evaluate': trpo_mpi.evaluate(env, policy_fn, args.load_model_path, timesteps_per_batch=1024, number_trajs=10, stochastic_policy=args.stochastic_policy) else: raise NotImplementedError else: raise NotImplementedError env.close()
def main(args): from gailtf.baselines.ppo1 import mlp_policy U.make_session(num_cpu=args.num_cpu).__enter__() set_global_seeds(args.seed) env = gym.make(args.env_id) def policy_fn(name, ob_space, ac_space, reuse=False): return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, reuse=reuse, hid_size=64, num_hid_layers=2) env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), "monitor.json")) env.seed(args.seed) gym.logger.setLevel(logging.WARN) task_name = get_task_name(args) args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name) args.log_dir = osp.join(args.log_dir, task_name) dataset = Mujoco_Traj_Dset(expert_path=args.expert_path, ret_threshold=args.ret_threshold, traj_limitation=args.traj_limitation, sentence_size = args.adversary_seq_size) if args.adversary_seq_size is None: args.adversary_seq_size = dataset.sentence_size pretrained_weight = None if (args.pretrained and args.task == 'train') or args.algo == 'bc': # Pretrain with behavior cloning from gailtf.algo import behavior_clone if args.algo == 'bc' and args.task == 'evaluate': behavior_clone.evaluate(env, policy_fn, args.load_model_path, stochastic_policy=args.stochastic_policy) sys.exit() pretrained_weight = behavior_clone.learn(env, policy_fn, dataset, max_iters=args.BC_max_iter, pretrained=args.pretrained, ckpt_dir=args.checkpoint_dir, log_dir=args.log_dir, task_name=task_name) if args.algo == 'bc': sys.exit() from gailtf.network.adversary_traj import TrajectoryClassifier # discriminator discriminator = TrajectoryClassifier(env, args.adversary_hidden_size, args.adversary_seq_size, args.adversary_attn_size, cell_type = args.adversary_cell_type, entcoeff=args.adversary_entcoeff) if args.algo == 'trpo': # Set up for MPI seed from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() if rank != 0: logger.set_level(logger.DISABLED) workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank() set_global_seeds(workerseed) env.seed(workerseed) from gailtf.algo import trpo_traj_mpi if args.task == 'train': trpo_traj_mpi.learn(env, policy_fn, discriminator, dataset, pretrained=args.pretrained, pretrained_weight=pretrained_weight, g_step=args.g_step, d_step=args.d_step, episodes_per_batch=100, dropout_keep_prob = 0.5, sequence_size = args.adversary_seq_size, max_kl=args.max_kl, cg_iters=10, cg_damping=0.1, max_timesteps=args.num_timesteps, entcoeff=args.policy_entcoeff, gamma=0.995, lam=0.97, vf_iters=5, vf_stepsize=1e-3, ckpt_dir=args.checkpoint_dir, log_dir=args.log_dir, save_per_iter=args.save_per_iter, load_model_path=args.load_model_path, task_name=task_name) elif args.task == 'evaluate': trpo_mpi.evaluate(env, policy_fn, args.load_model_path, timesteps_per_batch=1024, number_trajs=10, stochastic_policy=args.stochastic_policy) else: raise NotImplementedError else: raise NotImplementedError env.close()
def train(args): global env if args.expert_path is not None: assert osp.exists(args.expert_path) if args.load_model_path is not None: assert osp.exists(args.load_model_path + '.meta') args.pretrained = False printArgs(args) # ================================================ ENVIRONMENT ===================================================== U.make_session(num_cpu=args.num_cpu).__enter__() set_global_seeds(args.seed) if args.networkName == "MLP": env = gym.make(args.env_id) env = ActionWrapper(env, args.discrete) elif args.networkName == "CNN": env = make_atari(args.env_id) env = ActionWrapper(env, args.discrete) if args.deepmind: from gailtf.baselines.common.atari_wrappers import wrap_deepmind env = wrap_deepmind(env, False) env.metadata = 0 env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), "monitor.json"), allow_early_resets=True) env.seed(args.seed) gym.logger.setLevel(logging.WARN) discrete = (".D." if args.discrete else ".MD") # ============================================== PLAY AGENT ======================================================== # ================================================================================================================== if args.task == 'play_agent': logger.log("Playing agent...") from environments.atari.atari_agent import playAtari agent = policy_fn(args, PI, env, reuse=False) playAtari(env, agent, U, modelPath=args.load_model_path, fps=15, stochastic=args.stochastic_policy, zoom=2, delay=10) env.close() sys.exit() # ========================================== SAMPLE TRAJECTORY FROM RL ============================================= # ================================================================================================================== if args.task == 'RL_expert': logger.log("Sampling trajectory...") stoch = 'stochastic.' if args.stochastic_policy else 'deterministic.' taskName = stoch + "" + args.alg + "." + args.env_id + discrete + "." + str( args.maxSampleTrajectories) taskName = osp.join("data/expert", taskName) currentPolicy = policy_fn(args, PI, env, reuse=False) episodesGenerator = traj_episode_generator( currentPolicy, env, args.trajectoriesPerBatch, stochastic=args.stochastic_policy, render=args.visualize, downsample=args.downsample) sample_trajectory(args.load_model_path, episodesGenerator, taskName, args.stochastic_policy, max_sample_traj=args.maxSampleTrajectories) sys.exit() # ======================================== SAMPLE TRAJECTORY FROM HUMAN ============================================ # ================================================================================================================== if args.task == 'human_expert': logger.log("Human plays...") taskName = "human." + args.env_id + "_" + args.networkName + "." + "50.pkl" args.checkpoint_dir = osp.join(args.checkpoint_dir, taskName) taskName = osp.join("data/expert", taskName) from environments.atari.atari_human import playAtari sampleTrajectories = playAtari(env, fps=15, zoom=2, taskName=taskName) pkl.dump(sampleTrajectories, open(taskName, "wb")) env.close() sys.exit() # =========================================== TRAIN RL EXPERT ====================================================== # ================================================================================================================== if args.task == "train_RL_expert": logger.log("Training RL expert...") if args.alg == 'trpo': from gailtf.baselines.trpo_mpi import trpo_mpi taskName = args.alg + "." + args.env_id + "." + str( args.policy_hidden_size) + discrete + "." + str( args.maxSampleTrajectories) rank = MPI.COMM_WORLD.Get_rank() if rank != 0: logger.set_level(logger.DISABLED) workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank() set_global_seeds(workerseed) env = gym.make(args.env_id) env = bench.Monitor( env, logger.get_dir() and osp.join(logger.get_dir(), "%i.monitor.json" % rank)) env.seed(workerseed) gym.logger.setLevel(logging.WARN) args.checkpoint_dir = osp.join("data/training", taskName) trpo_mpi.learn(args, env, policy_fn, timesteps_per_batch=1024, max_iters=50_000, vf_iters=5, vf_stepsize=1e-3, task_name=taskName) env.close() sys.exit() else: return NotImplementedError # =================================================== GAIL ========================================================= # ================================================================================================================== if args.task == 'train_gail': taskName = get_task_name(args) args.checkpoint_dir = osp.join(args.checkpoint_dir, taskName) args.log_dir = osp.join(args.log_dir, taskName) args.task_name = taskName dataset = Mujoco_Dset(expert_path=args.expert_path, ret_threshold=args.ret_threshold, traj_limitation=args.traj_limitation) # discriminator if len(env.observation_space.shape) > 2: from gailtf.network.adversary_cnn import TransitionClassifier else: if args.wasserstein: from gailtf.network.w_adversary import TransitionClassifier else: from gailtf.network.adversary import TransitionClassifier discriminator = TransitionClassifier(env, args.adversary_hidden_size, entcoeff=args.adversary_entcoeff) pretrained_weight = None # pre-training with BC (optional): if (args.pretrained and args.task == 'train_gail') or args.alg == 'bc': # Pretrain with behavior cloning from gailtf.algo import behavior_clone if args.load_model_path is None: pretrained_weight = behavior_clone.learn( args, env, policy_fn, dataset) if args.alg == 'bc': sys.exit() if args.alg == 'trpo': # Set up for MPI seed rank = MPI.COMM_WORLD.Get_rank() if rank != 0: logger.set_level(logger.DISABLED) workerseed = args.seed + 10000 * MPI.COMM_WORLD.Get_rank() set_global_seeds(workerseed) env.seed(workerseed) # if args.wasserstein: # from gailtf.algo import w_trpo_mpi as trpo # else: from gailtf.algo import trpo_mpi as trpo trpo.learn(args, env, policy_fn, discriminator, dataset, pretrained_weight=pretrained_weight, cg_damping=0.1, vf_iters=5, vf_stepsize=1e-3) else: raise NotImplementedError env.close() sys.exit()