Example #1
0
 def test_PoseEnv(self):
     urdf_root = pose_env.get_pybullet_urdf_root()
     self.assertTrue(os.path.exists(urdf_root))
     env = pose_env.PoseToyEnv(urdf_root=urdf_root)
     obs = env.reset()
     policy = pose_env.PoseEnvRandomPolicy()
     action, _ = policy.sample_action(obs, 0)
     for _ in range(3):
         obs, _, done, _ = env.step(action)
         if done:
             obs = env.reset()
    def _test_policy_interface(self, policy, restore=True):
        urdf_root = pose_env.get_pybullet_urdf_root()
        self.assertTrue(os.path.exists(urdf_root))
        env = pose_env.PoseToyEnv(urdf_root=urdf_root, render_mode='DIRECT')
        env.reset_task()
        obs = env.reset()
        if restore:
            policy.restore()
        policy.reset_task()
        action = policy.SelectAction(obs, None, 0)

        new_obs, rew, done, env_debug = env.step(action)
        episode_data = [[(obs, action, rew, new_obs, done, env_debug)]]
        policy.adapt(episode_data)

        policy.SelectAction(new_obs, None, 1)
Example #3
0
    def test_run_pose_env_collect(self, demo_policy_cls):
        urdf_root = pose_env.get_pybullet_urdf_root()

        config_dir = 'research/pose_env/configs'
        gin_config = os.path.join(FLAGS.test_srcdir, config_dir,
                                  'run_random_collect.gin')
        gin.parse_config_file(gin_config)
        tmp_dir = absltest.get_default_test_tmpdir()
        root_dir = os.path.join(tmp_dir, str(demo_policy_cls))
        gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root)
        gin.bind_parameter('collect_eval_loop.root_dir', root_dir)
        gin.bind_parameter('run_meta_env.num_tasks', 2)
        gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1)
        gin.bind_parameter('collect_eval_loop.policy_class', demo_policy_cls)
        continuous_collect_eval.collect_eval_loop()
        output_files = tf.io.gfile.glob(
            os.path.join(root_dir, 'policy_collect', '*.tfrecord'))
        self.assertLen(output_files, 2)