Exemplo n.º 1
0
    def _distribution(self, time_step, policy_state):
        # In DQN, we always either take a uniformly random action, or the action
        # with the highest Q-value. However, to support more complicated policies,
        # we expose all Q-values as a categorical distribution with Q-values as
        # logits, and apply the GreedyPolicy wrapper in dqn_agent.py to select the
        # action with the highest Q-value.
        q_values, policy_state = self._q_network(time_step.observation,
                                                 time_step.step_type,
                                                 policy_state)

        # TODO(b/122314058): Validate and enforce that sampling distributions
        # created with the q_network logits generate the right action shapes. This
        # is curretly patching the problem.

        # If the action spec says each action should be shaped (1,), add another
        # dimension so the final shape is (B, 1, A), where A is the number of
        # actions. This will make Categorical emit events shaped (B, 1) rather than
        # (B,). Using axis -2 to allow for (B, T, 1, A) shaped q_values.
        if self._flat_action_spec.shape.ndims == 1:
            q_values = tf.expand_dims(q_values, -2)

        # TODO(kbanoop): Handle distributions over nests.
        distribution = shifted_categorical.ShiftedCategorical(
            logits=q_values,
            dtype=self._flat_action_spec.dtype,
            shift=self._flat_action_spec.minimum)
        distribution = tf.nest.pack_sequence_as(self._action_spec,
                                                [distribution])
        return policy_step.PolicyStep(distribution, policy_state)
Exemplo n.º 2
0
    def _distribution(self, time_step, policy_state):
        network_observation = time_step.observation

        if self._observation_and_action_constraint_splitter:
            network_observation, mask = (self._observation_and_action_constraint_splitter(network_observation))

        q_values, policy_state = self._q_network(network_observation, time_step.step_type, policy_state)

        if self._flat_action_spec.shape.rank == 1:
            q_values = tf.expand_dims(q_values, -2)

        logits = q_values

        if self._observation_and_action_constraint_splitter:
            if self._flat_action_spec.shape.rank == 1:
                mask = tf.expand_dims(mask, -2)

            neg_inf = tf.constat(-np.inf, dtype=logits.dtype)
            logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits, neg_inf)

        distribution = shifted_categorical.ShiftedCategorical(logits=logits,
                                                              dtype=self._flat_action_spec.dtype,
                                                              shift=self._flat_action_spec.minimum)

        distribution = tf.nest.pack_sequence_as(self._action_spec, [distribution])
        return policy_step.PolicyStep(distribution, policy_state)
Exemplo n.º 3
0
  def testCompareToCategorical(self):
    # Use the same probabilities for normal categorical and shifted one.
    shift = 2
    probabilities = [0.3, 0.3, 0.4]
    distribution = tfp.distributions.Categorical(probs=probabilities)
    shifted_distribution = shifted_categorical.ShiftedCategorical(
        probs=probabilities, shift=shift)

    # Compare outputs of basic methods, using the same starting seed.
    tf.compat.v1.set_random_seed(1)  # required per b/131171329, only with TF2.
    sample = distribution.sample(seed=1)
    tf.compat.v1.set_random_seed(1)  # required per b/131171329, only with TF2.
    shifted_sample = shifted_distribution.sample(seed=1)

    mode = distribution.mode()
    shifted_mode = shifted_distribution.mode()

    sample, shifted_sample = self.evaluate([sample, shifted_sample])
    mode, shifted_mode = self.evaluate([mode, shifted_mode])

    self.assertEqual(shifted_sample, sample + shift)
    self.assertEqual(shifted_mode, mode + shift)

    # These functions should return the same values for shifted values.
    fns = ['cdf', 'log_cdf', 'prob', 'log_prob']
    for fn_name in fns:
      fn = getattr(distribution, fn_name)
      shifted_fn = getattr(shifted_distribution, fn_name)
      value, shifted_value = self.evaluate([fn(sample),
                                            shifted_fn(shifted_sample)])
      self.assertEqual(value, shifted_value)
