def select_action(self, observation: types.NestedArray) -> types.NestedArray: # Add a dummy batch dimension and as a side effect convert numpy to TF. batched_obs = tf2_utils.add_batch_dim(observation) # Initialize the RNN state if necessary. if self._state is None: self._state = self._network.initial_state(1) # Forward. policy_output, new_state = self._policy(batched_obs, self._state) # If the policy network parameterises a distribution, sample from it. def maybe_sample(output): if isinstance(output, tfd.Distribution): output = output.sample() return output policy_output = tree.map_structure(maybe_sample, policy_output) self._prev_state = self._state self._state = new_state # Convert to numpy and squeeze out the batch dimension. action = tf2_utils.to_numpy_squeeze(policy_output) return action
def observe( self, action: types.NestedArray, next_timestep: dm_env.TimeStep, ): if not self._adder: return numpy_state = tf2_utils.to_numpy_squeeze(self._prev_state) self._adder.add(action, next_timestep, extras=(numpy_state,))
def observe( self, action: types.NestedArray, next_timestep: dm_env.TimeStep, ): if not self._adder: return extras = {'logits': self._prev_logits, 'core_state': self._prev_state} extras = tf2_utils.to_numpy_squeeze(extras) self._adder.add(action, next_timestep, extras)
def select_action(self, observation: types.NestedArray) -> types.NestedArray: # Add a dummy batch dimension and as a side effect convert numpy to TF. batched_obs = tf2_utils.add_batch_dim(observation) if self._state is None: self._state = self._network.initial_state(1) # Forward. (logits, _), new_state = self._policy(batched_obs, self._state) self._prev_logits = logits self._prev_state = self._state self._state = new_state action = tfd.Categorical(logits).sample() action = tf2_utils.to_numpy_squeeze(action) return action
def select_action(self, observation: types.NestedArray) -> types.NestedArray: # Add a dummy batch dimension and as a side effect convert numpy to TF. batched_obs = tf2_utils.add_batch_dim(observation) # Forward the policy network. policy_output = self._policy_network(batched_obs) # If the policy network parameterises a distribution, sample from it. def maybe_sample(output): if isinstance(output, tfd.Distribution): output = output.sample() return output policy_output = tree.map_structure(maybe_sample, policy_output) # Convert to numpy and squeeze out the batch dimension. action = tf2_utils.to_numpy_squeeze(policy_output) return action