コード例 #1
0
  def select_action(self, observation: types.NestedArray) -> types.NestedArray:
    # Add a dummy batch dimension and as a side effect convert numpy to TF.
    batched_obs = tf2_utils.add_batch_dim(observation)

    # Initialize the RNN state if necessary.
    if self._state is None:
      self._state = self._network.initial_state(1)

    # Forward.
    policy_output, new_state = self._policy(batched_obs, self._state)

    # If the policy network parameterises a distribution, sample from it.
    def maybe_sample(output):
      if isinstance(output, tfd.Distribution):
        output = output.sample()
      return output

    policy_output = tree.map_structure(maybe_sample, policy_output)

    self._prev_state = self._state
    self._state = new_state

    # Convert to numpy and squeeze out the batch dimension.
    action = tf2_utils.to_numpy_squeeze(policy_output)

    return action
コード例 #2
0
  def observe(
      self,
      action: types.NestedArray,
      next_timestep: dm_env.TimeStep,
  ):
    if not self._adder:
      return

    numpy_state = tf2_utils.to_numpy_squeeze(self._prev_state)
    self._adder.add(action, next_timestep, extras=(numpy_state,))
コード例 #3
0
    def observe(
        self,
        action: types.NestedArray,
        next_timestep: dm_env.TimeStep,
    ):
        if not self._adder:
            return

        extras = {'logits': self._prev_logits, 'core_state': self._prev_state}
        extras = tf2_utils.to_numpy_squeeze(extras)
        self._adder.add(action, next_timestep, extras)
コード例 #4
0
    def select_action(self,
                      observation: types.NestedArray) -> types.NestedArray:
        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_obs = tf2_utils.add_batch_dim(observation)

        if self._state is None:
            self._state = self._network.initial_state(1)

        # Forward.
        (logits, _), new_state = self._policy(batched_obs, self._state)

        self._prev_logits = logits
        self._prev_state = self._state
        self._state = new_state

        action = tfd.Categorical(logits).sample()
        action = tf2_utils.to_numpy_squeeze(action)

        return action
コード例 #5
0
  def select_action(self, observation: types.NestedArray) -> types.NestedArray:
    # Add a dummy batch dimension and as a side effect convert numpy to TF.
    batched_obs = tf2_utils.add_batch_dim(observation)

    # Forward the policy network.
    policy_output = self._policy_network(batched_obs)

    # If the policy network parameterises a distribution, sample from it.
    def maybe_sample(output):
      if isinstance(output, tfd.Distribution):
        output = output.sample()
      return output

    policy_output = tree.map_structure(maybe_sample, policy_output)

    # Convert to numpy and squeeze out the batch dimension.
    action = tf2_utils.to_numpy_squeeze(policy_output)

    return action