Exemplo n.º 4
0
    def _distribution(self, time_step, policy_state):
        network_obs = time_step.observation

        q_values, policy_state = self._network(network_obs, time_step.step_type, policy_state)
        logits = q_values
        distribution = shifted_categorical.ShiftedCategorical(logits=logits,
                                                              dtype=self._flat_action_spec.dtype,
                                                              shift=self._flat_action_spec.minimum)
        distribution = tf.nest.pack_sequence_as(self._action_spec, [distribution])
        return policy_step.PolicyStep(distribution, policy_state)
Exemplo n.º 5
0
 def testCopy(self):
   """Confirm we can copy the distribution."""
   distribution = shifted_categorical.ShiftedCategorical(
       logits=[100.0, 100.0, 100.0], shift=2)
   copy = distribution.copy()
   with self.cached_session() as s:
     probs_np = s.run(copy.probs_parameter())
     logits_np = s.run(copy.logits_parameter())
     ref_probs_np = s.run(distribution.probs_parameter())
     ref_logits_np = s.run(distribution.logits_parameter())
   self.assertAllEqual(ref_logits_np, logits_np)
   self.assertAllEqual(ref_probs_np, probs_np)
Exemplo n.º 6
0
    def _action_distribution(self, time_step, policy_state, seed):
        """Implementation of `_action_distribution`.

         Args:
           time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
           policy_state: A Tensor, or a nested dict, list or tuple of Tensors
             representing the previous policy_state.
           seed: Seed to use if action performs sampling (optional).

         Returns:
           A `PolicyStep` named tuple containing:
             `action`: An action Tensor matching the `action_spec()`.
             `state`: A policy state tensor to be fed into the next call to action.
             `info`: Optional side information such as action log probabilities.
           A `PolicyStep` named tuple containing:
             `action`: A Logit Tensor representing Q values`.
             `state`: A policy state tensor to be fed into the next call to action.
             `info`: Optional side information such as action log probabilities.
         """

        raw_action, transformed_actions, action_logits = self._feedforward(
            time_step, policy_state, seed)

        def _to_distribution(action_or_distribution):
            if isinstance(action_or_distribution, tf.Tensor):
                # This is an action tensor, so wrap it in a deterministic distribution.
                return tfp.distributions.Deterministic(
                    loc=action_or_distribution)
            return action_or_distribution

        combined_actions = dict()
        combined_actions[self._raw_action_key] = raw_action
        combined_actions[self._transformed_action_key] = transformed_actions

        q_distributions = [
            shifted_categorical.ShiftedCategorical(logits=logit,
                                                   dtype=tf.int32,
                                                   shift=0)
            for logit in action_logits
        ]
        distribution_steps = [
            policy_step.PolicyStep(q_dist, policy_state)
            for q_dist in q_distributions
        ]

        actions = tf.nest.map_structure(_to_distribution, combined_actions)
        action_step = policy_step.PolicyStep(actions, policy_state)
        seed_stream = tfp.distributions.SeedStream(
            seed=seed, salt='sc2_sequential_policy')
        actions = tf.nest.map_structure(lambda d: d.sample(seed=seed_stream()),
                                        action_step.action)

        return action_step._replace(action=actions), distribution_steps
