def testSynchronousTrainCollectEval(self): """End-to-end integration test. """ env = grasping_env.KukaGraspingProceduralEnv(downsample_width=64, downsample_height=64, continuous=True, remove_height_hack=True, render_mode='DIRECT') data_dir = 'testdata' gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'random_collect.gin') # Collect initial data from random policy without training. with open(gin_config, 'r') as f: gin.parse_config(f) train_collect_eval.train_collect_eval(collect_env=env, eval_env=None, test_env=None, root_dir=self._root_dir, train_fn=None) # Run training (synchronous train, collect, & eval). gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'train_dqn.gin') with open(gin_config, 'r') as f: gin.parse_config(f) train_collect_eval.train_collect_eval(collect_env=env, eval_env=None, test_env=None, root_dir=self._root_dir)
def resetEnvironment(): global obs, env env = grasping_env.KukaGraspingProceduralEnv(downsample_width=48, downsample_height=48, continuous=True, remove_height_hack=True, render_mode='GUI') print(env.action_space) obs = env.reset() done, env_step, episode_reward, episode_data = (False, 0, 0.0, [])
def testDDPGPolicy(self): np.random.seed(0) env = grasping_env.KukaGraspingProceduralEnv( downsample_width=48, downsample_height=48, continuous=True, remove_height_hack=True, render_mode='DIRECT') policy = policies.DDPGPolicy(tf_critics.cnn_v0, tf_critics.cnn_ia_v1, state_shape=(1, 48, 48, 3), action_size=4, use_gpu=False, build_target=False, include_timestep=True) policy.reset() obs = env.reset() action, debug = policy.sample_action(obs, 0) self.assertLen(action, 4) self.assertIn('q', debug)
def testPolicyRun(self, tag, use_root_dir): env = grasping_env.KukaGraspingProceduralEnv(downsample_width=48, downsample_height=48, continuous=True, remove_height_hack=True, render_mode='DIRECT') policy = policies.RandomGraspingPolicyD4() root_dir = os.path.join(FLAGS.test_tmpdir, tag) if use_root_dir else None run_env.run_env(env, policy=policy, explore_schedule=None, episode_to_transitions_fn=None, replay_writer=None, root_dir=root_dir, tag=tag, task=0, global_step=0, num_episodes=1)