def _distribution(self, time_step, policy_state): """Generates the distribution over next actions given the time_step. Args: time_step: A `TimeStep` tuple corresponding to `time_step_spec()`. policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. Returns: A tfp.distributions.Categorical capturing the distribution of next actions. A policy_state Tensor, or a nested dict, list or tuple of Tensors, representing the new policy state. """ q_logits, policy_state = self._q_network(time_step.observation, time_step.step_type, policy_state) q_logits.shape.assert_has_rank(3) q_values = common.convert_q_logits_to_values(q_logits, self._support) logits = q_values mask_split_fn = self._q_network.mask_split_fn if mask_split_fn: _, mask = mask_split_fn(time_step.observation) # Overwrite the logits for invalid actions to -inf. neg_inf = tf.constant(-np.inf, dtype=logits.dtype) logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits, neg_inf) dist = tfp.distributions.Categorical(logits=logits, dtype=self.action_spec.dtype) return policy_step.PolicyStep(dist, policy_state)
def testConvertQLogitsToValuesBatch(self): logits = tf.constant([[[1., 20., 1.], [1., 1., 20.]], [[20., 1., 1.], [1., 20., 20.]]]) support = tf.constant([10., 20., 30.]) values = common.convert_q_logits_to_values(logits, support) values_ = self.evaluate(values) self.assertAllClose(values_, [[20.0, 30.0], [10., 25.]], 0.001)
def _distribution(self, time_step, policy_state): """Generates the distribution over next actions given the time_step. Args: time_step: A `TimeStep` tuple corresponding to `time_step_spec()`. policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. Returns: A tfp.distributions.Categorical capturing the distribution of next actions. A policy_state Tensor, or a nested dict, list or tuple of Tensors, representing the new policy state. """ network_observation = time_step.observation observation_and_action_constraint_splitter = ( self.observation_and_action_constraint_splitter) if observation_and_action_constraint_splitter is not None: network_observation, mask = ( observation_and_action_constraint_splitter(network_observation)) q_logits, policy_state = self._q_network( network_observation, step_type=time_step.step_type, network_state=policy_state) q_logits.shape.assert_has_rank(3) q_values = common.convert_q_logits_to_values(q_logits, self._support) logits = q_values if observation_and_action_constraint_splitter is not None: # Overwrite the logits for invalid actions to -inf. neg_inf = tf.constant(-np.inf, dtype=logits.dtype) logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits, neg_inf) dist = tfp.distributions.Categorical( logits=logits, dtype=self.action_spec.dtype) # pytype: disable=attribute-error return policy_step.PolicyStep(dist, policy_state)
def _distribution(self, time_step, policy_state): """Generates the distribution over next actions given the time_step. Args: time_step: A `TimeStep` tuple corresponding to `time_step_spec()`. policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. Returns: A tfp.distributions.Categorical capturing the distribution of next actions. A policy_state Tensor, or a nested dict, list or tuple of Tensors, representing the new policy state. """ q_logits, policy_state = self._q_network(time_step.observation, time_step.step_type, policy_state) q_logits.shape.assert_has_rank(3) q_values = common.convert_q_logits_to_values(q_logits, self._support) return policy_step.PolicyStep( tfp.distributions.Categorical(logits=q_values, dtype=self.action_spec.dtype), policy_state)