コード例 #1
0
    def testFollowsScript(self):
        action_spec = [
            array_spec.BoundedArraySpec((2, 2), np.int32, -10, 10),
            array_spec.BoundedArraySpec((1, 2), np.int32, -10, 10)
        ]

        action_script = [
            (1, [
                np.array([[5, 2], [1, 3]], dtype=np.int32),
                np.array([[4, 6]], dtype=np.int32)
            ]),
            (0, [
                np.array([[0, 0], [0, 0]], dtype=np.int32),
                np.array([[0, 0]], dtype=np.int32)
            ]),
            (2, [
                np.array([[1, 2], [3, 4]], dtype=np.int32),
                np.array([[5, 6]], dtype=np.int32)
            ]),
        ]

        policy = scripted_py_policy.ScriptedPyPolicy(
            time_step_spec=self._time_step_spec,
            action_spec=action_spec,
            action_script=action_script)
        policy_state = policy.get_initial_state()

        action_step = policy.action(self._time_step, policy_state)
        self.assertEqual(action_script[0][1], action_step.action)
        action_step = policy.action(self._time_step, action_step.state)
        self.assertEqual(action_script[2][1], action_step.action)
        action_step = policy.action(self._time_step, action_step.state)
        self.assertEqual(action_script[2][1], action_step.action)
コード例 #2
0
    def testEpisodeLength(self):
        action_spec = [
            array_spec.BoundedArraySpec((2, 2), np.int32, -10, 10),
            array_spec.BoundedArraySpec((1, 2), np.int32, -10, 10)
        ]

        action_script = [
            (1, [
                np.array([[5, 2], [1, 3]], dtype=np.int32),
                np.array([[4, 6]], dtype=np.int32)
            ]),
            (2, [
                np.array([[1, 2], [3, 4]], dtype=np.int32),
                np.array([[5, 6]], dtype=np.int32)
            ]),
        ]

        policy = scripted_py_policy.ScriptedPyPolicy(
            time_step_spec=self._time_step_spec,
            action_spec=action_spec,
            action_script=action_script)
        policy_state = policy.get_initial_state()

        action_step = policy.action(self._time_step, policy_state)
        self.assertEqual(action_script[0][1], action_step.action)
        action_step = policy.action(self._time_step, action_step.state)
        self.assertEqual(action_script[1][1], action_step.action)
        action_step = policy.action(self._time_step, action_step.state)
        self.assertEqual(action_script[1][1], action_step.action)
        with self.assertRaisesRegexp(ValueError, '.*Episode is longer than.*'):
            policy.action(self._time_step, action_step.state)
コード例 #3
0
    def testChecksSpecBounds(self):
        action_spec = [
            array_spec.BoundedArraySpec((2, 2), np.int32, -10, 10),
            array_spec.BoundedArraySpec((1, 2), np.int32, -10, 10)
        ]

        action_script = [
            (1, [
                np.array([[15, 2], [1, 3]], dtype=np.int32),
                np.array([[4, 6]], dtype=np.int32)
            ]),
            (2, [
                np.array([[1, 2], [3, 4]], dtype=np.int32),
                np.array([[5, 6]], dtype=np.int32)
            ]),
        ]

        policy = scripted_py_policy.ScriptedPyPolicy(
            time_step_spec=self._time_step_spec,
            action_spec=action_spec,
            action_script=action_script)
        policy_state = policy.get_initial_state()

        with self.assertRaises(ValueError):
            policy.action(self._time_step, policy_state)
