def _distribution(self, time_step, policy_state): network_state, latent_state, last_action = policy_state latent_state = tf.where( time_step.is_first(), self._model_network.first_filter(time_step.observation), self._model_network.filter(time_step.observation, latent_state, last_action)) # Update the latent state policy_state = (network_state, latent_state, last_action) # Actor network outputs nested structure of distributions or actions. actions_or_distributions, network_state = self._apply_actor_network( time_step, policy_state) # Update the network state policy_state = (network_state, latent_state, last_action) def _to_distribution(action_or_distribution): if isinstance(action_or_distribution, tf.Tensor): # This is an action tensor, so wrap it in a deterministic distribution. return tfp.distributions.Deterministic( loc=action_or_distribution) return action_or_distribution distributions = tf.nest.map_structure(_to_distribution, actions_or_distributions) # Prepare policy_info. if self._collect: policy_info = ppo_utils.get_distribution_params(distributions) else: policy_info = () return policy_step.PolicyStep(distributions, policy_state, policy_info)
def _distribution(self, time_step, policy_state, training=False): # Actor network outputs a list of distributions or actions (one for each # agent), and a list of policy states for each agent actions_or_distributions, policy_state, attention_weights = self._apply_actor_network( time_step, policy_state, training=training) def _to_distribution(action_or_distribution): if isinstance(action_or_distribution, tf.Tensor): # This is an action tensor, so wrap it in a deterministic distribution. return tfp.distributions.Deterministic(loc=action_or_distribution) return action_or_distribution distributions = tf.nest.map_structure(_to_distribution, actions_or_distributions) # Prepare policy_info. if self._collect: policy_info = ppo_utils.get_distribution_params( distributions, False ) # Wrap policy info to be comptabile with new spec policy_info = list(policy_info) for a in range(len(policy_info)): if not self.inactive_agent_ids or a not in self.inactive_agent_ids: policy_info[a] = {'dist_params': policy_info[a]} policy_info[a].update({'attention_weights': attention_weights[a]}) # Fake logits for fixed agents. if self.inactive_agent_ids and self.learning_agents: for a in self.inactive_agent_ids: policy_info[a] = { 'dist_params': { 'logits': tf.zeros_like(policy_info[self.learning_agents[0]] ['dist_params']['logits']) } } policy_info = tuple(policy_info) # PolicyStep has actions, state, info step_result = policy_step.PolicyStep(distributions, policy_state, policy_info) else: # I was not able to use a GreedyPolicy wrapper and also override _action, # so I replicated the greedy functionality here. def dist_fn(dist): try: greedy_action = dist.mode() except NotImplementedError: raise ValueError("Your network's distribution does not implement " 'mode making it incompatible with a greedy policy.') return greedy_policy.DeterministicWithLogProb(loc=greedy_action) actions = tf.nest.map_structure(dist_fn, distributions) step_result = policy_step.PolicyStep(actions, policy_state, ()) return step_result
def _distribution(self, time_step, policy_state, training=False): if not policy_state: policy_state = { 'actor_network_state': (), 'value_network_state': () } else: policy_state = policy_state.copy() if 'actor_network_state' not in policy_state: policy_state['actor_network_state'] = () if 'value_network_state' not in policy_state: policy_state['value_network_state'] = () new_policy_state = { 'actor_network_state': (), 'value_network_state': () } def _to_distribution(action_or_distribution): if isinstance(action_or_distribution, tf.Tensor): # This is an action tensor, so wrap it in a deterministic distribution. return tfp.distributions.Deterministic( loc=action_or_distribution) return action_or_distribution (actions_or_distributions, new_policy_state['actor_network_state']) = self._apply_actor_network( time_step, policy_state['actor_network_state'], training=training) distributions = tf.nest.map_structure(_to_distribution, actions_or_distributions) if self._collect: policy_info = { 'dist_params': ppo_utils.get_distribution_params(distributions) } if not self._compute_value_and_advantage_in_train: # If value_prediction is not computed in agent.train it needs to be # computed and saved here. (policy_info['value_prediction'], new_policy_state['value_network_state'] ) = self.apply_value_network( time_step.observation, time_step.step_type, value_state=policy_state['value_network_state'], training=False) else: policy_info = () if (not new_policy_state['actor_network_state'] and not new_policy_state['value_network_state']): new_policy_state = () elif not new_policy_state['value_network_state']: new_policy_state.pop('value_network_state', None) elif not new_policy_state['actor_network_state']: new_policy_state.pop('actor_network_state', None) return policy_step.PolicyStep(distributions, new_policy_state, policy_info)
def test_get_distribution_params(self): ones = tf.ones(shape=[2], dtype=tf.float32) distribution = (tfp.distributions.Categorical(logits=ones), tfp.distributions.Normal(ones, ones)) params = ppo_utils.get_distribution_params(distribution) self.assertAllEqual([set(['logits']), set(['loc', 'scale'])], [set(d.keys()) for d in params]) self.assertAllEqual([[[2]], [[2], [2]]], [[d[k].shape.as_list() for k in d] for d in params])
def _distribution(self, time_step, policy_state, training=False): if not policy_state: policy_state = { 'actor_network_state': (), 'value_network_state': () } else: policy_state = policy_state.copy() if 'actor_network_state' not in policy_state: policy_state['actor_network_state'] = () if 'value_network_state' not in policy_state: policy_state['value_network_state'] = () new_policy_state = { 'actor_network_state': (), 'value_network_state': () } (distributions, new_policy_state['actor_network_state']) = (self._apply_actor_network( time_step, policy_state['actor_network_state'], training=training)) if self._collect: policy_info = { 'dist_params': ppo_utils.get_distribution_params( distributions, legacy_distribution_network=isinstance( self._actor_network, network.DistributionNetwork)) } if not self._compute_value_and_advantage_in_train: # If value_prediction is not computed in agent.train it needs to be # computed and saved here. (policy_info['value_prediction'], new_policy_state['value_network_state'] ) = self.apply_value_network( time_step.observation, time_step.step_type, value_state=policy_state['value_network_state'], training=False) else: policy_info = () if (not new_policy_state['actor_network_state'] and not new_policy_state['value_network_state']): new_policy_state = () elif not new_policy_state['value_network_state']: del new_policy_state['value_network_state'] elif not new_policy_state['actor_network_state']: del new_policy_state['actor_network_state'] return policy_step.PolicyStep(distributions, new_policy_state, policy_info)
def test_get_distribution_params(self, legacy_distribution_network): ones = tf.Variable(tf.ones(shape=[2], dtype=tf.float32)) distribution = (tfp.distributions.Categorical(logits=ones), tfp.distributions.Normal(ones, ones)) params = ppo_utils.get_distribution_params( distribution, legacy_distribution_network) self.assertEqual([set(['logits']), set(['loc', 'scale'])], [set(d.keys()) for d in params]) # pytype: disable=attribute-error self.assertAllEqual([[[2]], [[2], [2]]], [[d[k].shape.as_list() for k in d] for d in params]) # pytype: disable=attribute-error
def _distribution(self, time_step, policy_state): # Actor network outputs nested structure of distributions or actions. actions_or_distributions, policy_state = self._apply_actor_network( time_step, policy_state) def _to_distribution(action_or_distribution): if isinstance(action_or_distribution, tf.Tensor): # This is an action tensor, so wrap it in a deterministic distribution. return tfp.distributions.Deterministic(loc=action_or_distribution) return action_or_distribution distributions = tf.nest.map_structure(_to_distribution, actions_or_distributions) # Prepare policy_info. if self._collect: policy_info = ppo_utils.get_distribution_params(distributions) else: policy_info = () return policy_step.PolicyStep(distributions, policy_state, policy_info)