def main(unused_argv): if not tf.gfile.Exists(FLAGS.workdir): tf.gfile.MakeDirs(FLAGS.workdir) series_dict = {} for metric in XmSeries._fields: series_dict[metric] = utils.create_measurement_series( FLAGS.workdir, metric) xm_series = XmSeries(**series_dict) # type: ignore if FLAGS.fully_qualified_level == 'merged': # Use a fake level when merged training data is used. This corresponds to # the directory used by shuffle_examples.cc to write the merged dataset. level = constants.Level('merged') else: level = Const.find_level(FLAGS.fully_qualified_level) r_trainer = RTrainer(workdir=FLAGS.workdir, level=level, xm_series=xm_series) r_trainer.train() for series in xm_series._asdict().values(): if not series: continue utils.maybe_close_measurements(series)
def train(workdir, env_name, num_timesteps, nsteps=256, nminibatches=4, noptepochs=4, learning_rate=2.5e-4, ent_coef=0.01): """Runs PPO training. Args: workdir: where to store experiment results/logs env_name: the name of the environment to run num_timesteps: for how many timesteps to run training nsteps: Number of consecutive environment steps to use during training. nminibatches: Minibatch size. noptepochs: Number of optimization epochs. learning_rate: Initial learning rate. ent_coef: Entropy coefficient. """ train_measurements = utils.create_measurement_series(workdir, 'reward_train') valid_measurements = utils.create_measurement_series(workdir, 'reward_valid') test_measurements = utils.create_measurement_series(workdir, 'reward_test') def measurement_callback(unused_eplenmean, eprewmean, global_step_val): if train_measurements: train_measurements.create_measurement( objective_value=eprewmean, step=global_step_val) logger.logkv('eprewmean_train', eprewmean) def eval_callback_on_valid(eprewmean, global_step_val): if valid_measurements: valid_measurements.create_measurement( objective_value=eprewmean, step=global_step_val) logger.logkv('eprewmean_valid', eprewmean) def eval_callback_on_test(eprewmean, global_step_val): if test_measurements: test_measurements.create_measurement( objective_value=eprewmean, step=global_step_val) logger.logkv('eprewmean_test', eprewmean) logger_dir = workdir #logger.configure(logger_dir) logger.configure(dir=logger_dir, format_strs=['tensorboard', 'stdout', 'log', 'csv']) logger.Logger.DEFAULT = logger.Logger.CURRENT env, valid_env, test_env = get_environment(env_name) is_ant = env_name.startswith('parkour:') # Validation metric. policy_evaluator_on_valid = eval_policy.PolicyEvaluator( valid_env, metric_callback=eval_callback_on_valid, video_filename=None) # Test metric (+ videos). video_filename = os.path.join(FLAGS.workdir, 'video') policy_evaluator_on_test = eval_policy.PolicyEvaluator( test_env, metric_callback=eval_callback_on_test, video_filename=video_filename, grayscale=(env_name.startswith('atari:'))) # Delay to make sure that all the DMLab environments acquire # the GPU resources before TensorFlow acquire the rest of the memory. # TODO(damienv): Possibly use allow_grow in a TensorFlow session # so that there is no such problem anymore. time.sleep(15) cloud_sync_callback = lambda: None def evaluate_valid_test(model_step_fn, global_step): if not is_ant: policy_evaluator_on_valid.evaluate(model_step_fn, global_step) policy_evaluator_on_test.evaluate(model_step_fn, global_step) sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True #with tf.Session(): with tf.Session(config=sess_config): policy = {'cnn': policies.CnnPolicy, 'lstm': policies.LstmPolicy, 'lnlstm': policies.LnLstmPolicy, 'mlp': policies.MlpPolicy}[FLAGS.policy_architecture] # Openai baselines never performs num_timesteps env steps because # of the way it samples training data in batches. The number of timesteps # is multiplied by 1.1 (hacky) to insure at least num_timesteps are # performed. ppo2.learn(policy, env=env, nsteps=nsteps, nminibatches=nminibatches, lam=0.95, gamma=0.99, noptepochs=noptepochs, log_interval=1, ent_coef=ent_coef, lr=learning_rate if is_ant else lambda f: f * learning_rate, cliprange=0.2 if is_ant else lambda f: f * 0.1, total_timesteps=int(num_timesteps * 1.1), train_callback=measurement_callback, eval_callback=evaluate_valid_test, cloud_sync_callback=cloud_sync_callback, save_interval=200, workdir=workdir, use_curiosity=FLAGS.use_curiosity, curiosity_strength=FLAGS.curiosity_strength, forward_inverse_ratio=FLAGS.forward_inverse_ratio, curiosity_loss_strength=FLAGS.curiosity_loss_strength, random_state_predictor=FLAGS.random_state_predictor) cloud_sync_callback() test_env.close() valid_env.close() utils.maybe_close_measurements(train_measurements) utils.maybe_close_measurements(valid_measurements) utils.maybe_close_measurements(test_measurements)