def testSynchronousTrainCollectEval(self): """End-to-end integration test. """ env = grasping_env.KukaGraspingProceduralEnv(downsample_width=64, downsample_height=64, continuous=True, remove_height_hack=True, render_mode='DIRECT') data_dir = 'testdata' gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'random_collect.gin') # Collect initial data from random policy without training. with open(gin_config, 'r') as f: gin.parse_config(f) train_collect_eval.train_collect_eval(collect_env=env, eval_env=None, test_env=None, root_dir=self._root_dir, train_fn=None) # Run training (synchronous train, collect, & eval). gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'train_dqn.gin') with open(gin_config, 'r') as f: gin.parse_config(f) train_collect_eval.train_collect_eval(collect_env=env, eval_env=None, test_env=None, root_dir=self._root_dir)
def main(_): np.random.seed(FLAGS.task) tf.set_random_seed(FLAGS.task) if FLAGS.distributed: task = FLAGS.task else: task = 0 if FLAGS.gin_config: if tf.gfile.Exists(FLAGS.gin_config): # Parse as a file. with tf.gfile.Open(FLAGS.gin_config) as f: gin.parse_config(f) else: gin.parse_config(FLAGS.gin_config) gin.finalize() if FLAGS.run_mode == 'collect_eval_once': train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir, train_fn=None, task=FLAGS.task) elif FLAGS.run_mode == 'train_only': train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir, do_collect_eval=False, task=task, master=FLAGS.master, ps_tasks=FLAGS.ps_tasks) elif FLAGS.run_mode == 'collect_eval_loop': raise NotImplementedError('collect_eval_loops') else: # Synchronous train-collect-eval. train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir, task=task)