Exemplo n.º 1
0
 def testSynchronousTrainCollectEval(self):
   """End-to-end integration test.
   """
   env = grasping_env.KukaGraspingProceduralEnv(downsample_width=64,
                                                downsample_height=64,
                                                continuous=True,
                                                remove_height_hack=True,
                                                render_mode='DIRECT')
   data_dir = 'testdata'
   gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'random_collect.gin')
   # Collect initial data from random policy without training.
   with open(gin_config, 'r') as f:
     gin.parse_config(f)
   train_collect_eval.train_collect_eval(collect_env=env,
                                         eval_env=None,
                                         test_env=None,
                                         root_dir=self._root_dir,
                                         train_fn=None)
   # Run training (synchronous train, collect, & eval).
   gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'train_dqn.gin')
   with open(gin_config, 'r') as f:
     gin.parse_config(f)
   train_collect_eval.train_collect_eval(collect_env=env,
                                         eval_env=None,
                                         test_env=None,
                                         root_dir=self._root_dir)
def main(_):
    np.random.seed(FLAGS.task)
    tf.set_random_seed(FLAGS.task)

    if FLAGS.distributed:
        task = FLAGS.task
    else:
        task = 0

    if FLAGS.gin_config:
        if tf.gfile.Exists(FLAGS.gin_config):
            # Parse as a file.
            with tf.gfile.Open(FLAGS.gin_config) as f:
                gin.parse_config(f)
        else:
            gin.parse_config(FLAGS.gin_config)

    gin.finalize()

    if FLAGS.run_mode == 'collect_eval_once':
        train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir,
                                              train_fn=None,
                                              task=FLAGS.task)
    elif FLAGS.run_mode == 'train_only':
        train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir,
                                              do_collect_eval=False,
                                              task=task,
                                              master=FLAGS.master,
                                              ps_tasks=FLAGS.ps_tasks)
    elif FLAGS.run_mode == 'collect_eval_loop':
        raise NotImplementedError('collect_eval_loops')
    else:
        # Synchronous train-collect-eval.
        train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir,
                                              task=task)