コード例 #1
0
    def testAction(self):
        py_observation_spec = array_spec.BoundedArraySpec((3, ), np.int32, 1,
                                                          1)
        py_time_step_spec = ts.time_step_spec(py_observation_spec)
        py_action_spec = array_spec.BoundedArraySpec((7, ), np.int32, 1, 1)
        py_policy_state_spec = array_spec.BoundedArraySpec((5, ), np.int32, 0,
                                                           1)
        py_policy_info_spec = array_spec.BoundedArraySpec((3, ), np.int32, 0,
                                                          1)

        mock_py_policy = mock.create_autospec(py_policy.PyPolicy)
        mock_py_policy.time_step_spec = py_time_step_spec
        mock_py_policy.action_spec = py_action_spec
        mock_py_policy.policy_state_spec = py_policy_state_spec
        mock_py_policy.info_spec = py_policy_info_spec

        expected_py_policy_state = np.ones(py_policy_state_spec.shape,
                                           py_policy_state_spec.dtype)
        expected_py_time_step = tf.nest.map_structure(
            lambda arr_spec: np.ones((1, ) + arr_spec.shape, arr_spec.dtype),
            py_time_step_spec)
        expected_py_action = np.ones((1, ) + py_action_spec.shape,
                                     py_action_spec.dtype)
        expected_new_py_policy_state = np.zeros(py_policy_state_spec.shape,
                                                py_policy_state_spec.dtype)
        expected_py_info = np.zeros(py_policy_info_spec.shape,
                                    py_policy_info_spec.dtype)

        mock_py_policy.action.return_value = policy_step.PolicyStep(
            nest_utils.unbatch_nested_array(expected_py_action),
            expected_new_py_policy_state, expected_py_info)

        tf_mock_py_policy = tf_py_policy.TFPyPolicy(mock_py_policy)
        time_step = tf.nest.map_structure(
            lambda arr_spec: tf.ones((1, ) + arr_spec.shape, arr_spec.dtype),
            py_time_step_spec)
        action_step = tf_mock_py_policy.action(
            time_step, tf.ones(py_policy_state_spec.shape, tf.int32))
        py_action_step = self.evaluate(action_step)

        self.assertEqual(1, mock_py_policy.action.call_count)
        np.testing.assert_equal(
            mock_py_policy.action.call_args[1]['time_step'],
            nest_utils.unbatch_nested_array(expected_py_time_step))
        np.testing.assert_equal(
            mock_py_policy.action.call_args[1]['policy_state'],
            expected_py_policy_state)
        np.testing.assert_equal(py_action_step.action, expected_py_action)
        np.testing.assert_equal(py_action_step.state,
                                expected_new_py_policy_state)
        np.testing.assert_equal(py_action_step.info, expected_py_info)
コード例 #2
0
    def _action(self,
                time_step: ts.TimeStep,
                policy_state: types.NestedArray,
                seed: Optional[types.Seed] = None) -> ps.PolicyStep:
        """Forward a batch of time_step and policy_states to the wrapped policies.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: An Array, or a nested dict, list or tuple of Arrays
        representing the previous policy_state.
      seed: Seed value used to initialize a pseudorandom number generator.

    Returns:
      A batch of `PolicyStep` named tuples, each one containing:
        `action`: A nest of action Arrays matching the `action_spec()`.
        `state`: A nest of policy states to be fed into the next call to action.
        `info`: Optional side information such as action log probabilities.

    Raises:
      NotImplementedError: if `seed` is not None.
    """
        if seed is not None:
            raise NotImplementedError(
                "seed is not supported; but saw seed: {}".format(seed))
        if self._num_policies == 1:
            time_step = nest_utils.unbatch_nested_array(time_step)
            policy_state = nest_utils.unbatch_nested_array(policy_state)
            policy_steps = self._policies[0].action(time_step, policy_state)
            return nest_utils.batch_nested_array(policy_steps)
        else:
            unstacked_time_steps = nest_utils.unstack_nested_arrays(time_step)
            if len(unstacked_time_steps) != len(self._policies):
                raise ValueError(
                    "Primary dimension of time_step items does not match "
                    "batch size: %d vs. %d" %
                    (len(unstacked_time_steps), len(self._policies)))
            unstacked_policy_states = [()] * len(unstacked_time_steps)
            if policy_state:
                unstacked_policy_states = nest_utils.unstack_nested_arrays(
                    policy_state)
                if len(unstacked_policy_states) != len(self._policies):
                    raise ValueError(
                        "Primary dimension of policy_state items does not match "
                        "batch size: %d vs. %d" %
                        (len(unstacked_policy_states), len(self._policies)))
            policy_steps = self._execute(
                _execute_policy,
                zip(self._policies, unstacked_time_steps,
                    unstacked_policy_states))
            return nest_utils.stack_nested_arrays(policy_steps)