Exemplo n.º 7
0
    def _distribution(self, time_step, policy_state):
        # In DQN, we always either take a uniformly random action, or the action
        # with the highest Q-value. However, to support more complicated policies,
        # we expose all Q-values as a categorical distribution with Q-values as
        # logits, and apply the GreedyPolicy wrapper in dqn_agent.py to select the
        # action with the highest Q-value.
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        network_observation = time_step.observation

        if observation_and_action_constraint_splitter is not None:
            network_observation, mask = observation_and_action_constraint_splitter(
                network_observation)

        q_values, policy_state = self._q_network(network_observation,
                                                 time_step.step_type,
                                                 policy_state)

        # TODO(b/122314058): Validate and enforce that sampling distributions
        # created with the q_network logits generate the right action shapes. This
        # is curretly patching the problem.

        # If the action spec says each action should be shaped (1,), add another
        # dimension so the final shape is (B, 1, A), where A is the number of
        # actions. This will make Categorical emit events shaped (B, 1) rather than
        # (B,). Using axis -2 to allow for (B, T, 1, A) shaped q_values.
        if self._flat_action_spec.shape.rank == 1:
            q_values = tf.expand_dims(q_values, -2)

        logits = q_values

        if observation_and_action_constraint_splitter is not None:
            # Expand the mask as needed in the same way as q_values above.
            if self._flat_action_spec.shape.rank == 1:
                mask = tf.expand_dims(mask, -2)

            # Overwrite the logits for invalid actions to -inf.
            neg_inf = tf.constant(-np.inf, dtype=logits.dtype)
            logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits,
                                        neg_inf)

        # TODO(kbanoop): Handle distributions over nests.
        if self._flat_action_spec.minimum != 0:
            distribution = shifted_categorical.ShiftedCategorical(
                logits=logits,
                dtype=self._flat_action_spec.dtype,
                shift=self._flat_action_spec.minimum)
        else:
            distribution = tfp.distributions.Categorical(
                logits=logits, dtype=self._flat_action_spec.dtype)
        distribution = tf.nest.pack_sequence_as(self._action_spec,
                                                [distribution])
        return policy_step.PolicyStep(distribution, policy_state)
Exemplo n.º 8
0
  def testShiftedSampling(self):
    distribution = shifted_categorical.ShiftedCategorical(
        probs=[0.1, 0.8, 0.1], shift=2)
    sample = distribution.sample()
    log_prob = distribution.log_prob(sample)
    results = []

    with self.cached_session() as s:
      for _ in range(100):
        value, _ = s.run([sample, log_prob])
        results.append(value)

    results = np.array(results, dtype=np.int32)
    self.assertTrue(np.all(results >= 2))
    self.assertTrue(np.all(results <= 4))
Exemplo n.º 9
0
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        q_values = self.func(observation)
        logits = q_values

        if self.q_policy._flat_action_spec.minimum != 0:
            distribution = shifted_categorical.ShiftedCategorical(
                logits=logits,
                dtype=self.q_policy._flat_action_spec.dtype,
                shift=self.q_policy._flat_action_spec.minimum)
        else:
            distribution = tfp.distributions.Categorical(
                logits=logits, dtype=self.q_policy._flat_action_spec.dtype)

        distribution = tf.nest.pack_sequence_as(self.q_policy._action_spec,
                                                [distribution])
        return policy_step.PolicyStep(distribution, policy_state)
Exemplo n.º 10
0
  def _distribution(self, time_step, policy_state):
    # In DQN, we always either take a uniformly random action, or the action
    # with the highest Q-value. However, to support more complicated policies,
    # we expose all Q-values as a categorical distribution with Q-values as
    # logits, and apply the GreedyPolicy wrapper in dqn_agent.py to select the
    # action with the highest Q-value.
    observation_and_action_constraint_splitter = (
        self.observation_and_action_constraint_splitter)
    network_observation = time_step.observation

    if observation_and_action_constraint_splitter is not None:
      network_observation, mask = observation_and_action_constraint_splitter(
          network_observation)

    q_values, policy_state = self._q_network(
        network_observation, network_state=policy_state,
        step_type=time_step.step_type)

    logits = q_values

    if observation_and_action_constraint_splitter is not None:
      # Overwrite the logits for invalid actions to logits.dtype.min.
      almost_neg_inf = tf.constant(logits.dtype.min, dtype=logits.dtype)
      logits = tf.compat.v2.where(
          tf.cast(mask, tf.bool), logits, almost_neg_inf)

    if self._flat_action_spec.minimum != 0:
      distribution = shifted_categorical.ShiftedCategorical(
          logits=logits,
          dtype=self._flat_action_spec.dtype,
          shift=self._flat_action_spec.minimum)
    else:
      distribution = tfp.distributions.Categorical(
          logits=logits,
          dtype=self._flat_action_spec.dtype)

    distribution = tf.nest.pack_sequence_as(self._action_spec, [distribution])
    return policy_step.PolicyStep(distribution, policy_state)
