Esempio n. 1
0
def main(_):
    logging.set_verbosity(logging.INFO)

    # !! IMPORTANT: making sure these are a multiple of num_epochs of the PPO agent
    FLAGS.eval_interval = FLAGS.eval_interval * FLAGS.num_epochs
    FLAGS.summary_interval = FLAGS.summary_interval * FLAGS.num_epochs
    FLAGS.checkpoint_interval = FLAGS.checkpoint_interval * FLAGS.num_epochs
    FLAGS.log_interval = FLAGS.log_interval * FLAGS.num_epochs

    main_fn = train_eval
    if FLAGS.just_eval or FLAGS.just_eval_random:
        main_fn = eval_only

    if FLAGS.use_multiprocessing:
        multiprocessing.enable_interactive_mode()
        multiprocessing.handle_main(lambda: main_fn(FLAGS))
    else:
        main_fn(FLAGS)
 def testArgExpected(self):
     no_argument_main_fn = lambda: None
     with self.assertRaises(TypeError):
         multiprocessing.handle_main(no_argument_main_fn)
Esempio n. 3
0
        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )


def main(_):
    logging.set_verbosity(logging.INFO)
    tf.compat.v1.enable_v2_behavior()
    train_eval(
        FLAGS.root_dir,
        env_name=FLAGS.env_name,
        use_rnns=FLAGS.use_rnns,
        num_environment_steps=FLAGS.num_environment_steps,
        collect_episodes_per_iteration=FLAGS.collect_episodes_per_iteration,
        num_parallel_environments=FLAGS.num_parallel_environments,
        replay_buffer_capacity=FLAGS.replay_buffer_capacity,
        num_epochs=FLAGS.num_epochs,
        num_eval_episodes=FLAGS.num_eval_episodes)


if __name__ == '__main__':
    flags.mark_flag_as_required('root_dir')
    multiprocessing.handle_main(functools.partial(app.run, main))
Esempio n. 4
0
    tf.enable_v2_behavior()

    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)

    # Wait for the collect policy to become available, then load it.
    collect_policy_dir = os.path.join(FLAGS.root_dir,
                                      learner.POLICY_SAVED_MODEL_DIR,
                                      learner.COLLECT_POLICY_SAVED_MODEL_DIR)
    collect_policy = train_utils.wait_for_policy(collect_policy_dir,
                                                 load_specs_from_pbtxt=True)

    # Prepare summary directory.
    summary_dir = os.path.join(FLAGS.root_dir, learner.TRAIN_DIR,
                               str(FLAGS.task))

    # Perform collection.
    collect(summary_dir=summary_dir,
            environment_name=gin.REQUIRED,
            collect_policy=collect_policy,
            replay_buffer_server_address=FLAGS.replay_buffer_server_address,
            variable_container_server_address=FLAGS.
            variable_container_server_address)


if __name__ == '__main__':
    flags.mark_flags_as_required([
        'root_dir', 'replay_buffer_server_address',
        'variable_container_server_address'
    ])
    multiprocessing.handle_main(lambda _: app.run(main))