Пример #1
0
 def select_action(self,
                   observation: types.NestedArray) -> types.NestedArray:
     key = next(self._rng)
     result = self._policy(self._client.params, key, observation)
     if self._has_extras:
         action, self._extras = result
     else:
         action = result
     return utils.to_numpy(action)
Пример #2
0
 def select_action(
         self, observation: network_lib.Observation) -> types.NestedArray:
     result, self._random_key = self._policy(self._client.params,
                                             self._random_key, observation)
     if self._has_extras:
         action, self._extras = result
     else:
         action = result
     return utils.to_numpy(action)
Пример #3
0
 def select_action(self,
                   observation: types.NestedArray) -> types.NestedArray:
     action, new_state = self._recurrent_policy(self._client.params,
                                                key=next(self._rng),
                                                observation=observation,
                                                core_state=self._state)
     self._prev_state = self._state  # Keep previous state to save in replay.
     self._state = new_state  # Keep new state for next policy call.
     return utils.to_numpy(action)
Пример #4
0
 def select_action(self,
                   observation):
   if (self._params['mlp/~/linear_0']['b'] == 0).all():
     shape = self._params['Normal/~/linear']['b'].shape
     rng, self._state = jax.random.split(self._state)
     action = jax.random.uniform(key=rng, shape=shape,
                                 minval=-1.0, maxval=1.0)
   else:
     action, self._state = self._policy(self._params, observation,
                                        self._state)
   return utils.to_numpy(action)
Пример #5
0
 def select_action(
         self, observation: network_lib.Observation) -> types.NestedArray:
     action, self._state = self._policy(self._params, observation,
                                        self._state)
     return utils.to_numpy(action)
Пример #6
0
 def observe(self, action: types.NestedArray,
             next_timestep: dm_env.TimeStep):
     if self._adder:
         numpy_state = utils.to_numpy(self._prev_state)
         self._adder.add(action, next_timestep, extras=(numpy_state, ))
Пример #7
0
 def select_action(self, observation):
     discrete_action = self._wrapped_actor.select_action(observation)
     action = self._policy(self._client.params, observation,
                           discrete_action)
     self._last_discrete_action = discrete_action
     return utils.to_numpy(action)