def _distribution(self, time_step, policy_state): # In DQN, we always either take a uniformly random action, or the action # with the highest Q-value. However, to support more complicated policies, # we expose all Q-values as a categorical distribution with Q-values as # logits, and apply the GreedyPolicy wrapper in dqn_agent.py to select the # action with the highest Q-value. q_values, policy_state = self._q_network(time_step.observation, time_step.step_type, policy_state) # TODO(b/122314058): Validate and enforce that sampling distributions # created with the q_network logits generate the right action shapes. This # is curretly patching the problem. # If the action spec says each action should be shaped (1,), add another # dimension so the final shape is (B, 1, A), where A is the number of # actions. This will make Categorical emit events shaped (B, 1) rather than # (B,). Using axis -2 to allow for (B, T, 1, A) shaped q_values. if self._flat_action_spec.shape.ndims == 1: q_values = tf.expand_dims(q_values, -2) # TODO(kbanoop): Handle distributions over nests. distribution = shifted_categorical.ShiftedCategorical( logits=q_values, dtype=self._flat_action_spec.dtype, shift=self._flat_action_spec.minimum) distribution = tf.nest.pack_sequence_as(self._action_spec, [distribution]) return policy_step.PolicyStep(distribution, policy_state)
def _distribution(self, time_step, policy_state): network_observation = time_step.observation if self._observation_and_action_constraint_splitter: network_observation, mask = (self._observation_and_action_constraint_splitter(network_observation)) q_values, policy_state = self._q_network(network_observation, time_step.step_type, policy_state) if self._flat_action_spec.shape.rank == 1: q_values = tf.expand_dims(q_values, -2) logits = q_values if self._observation_and_action_constraint_splitter: if self._flat_action_spec.shape.rank == 1: mask = tf.expand_dims(mask, -2) neg_inf = tf.constat(-np.inf, dtype=logits.dtype) logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits, neg_inf) distribution = shifted_categorical.ShiftedCategorical(logits=logits, dtype=self._flat_action_spec.dtype, shift=self._flat_action_spec.minimum) distribution = tf.nest.pack_sequence_as(self._action_spec, [distribution]) return policy_step.PolicyStep(distribution, policy_state)
def testCompareToCategorical(self): # Use the same probabilities for normal categorical and shifted one. shift = 2 probabilities = [0.3, 0.3, 0.4] distribution = tfp.distributions.Categorical(probs=probabilities) shifted_distribution = shifted_categorical.ShiftedCategorical( probs=probabilities, shift=shift) # Compare outputs of basic methods, using the same starting seed. tf.compat.v1.set_random_seed(1) # required per b/131171329, only with TF2. sample = distribution.sample(seed=1) tf.compat.v1.set_random_seed(1) # required per b/131171329, only with TF2. shifted_sample = shifted_distribution.sample(seed=1) mode = distribution.mode() shifted_mode = shifted_distribution.mode() sample, shifted_sample = self.evaluate([sample, shifted_sample]) mode, shifted_mode = self.evaluate([mode, shifted_mode]) self.assertEqual(shifted_sample, sample + shift) self.assertEqual(shifted_mode, mode + shift) # These functions should return the same values for shifted values. fns = ['cdf', 'log_cdf', 'prob', 'log_prob'] for fn_name in fns: fn = getattr(distribution, fn_name) shifted_fn = getattr(shifted_distribution, fn_name) value, shifted_value = self.evaluate([fn(sample), shifted_fn(shifted_sample)]) self.assertEqual(value, shifted_value)
def _distribution(self, time_step, policy_state): network_obs = time_step.observation q_values, policy_state = self._network(network_obs, time_step.step_type, policy_state) logits = q_values distribution = shifted_categorical.ShiftedCategorical(logits=logits, dtype=self._flat_action_spec.dtype, shift=self._flat_action_spec.minimum) distribution = tf.nest.pack_sequence_as(self._action_spec, [distribution]) return policy_step.PolicyStep(distribution, policy_state)
def testCopy(self): """Confirm we can copy the distribution.""" distribution = shifted_categorical.ShiftedCategorical( logits=[100.0, 100.0, 100.0], shift=2) copy = distribution.copy() with self.cached_session() as s: probs_np = s.run(copy.probs_parameter()) logits_np = s.run(copy.logits_parameter()) ref_probs_np = s.run(distribution.probs_parameter()) ref_logits_np = s.run(distribution.logits_parameter()) self.assertAllEqual(ref_logits_np, logits_np) self.assertAllEqual(ref_probs_np, probs_np)
def _action_distribution(self, time_step, policy_state, seed): """Implementation of `_action_distribution`. Args: time_step: A `TimeStep` tuple corresponding to `time_step_spec()`. policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. seed: Seed to use if action performs sampling (optional). Returns: A `PolicyStep` named tuple containing: `action`: An action Tensor matching the `action_spec()`. `state`: A policy state tensor to be fed into the next call to action. `info`: Optional side information such as action log probabilities. A `PolicyStep` named tuple containing: `action`: A Logit Tensor representing Q values`. `state`: A policy state tensor to be fed into the next call to action. `info`: Optional side information such as action log probabilities. """ raw_action, transformed_actions, action_logits = self._feedforward( time_step, policy_state, seed) def _to_distribution(action_or_distribution): if isinstance(action_or_distribution, tf.Tensor): # This is an action tensor, so wrap it in a deterministic distribution. return tfp.distributions.Deterministic( loc=action_or_distribution) return action_or_distribution combined_actions = dict() combined_actions[self._raw_action_key] = raw_action combined_actions[self._transformed_action_key] = transformed_actions q_distributions = [ shifted_categorical.ShiftedCategorical(logits=logit, dtype=tf.int32, shift=0) for logit in action_logits ] distribution_steps = [ policy_step.PolicyStep(q_dist, policy_state) for q_dist in q_distributions ] actions = tf.nest.map_structure(_to_distribution, combined_actions) action_step = policy_step.PolicyStep(actions, policy_state) seed_stream = tfp.distributions.SeedStream( seed=seed, salt='sc2_sequential_policy') actions = tf.nest.map_structure(lambda d: d.sample(seed=seed_stream()), action_step.action) return action_step._replace(action=actions), distribution_steps
def _distribution(self, time_step, policy_state): # In DQN, we always either take a uniformly random action, or the action # with the highest Q-value. However, to support more complicated policies, # we expose all Q-values as a categorical distribution with Q-values as # logits, and apply the GreedyPolicy wrapper in dqn_agent.py to select the # action with the highest Q-value. observation_and_action_constraint_splitter = ( self.observation_and_action_constraint_splitter) network_observation = time_step.observation if observation_and_action_constraint_splitter is not None: network_observation, mask = observation_and_action_constraint_splitter( network_observation) q_values, policy_state = self._q_network(network_observation, time_step.step_type, policy_state) # TODO(b/122314058): Validate and enforce that sampling distributions # created with the q_network logits generate the right action shapes. This # is curretly patching the problem. # If the action spec says each action should be shaped (1,), add another # dimension so the final shape is (B, 1, A), where A is the number of # actions. This will make Categorical emit events shaped (B, 1) rather than # (B,). Using axis -2 to allow for (B, T, 1, A) shaped q_values. if self._flat_action_spec.shape.rank == 1: q_values = tf.expand_dims(q_values, -2) logits = q_values if observation_and_action_constraint_splitter is not None: # Expand the mask as needed in the same way as q_values above. if self._flat_action_spec.shape.rank == 1: mask = tf.expand_dims(mask, -2) # Overwrite the logits for invalid actions to -inf. neg_inf = tf.constant(-np.inf, dtype=logits.dtype) logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits, neg_inf) # TODO(kbanoop): Handle distributions over nests. if self._flat_action_spec.minimum != 0: distribution = shifted_categorical.ShiftedCategorical( logits=logits, dtype=self._flat_action_spec.dtype, shift=self._flat_action_spec.minimum) else: distribution = tfp.distributions.Categorical( logits=logits, dtype=self._flat_action_spec.dtype) distribution = tf.nest.pack_sequence_as(self._action_spec, [distribution]) return policy_step.PolicyStep(distribution, policy_state)
def testShiftedSampling(self): distribution = shifted_categorical.ShiftedCategorical( probs=[0.1, 0.8, 0.1], shift=2) sample = distribution.sample() log_prob = distribution.log_prob(sample) results = [] with self.cached_session() as s: for _ in range(100): value, _ = s.run([sample, log_prob]) results.append(value) results = np.array(results, dtype=np.int32) self.assertTrue(np.all(results >= 2)) self.assertTrue(np.all(results <= 4))
def _distribution(self, time_step, policy_state): observation = time_step.observation q_values = self.func(observation) logits = q_values if self.q_policy._flat_action_spec.minimum != 0: distribution = shifted_categorical.ShiftedCategorical( logits=logits, dtype=self.q_policy._flat_action_spec.dtype, shift=self.q_policy._flat_action_spec.minimum) else: distribution = tfp.distributions.Categorical( logits=logits, dtype=self.q_policy._flat_action_spec.dtype) distribution = tf.nest.pack_sequence_as(self.q_policy._action_spec, [distribution]) return policy_step.PolicyStep(distribution, policy_state)
def _distribution(self, time_step, policy_state): # In DQN, we always either take a uniformly random action, or the action # with the highest Q-value. However, to support more complicated policies, # we expose all Q-values as a categorical distribution with Q-values as # logits, and apply the GreedyPolicy wrapper in dqn_agent.py to select the # action with the highest Q-value. observation_and_action_constraint_splitter = ( self.observation_and_action_constraint_splitter) network_observation = time_step.observation if observation_and_action_constraint_splitter is not None: network_observation, mask = observation_and_action_constraint_splitter( network_observation) q_values, policy_state = self._q_network( network_observation, network_state=policy_state, step_type=time_step.step_type) logits = q_values if observation_and_action_constraint_splitter is not None: # Overwrite the logits for invalid actions to logits.dtype.min. almost_neg_inf = tf.constant(logits.dtype.min, dtype=logits.dtype) logits = tf.compat.v2.where( tf.cast(mask, tf.bool), logits, almost_neg_inf) if self._flat_action_spec.minimum != 0: distribution = shifted_categorical.ShiftedCategorical( logits=logits, dtype=self._flat_action_spec.dtype, shift=self._flat_action_spec.minimum) else: distribution = tfp.distributions.Categorical( logits=logits, dtype=self._flat_action_spec.dtype) distribution = tf.nest.pack_sequence_as(self._action_spec, [distribution]) return policy_step.PolicyStep(distribution, policy_state)
def _distribution(self, time_step, policy_state): observation = time_step.observation if self.observation_and_action_constraint_splitter is not None: observation, _ = self.observation_and_action_constraint_splitter( observation) predictions, policy_state = self._reward_network( observation, time_step.step_type, policy_state) batch_size = tf.shape(predictions)[0] if isinstance(self._reward_network, heteroscedastic_q_network.HeteroscedasticQNetwork): predicted_reward_values = predictions.q_value_logits else: predicted_reward_values = predictions predicted_reward_values.shape.with_rank_at_least(2) predicted_reward_values.shape.with_rank_at_most(3) if predicted_reward_values.shape[ -1] is not None and predicted_reward_values.shape[ -1] != self._expected_num_actions: raise ValueError( 'The number of actions ({}) does not match the reward_network output' ' size ({}).'.format(self._expected_num_actions, predicted_reward_values.shape[1])) mask = constr.construct_mask_from_multiple_sources( time_step.observation, self._observation_and_action_constraint_splitter, self._constraints, self._expected_num_actions) # Apply the temperature scaling, needed for Boltzmann exploration. logits = predicted_reward_values / self._get_temperature_value() # Apply masking if needed. Overwrite the logits for invalid actions to # logits.dtype.min. if mask is not None: almost_neg_inf = tf.constant(logits.dtype.min, dtype=logits.dtype) logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits, almost_neg_inf) if self._action_offset != 0: distribution = shifted_categorical.ShiftedCategorical( logits=logits, dtype=self._action_spec.dtype, shift=self._action_offset) else: distribution = tfp.distributions.Categorical( logits=logits, dtype=self._action_spec.dtype) actions = distribution.sample() bandit_policy_values = tf.fill( [batch_size, 1], policy_utilities.BanditPolicyType.BOLTZMANN) if self._accepts_per_arm_features: # Saving the features for the chosen action to the policy_info. def gather_observation(obs): return tf.gather(params=obs, indices=actions, batch_dims=1) chosen_arm_features = tf.nest.map_structure( gather_observation, observation[bandit_spec_utils.PER_ARM_FEATURE_KEY]) policy_info = policy_utilities.PerArmPolicyInfo( log_probability=distribution.log_prob(actions) if policy_utilities.InfoFields.LOG_PROBABILITY in self._emit_policy_info else (), predicted_rewards_mean=( predicted_reward_values if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in self._emit_policy_info else ()), bandit_policy_type=( bandit_policy_values if policy_utilities.InfoFields.BANDIT_POLICY_TYPE in self._emit_policy_info else ()), chosen_arm_features=chosen_arm_features) else: policy_info = policy_utilities.PolicyInfo( log_probability=distribution.log_prob(actions) if policy_utilities.InfoFields.LOG_PROBABILITY in self._emit_policy_info else (), predicted_rewards_mean=( predicted_reward_values if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in self._emit_policy_info else ()), bandit_policy_type=( bandit_policy_values if policy_utilities.InfoFields.BANDIT_POLICY_TYPE in self._emit_policy_info else ())) return policy_step.PolicyStep( tfp.distributions.Deterministic(loc=actions), policy_state, policy_info)
def _distribution(self, time_step, policy_state): # In DQN, we always either take a uniformly random action, or the action # with the highest Q-value. However, to support more complicated policies, # we expose all Q-values as a categorical distribution with Q-values as # logits, and apply the GreedyPolicy wrapper in dqn_agent.py to select the # action with the highest Q-value. neg_inf = tf.constant(-np.inf, dtype=tf.float32) previous_action = time_step.observation[self._previous_action_key] if self._use_previous_action else None func_action = time_step.observation[self._func_action_key] if self._use_previous_action else None assert (self._available_actions_key in time_step.observation) == \ (self._available_actions_key in self._time_step_spec.observation) available_actions = None if self._available_actions_key in time_step.observation: available_actions = time_step.observation[self._available_actions_key] # time_step.observation is a dict of screen, minimap and structured info. time_step_obs = time_step.observation (spatial_q_values, structured_q_values), q_policy_state = self._mixed_q_network( time_step_obs, time_step.step_type, policy_state) assert isinstance(spatial_q_values, dict) and isinstance(structured_q_values, dict) if available_actions is not None: available_actions = tf.convert_to_tensor(available_actions) assert all([available_actions.shape == t.shape for t in structured_q_values.values()]) structured_q_values = tf.nest.map_structure(lambda x: tf.where(tf.equal(available_actions, 1), x, neg_inf), structured_q_values) # TODO(b/122314058): Validate and enforce that sampling distributions # created with the q_network logits generate the right action shapes. This # is curretly patching the problem. # If the action spec says each action should be shaped (1,), add another # dimension so the final shape is (B, 1, A), where A is the number of # actions. This will make Categorical emit events shaped (B, 1) rather than # (B,). Using axis -2 to allow for (B, T, 1, A) shaped q_values. assert all([s.shape.ndims == 1 for s in self._flat_action_spec]) \ or all([s.shape.ndims == 0 for s in self._flat_action_spec]), \ "all action specs' ndims should be consistently 1 or 0" if previous_action is not None: tf.assert_equal(tf.add_n([spatial_q_values[k].shape[-1] if k in spatial_q_values else 0 for k in self._spatial_names] + [structured_q_values[k].shape[-1] if k in structured_q_values else 0 for k in self._structured_names]), self._func_arg_mask.shape[-1]) if func_action is None: discrete_func_action = None else: assert isinstance(previous_action, dict) discrete_func_action = func_action[self._discrete_action_key] \ if self._discrete_action_key in func_action else None spatial_q_values, structured_q_values = self._mask_logits(spatial_q_values, structured_q_values, discrete_func_action) if self._flat_action_spec[0].shape.ndims == 1: for k, v in spatial_q_values.items(): spatial_q_values[k] = tf.expand_dims(v, -2) for k, v in structured_q_values.items(): structured_q_values[k] = tf.expand_dims(v, -2) # concatenate q_values into single represnetation q_values = [] for name in self._spatial_names: if name in spatial_q_values: q_values.append(spatial_q_values[name]) for name in self._structured_names: if name in structured_q_values: q_values.append(structured_q_values[name]) q_values = tf.concat(q_values, axis=-1) ##TODO: mask_split_fn needs to be investigated # logits = spatial_q_values[self._spatial_names[0]] # mask_split_fn = self._q_network.mask_split_fn # # neg_inf = tf.constant(-np.inf, dtype=tf.float32) # if mask_split_fn: # _, mask = mask_split_fn(time_step.observation) # # # Expand the mask as needed in the same way as q_values above. # if self._flat_action_spec.shape.ndims == 1: # mask = tf.expand_dims(mask, -2) # # # Overwrite the logits for invalid actions to -inf. # logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits, neg_inf) # TODO(kbanoop): Handle distributions over nests. q_distribution = shifted_categorical.ShiftedCategorical( logits=q_values, dtype=tf.int32, shift=0) # q_distributions = dict() # for k, v in spatial_q_values.items(): # print(v.shape) # q_distributions[k] = shifted_categorical.ShiftedCategorical( # logits=v, # dtype=self._action_spec[k].dtype, # shift=self._action_spec[k].minimum[0]) # for k, v in structured_q_values.items(): # print(v.shape) # q_distributions[k] = shifted_categorical.ShiftedCategorical( # logits=v, # dtype=self._action_spec[k].dtype, # shift=self._action_spec[k].minimum[0]) # q_distribution = tf.nest.pack_sequence_as(self._action_spec, [q_distribution]) return policy_step.PolicyStep(q_distribution, q_policy_state)