Beispiel #1
0
 def observe(self, action: types.NestedArray,
             next_timestep: dm_env.TimeStep):
     if self._adder:
         numpy_state = utils.to_numpy_squeeze(self._prev_state)
         self._adder.add(action,
                         next_timestep,
                         extras=(numpy_state, ) + self._extras)
Beispiel #2
0
    def select_action(self,
                      observation: types.NestedArray) -> types.NestedArray:
        result, new_state = self._recurrent_policy(self._client.params,
                                                   key=next(self._rng),
                                                   observation=observation,
                                                   core_state=self._state)
        self._prev_state = self._state  # Keep previous state to save in replay.
        self._state = new_state  # Keep new state for next policy call.

        if self._has_extras:
            action, extras = result
            self._extras = utils.to_numpy_squeeze(
                extras)  # Keep to save in replay.
        else:
            action = result
        return utils.to_numpy_squeeze(action)
Beispiel #3
0
 def observe(self, action: network_lib.Action,
             next_timestep: dm_env.TimeStep):
     if self._adder:
         numpy_state = utils.to_numpy_squeeze(self._prev_state)
         self._adder.add(action,
                         next_timestep,
                         extras=(numpy_state, ) + self._extras)
Beispiel #4
0
 def select_action(self,
                   observation: types.NestedArray) -> types.NestedArray:
     key = next(self._rng)
     # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs.
     observation = utils.add_batch_dim(observation)
     action = self._policy(self._client.params, key, observation)
     return utils.to_numpy_squeeze(action)
Beispiel #5
0
    def select_action(
            self, observation: network_lib.Observation) -> network_lib.Action:
        result, new_state, self._random_key = self._recurrent_policy(
            self._client.params,
            key=self._random_key,
            observation=observation,
            core_state=self._state)
        self._prev_state = self._state  # Keep previous state to save in replay.
        self._state = new_state  # Keep new state for next policy call.

        if self._has_extras:
            action, extras = result
            self._extras = utils.to_numpy_squeeze(
                extras)  # Keep to save in replay.
        else:
            action = result
        return utils.to_numpy_squeeze(action)
Beispiel #6
0
 def select_action(self,
                   observation: types.NestedArray) -> types.NestedArray:
     action, new_state = self._recurrent_policy(
         self._client.params,
         key=next(self._rng),
         observation=utils.add_batch_dim(observation),
         core_state=self._state)
     self._prev_state = self._state  # Keep previous state to save in replay.
     self._state = new_state  # Keep new state for next policy call.
     return utils.to_numpy_squeeze(action)
Beispiel #7
0
 def select_action(self,
                   observation: types.NestedArray) -> types.NestedArray:
     key = next(self._rng)
     observation = utils.add_batch_dim(observation)
     action = self._policy(self._client.params, key, observation)
     return utils.to_numpy_squeeze(action)