Exemplo n.º 11
0
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        if self.observation_and_action_constraint_splitter is not None:
            observation, _ = self.observation_and_action_constraint_splitter(
                observation)

        predictions, policy_state = self._reward_network(
            observation, time_step.step_type, policy_state)
        batch_size = tf.shape(predictions)[0]

        if isinstance(self._reward_network,
                      heteroscedastic_q_network.HeteroscedasticQNetwork):
            predicted_reward_values = predictions.q_value_logits
        else:
            predicted_reward_values = predictions

        predicted_reward_values.shape.with_rank_at_least(2)
        predicted_reward_values.shape.with_rank_at_most(3)
        if predicted_reward_values.shape[
                -1] is not None and predicted_reward_values.shape[
                    -1] != self._expected_num_actions:
            raise ValueError(
                'The number of actions ({}) does not match the reward_network output'
                ' size ({}).'.format(self._expected_num_actions,
                                     predicted_reward_values.shape[1]))

        mask = constr.construct_mask_from_multiple_sources(
            time_step.observation,
            self._observation_and_action_constraint_splitter,
            self._constraints, self._expected_num_actions)

        # Apply the temperature scaling, needed for Boltzmann exploration.
        logits = predicted_reward_values / self._get_temperature_value()

        # Apply masking if needed. Overwrite the logits for invalid actions to
        # logits.dtype.min.
        if mask is not None:
            almost_neg_inf = tf.constant(logits.dtype.min, dtype=logits.dtype)
            logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits,
                                        almost_neg_inf)

        if self._action_offset != 0:
            distribution = shifted_categorical.ShiftedCategorical(
                logits=logits,
                dtype=self._action_spec.dtype,
                shift=self._action_offset)
        else:
            distribution = tfp.distributions.Categorical(
                logits=logits, dtype=self._action_spec.dtype)

        actions = distribution.sample()

        bandit_policy_values = tf.fill(
            [batch_size, 1], policy_utilities.BanditPolicyType.BOLTZMANN)

        if self._accepts_per_arm_features:
            # Saving the features for the chosen action to the policy_info.
            def gather_observation(obs):
                return tf.gather(params=obs, indices=actions, batch_dims=1)

            chosen_arm_features = tf.nest.map_structure(
                gather_observation,
                observation[bandit_spec_utils.PER_ARM_FEATURE_KEY])
            policy_info = policy_utilities.PerArmPolicyInfo(
                log_probability=distribution.log_prob(actions)
                if policy_utilities.InfoFields.LOG_PROBABILITY
                in self._emit_policy_info else (),
                predicted_rewards_mean=(
                    predicted_reward_values
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()),
                chosen_arm_features=chosen_arm_features)
        else:
            policy_info = policy_utilities.PolicyInfo(
                log_probability=distribution.log_prob(actions)
                if policy_utilities.InfoFields.LOG_PROBABILITY
                in self._emit_policy_info else (),
                predicted_rewards_mean=(
                    predicted_reward_values
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()))

        return policy_step.PolicyStep(
            tfp.distributions.Deterministic(loc=actions), policy_state,
            policy_info)