コード例 #3
0
    def _action(self, time_step, policy_state):
        if not self._built:
            self._build_from_time_step(time_step)

        batch_size = None
        if time_step.step_type.shape:
            batch_size = time_step.step_type.shape[0]
        if self._batch_size != batch_size:
            raise ValueError(
                'The batch size of time_step is different from the batch size '
                'provided previously. Expected {}, but saw {}.'.format(
                    self._batch_size, batch_size))

        if not self._batched:
            # Since policy_state is given in a batched form from the policy and we
            # simply have to send it back we do not need to worry about it. Only
            # update time_step.
            time_step = nest_utils.batch_nested_array(time_step)

        tf.nest.assert_same_structure(self._time_step, time_step)
        feed_dict = {self._time_step: time_step}
        if policy_state is not None:
            # Flatten policy_state to handle specs that are not hashable due to lists.
            for state_ph, state in zip(tf.nest.flatten(self._policy_state),
                                       tf.nest.flatten(policy_state)):
                feed_dict[state_ph] = state

        action_step = self.session.run(self._action_step, feed_dict)
        action, state, info = action_step

        if not self._batched:
            action, info = nest_utils.unbatch_nested_array([action, info])

        return policy_step.PolicyStep(action, state, info)
コード例 #4
0
  def _step(self, actions):
    """Forward a batch of actions to the wrapped environments.

    Args:
      actions: Batched action, possibly nested, to apply to the environment.

    Raises:
      ValueError: Invalid actions.

    Returns:
      Batch of observations, rewards, and done flags.
    """

    if self._num_envs == 1:
      actions = nest_utils.unbatch_nested_array(actions)
      time_steps = self._envs[0].step(actions)
      return nest_utils.batch_nested_array(time_steps)
    else:
      unstacked_actions = unstack_actions(actions)
      if len(unstacked_actions) != self.batch_size:
        raise ValueError(
            "Primary dimension of action items does not match "
            "batch size: %d vs. %d" % (len(unstacked_actions), self.batch_size))
      time_steps = self._execute(
          lambda env_action: env_action[0].step(env_action[1]),
          zip(self._envs, unstacked_actions))
      return nest_utils.stack_nested_arrays(time_steps)
コード例 #5
0
    def _add_batch(self, items):
        """
        Add the experiences in the batch to the replay buffer. Only batches of size 1 are supported at the moment

        Params:
            items: this contains the experiences to be added
        """
        logger.info("Adding a batch of 1 experiences to Replay buffer")

        outer_shape = nest_utils.get_outer_array_shape(items, self._data_spec)
        if outer_shape[0] != 1:
            raise NotImplementedError('PyPrioritizedReplayBuffer only supports a batch '
                                      'size of 1, but received `items` with batch '
                                      'size {}.'.format(outer_shape[0]))

        item = nest_utils.unbatch_nested_array(items)

        # get maximum priority in the replay buffer or set it's initial value is 1
        max_priority = self._prioritized_buffer_priorities.max() if self._np_state.size > 0 else 1.0

        with self._lock:
            if self._np_state.size == self._prioritized_buffer_capacity:
                # If we are at capacity, we are deleting element cur_id.
                self._on_delete(self._storage.get(self._np_state.cur_id))

            self._storage.set(self._np_state.cur_id, self._encode(item))
            # add the max priority of the experience to the priority array
            self._prioritized_buffer_priorities[self._np_state.cur_id] = max_priority

            self._np_state.size = np.minimum(self._np_state.size + 1, self._prioritized_buffer_capacity)
            self._np_state.cur_id = (self._np_state.cur_id + 1) % self._prioritized_buffer_capacity
            self._np_state.item_count += 1
コード例 #6
0
 def _action(self, time_step, policy_state):
     time_step = nest_utils.batch_nested_array(time_step)
     # Avoid passing numpy arrays to avoid retracing of the tf.function.
     time_step = tf.nest.map_structure(tf.convert_to_tensor, time_step)
     policy_step = self._policy_action_fn(time_step, policy_state)
     return policy_step._replace(
         action=nest_utils.unbatch_nested_array(policy_step.action.numpy()))
