Example #1
0
 def testSynchronousTrainCollectEval(self):
   """End-to-end integration test.
   """
   env = grasping_env.KukaGraspingProceduralEnv(downsample_width=64,
                                                downsample_height=64,
                                                continuous=True,
                                                remove_height_hack=True,
                                                render_mode='DIRECT')
   data_dir = 'testdata'
   gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'random_collect.gin')
   # Collect initial data from random policy without training.
   with open(gin_config, 'r') as f:
     gin.parse_config(f)
   train_collect_eval.train_collect_eval(collect_env=env,
                                         eval_env=None,
                                         test_env=None,
                                         root_dir=self._root_dir,
                                         train_fn=None)
   # Run training (synchronous train, collect, & eval).
   gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'train_dqn.gin')
   with open(gin_config, 'r') as f:
     gin.parse_config(f)
   train_collect_eval.train_collect_eval(collect_env=env,
                                         eval_env=None,
                                         test_env=None,
                                         root_dir=self._root_dir)
Example #2
0
def resetEnvironment():
    global obs, env
    env = grasping_env.KukaGraspingProceduralEnv(downsample_width=48,
                                                 downsample_height=48,
                                                 continuous=True,
                                                 remove_height_hack=True,
                                                 render_mode='GUI')
    print(env.action_space)
    obs = env.reset()
    done, env_step, episode_reward, episode_data = (False, 0, 0.0, [])
Example #3
0
 def testDDPGPolicy(self):
   np.random.seed(0)
   env = grasping_env.KukaGraspingProceduralEnv(
       downsample_width=48, downsample_height=48,
       continuous=True, remove_height_hack=True, render_mode='DIRECT')
   policy = policies.DDPGPolicy(tf_critics.cnn_v0,
                                tf_critics.cnn_ia_v1,
                                state_shape=(1, 48, 48, 3),
                                action_size=4,
                                use_gpu=False,
                                build_target=False,
                                include_timestep=True)
   policy.reset()
   obs = env.reset()
   action, debug = policy.sample_action(obs, 0)
   self.assertLen(action, 4)
   self.assertIn('q', debug)
 def testPolicyRun(self, tag, use_root_dir):
     env = grasping_env.KukaGraspingProceduralEnv(downsample_width=48,
                                                  downsample_height=48,
                                                  continuous=True,
                                                  remove_height_hack=True,
                                                  render_mode='DIRECT')
     policy = policies.RandomGraspingPolicyD4()
     root_dir = os.path.join(FLAGS.test_tmpdir,
                             tag) if use_root_dir else None
     run_env.run_env(env,
                     policy=policy,
                     explore_schedule=None,
                     episode_to_transitions_fn=None,
                     replay_writer=None,
                     root_dir=root_dir,
                     tag=tag,
                     task=0,
                     global_step=0,
                     num_episodes=1)