Exemplo n.º 12
0
    def _distribution(self, time_step, policy_state):
        # In DQN, we always either take a uniformly random action, or the action
        # with the highest Q-value. However, to support more complicated policies,
        # we expose all Q-values as a categorical distribution with Q-values as
        # logits, and apply the GreedyPolicy wrapper in dqn_agent.py to select the
        # action with the highest Q-value.

        neg_inf = tf.constant(-np.inf, dtype=tf.float32)
        previous_action = time_step.observation[self._previous_action_key] if self._use_previous_action else None
        func_action = time_step.observation[self._func_action_key] if self._use_previous_action else None

        assert (self._available_actions_key in time_step.observation) == \
               (self._available_actions_key in self._time_step_spec.observation)
        available_actions = None
        if self._available_actions_key in time_step.observation:
            available_actions = time_step.observation[self._available_actions_key]

        # time_step.observation is a dict of screen, minimap and structured info.
        time_step_obs = time_step.observation
        (spatial_q_values, structured_q_values), q_policy_state = self._mixed_q_network(
            time_step_obs, time_step.step_type, policy_state)
        assert isinstance(spatial_q_values, dict) and isinstance(structured_q_values, dict)
        if available_actions is not None:
            available_actions = tf.convert_to_tensor(available_actions)
            assert all([available_actions.shape == t.shape for t in structured_q_values.values()])
            structured_q_values = tf.nest.map_structure(lambda x: tf.where(tf.equal(available_actions, 1), x, neg_inf),
                                                        structured_q_values)

        # TODO(b/122314058): Validate and enforce that sampling distributions
        # created with the q_network logits generate the right action shapes. This
        # is curretly patching the problem.

        # If the action spec says each action should be shaped (1,), add another
        # dimension so the final shape is (B, 1, A), where A is the number of
        # actions. This will make Categorical emit events shaped (B, 1) rather than
        # (B,). Using axis -2 to allow for (B, T, 1, A) shaped q_values.
        assert all([s.shape.ndims == 1 for s in self._flat_action_spec]) \
               or all([s.shape.ndims == 0 for s in self._flat_action_spec]), \
            "all action specs' ndims should be consistently 1 or 0"

        if previous_action is not None:
             tf.assert_equal(tf.add_n([spatial_q_values[k].shape[-1] if k in spatial_q_values else 0
                                       for k in self._spatial_names]
                                 + [structured_q_values[k].shape[-1] if k in structured_q_values else 0
                                   for k in self._structured_names]),
                        self._func_arg_mask.shape[-1])

        if func_action is None:
            discrete_func_action = None
        else:
            assert isinstance(previous_action, dict)
            discrete_func_action = func_action[self._discrete_action_key] \
                                        if self._discrete_action_key in func_action else None
        spatial_q_values, structured_q_values = self._mask_logits(spatial_q_values, structured_q_values,
                                                                  discrete_func_action)

        if self._flat_action_spec[0].shape.ndims == 1:
            for k, v in spatial_q_values.items():
                spatial_q_values[k] = tf.expand_dims(v, -2)
            for k, v in structured_q_values.items():
                structured_q_values[k] = tf.expand_dims(v, -2)

        # concatenate q_values into single represnetation
        q_values = []
        for name in self._spatial_names:
            if name in spatial_q_values:
                q_values.append(spatial_q_values[name])
        for name in self._structured_names:
            if name in structured_q_values:
                q_values.append(structured_q_values[name])
        q_values = tf.concat(q_values, axis=-1)

        ##TODO: mask_split_fn needs to be investigated
        # logits = spatial_q_values[self._spatial_names[0]]
        # mask_split_fn = self._q_network.mask_split_fn
        #
        # neg_inf = tf.constant(-np.inf, dtype=tf.float32)
        # if mask_split_fn:
        #     _, mask = mask_split_fn(time_step.observation)
        #
        #     # Expand the mask as needed in the same way as q_values above.
        #     if self._flat_action_spec.shape.ndims == 1:
        #         mask = tf.expand_dims(mask, -2)
        #
        #     # Overwrite the logits for invalid actions to -inf.
        #     logits = tf.compat.v2.where(tf.cast(mask, tf.bool), logits, neg_inf)

        # TODO(kbanoop): Handle distributions over nests.
        q_distribution = shifted_categorical.ShiftedCategorical(
            logits=q_values, dtype=tf.int32, shift=0)

        # q_distributions = dict()
        # for k, v in spatial_q_values.items():
        #     print(v.shape)
        #     q_distributions[k] = shifted_categorical.ShiftedCategorical(
        #         logits=v,
        #         dtype=self._action_spec[k].dtype,
        #         shift=self._action_spec[k].minimum[0])
        # for k, v in structured_q_values.items():
        #     print(v.shape)
        #     q_distributions[k] = shifted_categorical.ShiftedCategorical(
        #         logits=v,
        #         dtype=self._action_spec[k].dtype,
        #         shift=self._action_spec[k].minimum[0])
        # q_distribution = tf.nest.pack_sequence_as(self._action_spec, [q_distribution])

        return policy_step.PolicyStep(q_distribution, q_policy_state)