def do_rollout(env, action_function): """ Builds a path by running through an environment using a provided function to select actions. """ obs, rewards, actions, human_obs = [], [], [], [] max_timesteps_per_episode = get_timesteps_per_episode(env) ob = env.reset() # Primary environment loop for i in range(max_timesteps_per_episode): action = action_function(env, ob) obs.append(ob) actions.append(action) ob, rew, done, info = env.step(action) rewards.append(rew) human_obs.append(info.get("human_obs")) if done: break # Build path dictionary path = { "obs": np.array(obs), "original_rewards": np.array(rewards), "actions": np.array(actions), "human_obs": np.array(human_obs)} return path
def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('-e', '--env_id', required=True) parser.add_argument('-p', '--predictor', required=True) parser.add_argument('-n', '--name', required=True) parser.add_argument('-s', '--seed', default=1, type=int) parser.add_argument('-w', '--workers', default=4, type=int) parser.add_argument('-l', '--n_labels', default=None, type=int) parser.add_argument('-L', '--pretrain_labels', default=None, type=int) parser.add_argument('-t', '--num_timesteps', default=5e6, type=int) parser.add_argument('-a', '--agent', default="parallel_trpo", type=str) parser.add_argument('-i', '--pretrain_iters', default=10000, type=int) parser.add_argument('-V', '--no_videos', action="store_true") parser.add_argument('-x', '--human_labels', default=1000, type=int) args = parser.parse_args() print("Setting things up...") env_id = args.env_id run_name = "%s/%s-%s" % (env_id, args.name, int(time())) summary_writer = make_summary_writer(run_name) env = make_with_torque_removed(env_id) num_timesteps = int(args.num_timesteps) experiment_name = slugify(args.name) if args.predictor == "rl": predictor = TraditionalRLRewardPredictor(summary_writer) else: agent_logger = AgentLogger(summary_writer) pretrain_labels = args.pretrain_labels if args.pretrain_labels else args.n_labels // 4 if args.n_labels: label_schedule = LabelAnnealer(agent_logger, final_timesteps=num_timesteps, final_labels=args.n_labels, pretrain_labels=pretrain_labels) else: print( "No label limit given. We will request one label every few seconds." ) label_schedule = ConstantLabelSchedule( pretrain_labels=pretrain_labels) if args.predictor == "synth": comparison_collector = SyntheticComparisonCollector( run_name, args.human_labels) elif args.predictor == "human": bucket = os.environ.get('RL_TEACHER_GCS_BUCKET') assert bucket and bucket.startswith( "gs://" ), "env variable RL_TEACHER_GCS_BUCKET must start with gs://" comparison_collector = HumanComparisonCollector( env_id, experiment_name=experiment_name) else: raise ValueError("Bad value for --predictor: %s" % args.predictor) predictor = ComparisonRewardPredictor( env, summary_writer, comparison_collector=comparison_collector, agent_logger=agent_logger, label_schedule=label_schedule, ) print( "Starting random rollouts to generate pretraining segments. No learning will take place..." ) pretrain_segments = segments_from_rand_rollout( env_id, make_with_torque_removed, n_desired_segments=pretrain_labels * 2, clip_length_in_seconds=CLIP_LENGTH, workers=args.workers) for i in range( pretrain_labels): # Turn our random segments into comparisons comparison_collector.add_segment_pair( pretrain_segments[i], pretrain_segments[i + pretrain_labels]) # Sleep until the human has labeled most of the pretraining comparisons while len(comparison_collector.labeled_comparisons) < int( pretrain_labels * 0.75): comparison_collector.label_unlabeled_comparisons() if args.predictor == "synth": print("%s synthetic labels generated... " % (len(comparison_collector.labeled_comparisons))) elif args.predictor == "human": print( "%s/%s comparisons labeled. Please add labels w/ the human-feedback-api. Sleeping... " % (len(comparison_collector.labeled_comparisons), pretrain_labels)) sleep(5) # Start the actual training for i in range(args.pretrain_iters): predictor.train_predictor() # Train on pretraining labels if i % 100 == 0: print("%s/%s predictor pretraining iters... " % (i, args.pretrain_iters)) # Wrap the predictor to capture videos every so often: if not args.no_videos: predictor = SegmentVideoRecorder(predictor, env, save_dir=osp.join( '/tmp/rl_teacher_vids', run_name)) # We use a vanilla agent from openai/baselines that contains a single change that blinds it to the true reward # The single changed section is in `rl_teacher/agent/trpo/core.py` print("Starting joint training of predictor and agent") if args.agent == "parallel_trpo": train_parallel_trpo( env_id=env_id, make_env=make_with_torque_removed, predictor=predictor, summary_writer=summary_writer, workers=args.workers, runtime=(num_timesteps / 1000), max_timesteps_per_episode=get_timesteps_per_episode(env), timesteps_per_batch=8000, max_kl=0.001, seed=args.seed, ) elif args.agent == "pposgd_mpi": def make_env(): return make_with_torque_removed(env_id) train_pposgd_mpi(make_env, num_timesteps=num_timesteps, seed=args.seed, predictor=predictor) else: raise ValueError("%s is not a valid choice for args.agent" % args.agent)
def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('-e', '--env_id', default="ShortHopper-v1", type=str) parser.add_argument('-p', '--predictor', default="human", type=str) parser.add_argument('-n', '--name', default="human-175-hopper", type=str) parser.add_argument('-s', '--seed', default=6, type=int) parser.add_argument('-w', '--workers', default=4, type=int) parser.add_argument('-l', '--n_labels', default=None, type=int) parser.add_argument('-L', '--pretrain_labels', default=20, type=int) parser.add_argument('-t', '--num_timesteps', default=5e6, type=int) parser.add_argument('-a', '--agent', default="pposgd_mpi", type=str) parser.add_argument('-i', '--pretrain_iters', default=1, type=int) parser.add_argument('-V', '--no_videos', action="store_true") parser.add_argument('--log_path', help='Directory to save learning curve data.', default='tmp/openaiTest', type=str) args = parser.parse_args() print("Setting things up...") env_id = args.env_id run_name = "%s/%s-%s" % (env_id, args.name, int(time())) summary_writer = make_summary_writer(run_name) env = make_with_torque_removed(env_id) num_timesteps = int(args.num_timesteps) experiment_name = slugify(args.name) if args.predictor == "rl": predictor = TraditionalRLRewardPredictor(summary_writer) else: agent_logger = AgentLogger(summary_writer) pretrain_labels = args.pretrain_labels if args.pretrain_labels else args.n_labels // 4 #online and offline if args.n_labels: label_schedule = LabelAnnealer(agent_logger, final_timesteps=num_timesteps, final_labels=args.n_labels, pretrain_labels=pretrain_labels) else: print( "No label limit given. We will request one label every few seconds." ) label_schedule = ConstantLabelSchedule( pretrain_labels=pretrain_labels) if args.predictor == "synth": comparison_collector = SyntheticComparisonCollector() elif args.predictor == "human": bucket = os.environ.get('RL_TEACHER_GCS_BUCKET') bucket = "gs://rl-teacher-preference" #assert bucket and bucket.startswith("gs://"), "env variable RL_TEACHER_GCS_BUCKET must start with gs://" comparison_collector = HumanComparisonCollector( env_id, experiment_name=experiment_name) else: raise ValueError("Bad value for --predictor: %s" % args.predictor) predictor = ComparisonRewardPredictor( env, summary_writer, comparison_collector=comparison_collector, agent_logger=agent_logger, label_schedule=label_schedule, ) # print("Starting random rollouts to generate pretraining segments. No learning will take place...") # pretrain_segments = segments_from_rand_rollout( # env_id, make_with_torque_removed, n_desired_segments=pretrain_labels * 2, # clip_length_in_seconds=CLIP_LENGTH, workers=args.workers) # for i in range(pretrain_labels): # Turn our random segments into comparisons # comparison_collector.add_segment_pair(pretrain_segments[i], pretrain_segments[i + pretrain_labels]) # # # Sleep until the human has labeled most of the pretraining comparisons # while len(comparison_collector.labeled_comparisons) < int(pretrain_labels * 0.75): # comparison_collector.label_unlabeled_comparisons() # if args.predictor == "synth": # print("%s synthetic labels generated... " % (len(comparison_collector.labeled_comparisons))) # elif args.predictor == "human": # print("%s/%s comparisons labeled. Please add labels w/ the human-feedback-api. Sleeping... " % ( # len(comparison_collector.labeled_comparisons), pretrain_labels)) # sleep(5) # # # Start the actual training # # for i in range(args.pretrain_iters): # predictor.train_predictor() # Train on pretraining labels # if i % 10 == 0: # print("%s/%s predictor pretraining iters... " % (i, args.pretrain_iters)) #saver = tf.train.Saver(max_to_keep=5) #save_path = saver.save(sess, "/tmp/GAN/GAN_preference_based_model.ckpt") #print("Model saved in path: %s" % save_path) # Wrap the predictor to capture videos every so often: if not args.no_videos: predictor = SegmentVideoRecorder(predictor, env, save_dir=osp.join( '/tmp/rl_teacher_vids', run_name)) # We use a vanilla agent from openai/baselines that contains a single change that blinds it to the true reward # The single changed section is in `rl_teacher/agent/trpo/core.py` print("Starting joint training of predictor and agent") if args.agent == "parallel_trpo": train_parallel_trpo( env_id=env_id, make_env=make_with_torque_removed, predictor=predictor, summary_writer=summary_writer, workers=args.workers, runtime=(num_timesteps / 1000), max_timesteps_per_episode=get_timesteps_per_episode(env), timesteps_per_batch=8000, max_kl=0.001, seed=args.seed, ) elif args.agent == "pposgd_mpi": def make_env(): return make_with_torque_removed(env_id) try: from mpi4py import MPI except ImportError: MPI = None def configure_logger(log_path, **kwargs): if log_path is not None: logger.configure(log_path) else: logger.configure(**kwargs) if MPI is None or MPI.COMM_WORLD.Get_rank() == 0: rank = 0 configure_logger(args.log_path) else: rank = MPI.COMM_WORLD.Get_rank() configure_logger(args.log_path, format_strs=[]) train_pposgd_mpi(make_env, num_timesteps=num_timesteps, seed=args.seed, predictor=predictor) else: raise ValueError("%s is not a valid choice for args.agent" % args.agent)