コード例 #4
0
    def testFollowsScriptWithListInsteadOfNpArrays(self):
        action_spec = [
            array_spec.BoundedArraySpec((2, 2), np.int32, -10, 10),
            array_spec.BoundedArraySpec((1, 2), np.int32, -10, 10)
        ]

        action_script = [
            (1, [
                [[5, 2], [1, 3]],
                [[4, 6]],
            ]),
            (2, [[[1, 2], [3, 4]], [[5, 6]]]),
        ]

        expected = [
            [
                np.array([[5, 2], [1, 3]], dtype=np.int32),
                np.array([[4, 6]], dtype=np.int32)
            ],
            [
                np.array([[1, 2], [3, 4]], dtype=np.int32),
                np.array([[5, 6]], dtype=np.int32)
            ],
        ]

        policy = scripted_py_policy.ScriptedPyPolicy(
            time_step_spec=self._time_step_spec,
            action_spec=action_spec,
            action_script=action_script)  # pytype: disable=wrong-arg-types
        policy_state = policy.get_initial_state()

        action_step = policy.action(self._time_step, policy_state)
        np.testing.assert_array_equal(expected[0][0], action_step.action[0])
        np.testing.assert_array_equal(expected[0][1], action_step.action[1])
        action_step = policy.action(self._time_step, action_step.state)
        np.testing.assert_array_equal(expected[1][0], action_step.action[0])
        np.testing.assert_array_equal(expected[1][1], action_step.action[1])
        action_step = policy.action(self._time_step, action_step.state)
        np.testing.assert_array_equal(expected[1][0], action_step.action[0])
        np.testing.assert_array_equal(expected[1][1], action_step.action[1])
コード例 #5
0
def create_action_gif():
    """Create a gif showing discretization of agent actions."""
    n_actions = 8  # should be even
    env_params = {
        'monster_speed': 0.0,
        'timeout_factor': 8,
        'step_size': 0.5,
        'n_actions': n_actions
    }
    py_env = LakeMonsterEnvironment(**env_params)
    action_script = [(1, 0), (1, n_actions // 2)]

    for _ in range(n_actions - 1):
        action_script.append((1, 1))  # step forward
        action_script.append((1, n_actions // 2))  # back
    policy = scripted_py_policy.ScriptedPyPolicy(
        time_step_spec=None,
        action_spec=py_env.action_spec(),
        action_script=action_script)

    save_path = os.path.join(configs.ASSETS_DIR, 'actions.gif')
    episode_as_gif(py_env, policy, save_path, fps=1, show_path=False)
コード例 #6
0
def collect_pair_episodes(policy,
                          env_name,
                          max_steps=None,
                          random_seed=None,
                          frame_shape=(84, 84, 3),
                          max_episodes=10):
    env = utils.load_dm_env_for_eval(env_name,
                                     frame_shape=frame_shape,
                                     task_kwargs={'random': random_seed})

    buffer, metrics = run_env(env,
                              policy,
                              max_steps=max_steps,
                              max_episodes=max_episodes)

    # Collect episodes with the same optimal policy
    env_copy = utils.load_dm_env_for_eval(env_name,
                                          frame_shape=(84, 84, 3),
                                          task_kwargs={'random': random_seed})

    actions = [x.action for x in buffer]
    action_script = list(zip([1] * len(actions), actions))
    optimal_policy = scripted_py_policy.ScriptedPyPolicy(
        time_step_spec=env.time_step_spec(),
        action_spec=env.action_spec(),
        action_script=action_script)
    paired_buffer, paired_metrics = run_env(env_copy,
                                            optimal_policy,
                                            max_steps=max_steps,
                                            max_episodes=max_episodes)

    for metric, paired_metric in zip(metrics, paired_metrics):
        assert metric.result() == paired_metric.result(), (
            'Metric results don\'t match')
        logging.info('%s: %.2f', metric.name, metric.result())

    episodes = get_complete_episodes(buffer, max_episodes)
    paired_episodes = get_complete_episodes(paired_buffer, max_episodes)
    return episodes, paired_episodes
コード例 #7
0
 def testPolicyStateSpecIsEmpty(self):
     policy = scripted_py_policy.ScriptedPyPolicy(
         time_step_spec=self._time_step_spec,
         action_spec=[],
         action_script=[])
     self.assertEqual(policy.policy_state_spec, ())