コード例 #7
0
 def _get_initial_state(self, batch_size: int) -> types.NestedArray:
     if self._num_policies == 1:
         return nest_utils.batch_nested_array(
             self._policies[0].get_initial_state())
     else:
         infos = self._execute(_execute_get_initial_state, self._policies)
         infos = nest_utils.unbatch_nested_array(infos)
         return nest_utils.stack_nested_arrays(infos)
コード例 #8
0
    def write(self, *data):
        """Encodes and writes (to file) a batch of tensor data.

    Args:
      *data: (unpacked) list/tuple of batched np.arrays.
    """
        data = nest_utils.unbatch_nested_array(data)
        structured_data = tf.nest.pack_sequence_as(self._array_data_spec, data)
        self._writer.write(self._encoder(structured_data))
コード例 #9
0
 def step_adversary(self, actions):
   if self._num_envs == 1:
     actions = nest_utils.unbatch_nested_array(actions)
     time_steps = self._envs[0].step_adversary(actions)
     return nest_utils.batch_nested_array(time_steps)
   else:
     unstacked_actions = batched_py_environment.unstack_actions(actions)
     if len(unstacked_actions) != self.batch_size:
       raise ValueError(
           'Primary dimension of action items does not match '
           'batch size: %d vs. %d' % (len(unstacked_actions), self.batch_size))
     time_steps = self._execute(
         lambda env_action: env_action[0].step_adversary(env_action[1]),
         zip(self._envs, unstacked_actions))
     return nest_utils.stack_nested_arrays(time_steps)
コード例 #10
0
  def _add_batch(self, items):
    outer_shape = nest_utils.get_outer_array_shape(items, self._data_spec)
    if outer_shape[0] != 1:
      raise NotImplementedError('PyUniformReplayBuffer only supports a batch '
                                'size of 1, but received `items` with batch '
                                'size {}.'.format(outer_shape[0]))

    item = nest_utils.unbatch_nested_array(items)
    with self._lock:
      if self._np_state.size == self._capacity:
        # If we are at capacity, we are deleting element cur_id.
        self._on_delete(self._storage.get(self._np_state.cur_id))
      self._storage.set(self._np_state.cur_id, self._encode(item))
      self._np_state.size = np.minimum(self._np_state.size + 1, self._capacity)
      self._np_state.cur_id = (self._np_state.cur_id + 1) % self._capacity
      self._np_state.item_count += 1
コード例 #11
0
  def _action(self, time_step, policy_state):
    if not self._batched:
      # Since policy_state is given in a batched form from the policy and we
      # simply have to send it back we do not need to worry about it. Only
      # update time_step.
      time_step = nest_utils.batch_nested_array(time_step)

    nest.assert_same_structure(self._time_step, time_step)
    feed_dict = {self._time_step: time_step}
    if policy_state is not None:
      # Flatten policy_state to handle specs that are not hashable due to lists.
      for state_ph, state in zip(
          nest.flatten(self._policy_state), nest.flatten(policy_state)):
        feed_dict[state_ph] = state

    action_step = self.session.run(self._action_step, feed_dict)
    action, state, info = action_step

    if not self._batched:
      action, info = nest_utils.unbatch_nested_array([action, info])

    return policy_step.PolicyStep(action, state, info)
コード例 #12
0
ファイル: dqn.py プロジェクト: mshinji/mahjong
    def _step(self, action):
        action = nest_utils.unbatch_nested_array(action)
        score = self.score()
        dahai = self.dahai(action)
        # print(action, dahai, self.player.tehai)

        self.reward = 0
        self.game.dahai(dahai, self.player)
        while self.game.next():
            pass

        if self.game.state in [Const.RYUKYOKU_STATE, Const.AGARI_STATE]:
            self.reward = self.score() - score
            self.game_end = True
            time_step = ts.termination(self.board(), reward=0)
        elif self.game.state == Const.SYUKYOKU_STATE:
            self.reward = [90, 45, 0, -180][self.rank()] * 1000
            self.game_end = True
            time_step = ts.termination(self.board(), reward=0)
        else:
            time_step = ts.transition(self.board(), reward=0, discount=1)

        return nest_utils.batch_nested_array(time_step)
コード例 #13
0
 def _action(self, time_step, policy_state):
     time_step = nest_utils.batch_nested_array(time_step)
     # Pull out action from policy_step
     policy_step = self._policy.action(time_step, policy_state)
     return policy_step._replace(
         action=nest_utils.unbatch_nested_array(policy_step.action.numpy()))