예제 #1
0
    def _critic_train_step(self, exp: Experience, state: SacCriticState,
                           action, log_pi):
        if self._is_continuous:
            critic_input = (exp.observation, exp.action)
            target_critic_input = (exp.observation, action)
        else:
            critic_input = exp.observation
            target_critic_input = exp.observation

        critic1, critic1_state = self._critic_network1(
            critic_input, step_type=exp.step_type, network_state=state.critic1)

        critic2, critic2_state = self._critic_network2(
            critic_input, step_type=exp.step_type, network_state=state.critic2)

        target_critic1, target_critic1_state = self._target_critic_network1(
            target_critic_input,
            step_type=exp.step_type,
            network_state=state.target_critic1)

        target_critic2, target_critic2_state = self._target_critic_network2(
            target_critic_input,
            step_type=exp.step_type,
            network_state=state.target_critic2)

        if not self._is_continuous:
            exp_action = tf.cast(exp.action, tf.int32)
            critic1 = tfa_common.index_with_actions(critic1, exp_action)
            critic2 = tfa_common.index_with_actions(critic2, exp_action)
            sampled_action = tf.cast(action, tf.int32)
            target_critic1 = tfa_common.index_with_actions(
                target_critic1, sampled_action)
            target_critic2 = tfa_common.index_with_actions(
                target_critic2, sampled_action)

        target_critic = (tf.minimum(target_critic1, target_critic2) -
                         tf.stop_gradient(tf.exp(self._log_alpha) * log_pi))

        state = SacCriticState(critic1=critic1_state,
                               critic2=critic2_state,
                               target_critic1=target_critic1_state,
                               target_critic2=target_critic2_state)

        info = SacCriticInfo(critic1=critic1,
                             critic2=critic2,
                             target_critic=target_critic)

        return state, info
예제 #2
0
    def _compute_next_q_values(self, target_policies, time_steps):
        """Compute the q value of the next state for TD error computation.

        Args:
          policies: list of target HeteroQPolicy object
          time_steps: A batch of current timesteps
        Returns:
          A tensor of Q values for the given next state.
        """

        q_values_seq, actions = self._sequential_network_activation(target_policies, time_steps)

        action_key = 'raw'
        raw_actions = tf.unstack(actions[action_key], axis=1)
        multi_dim_actions = False
        values = [common.index_with_actions(
            self._append2logits(q_values),
            tf.cast(act, dtype=tf.int32),
            multi_dim_actions=multi_dim_actions) for q_values, act in zip(q_values_seq, raw_actions)]
        # due to dist.mode() in GreedyPolicy, 0 is selected for masked action. So we need to remove -inf
        values = [tf.where(tf.equal(v, NEG_INF), 0.0, v) for v in values]

        # specifically set for sc2 minimaps
        values = [tf.where(tf.equal(time_steps.step_type, StepType.LAST), 0.0, v) for v in values]
        values = [tf.where(tf.equal(time_steps.reward, 1.0), 0.0, v) for v in values]

        values = tf.add_n(values)
        return values
예제 #3
0
  def testTwoOuterDimsUnknownShape(self):
    q_values = tf.convert_to_tensor(
        value=np.array([[[50, 51], [52, 53]]], dtype=np.float32))
    actions = tf.convert_to_tensor(value=np.array([[1, 0]], dtype=np.int32))
    values = common.index_with_actions(q_values, actions)

    self.assertAllClose([[51, 52]], self.evaluate(values))
예제 #4
0
    def compute_loss(self,
                     observations: types.NestedTensor,
                     actions: types.NestedTensor,
                     rewards: types.Tensor,
                     weights: Optional[types.TensorOrArray] = None,
                     training: bool = False) -> types.Tensor:
        """Computes loss for training the constraint network.

    Args:
      observations: A batch of observations.
      actions: A batch of actions.
      rewards: A batch of rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output batch loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether the loss is being used for training.

    Returns:
      loss: A `Tensor` containing the loss for the training step.
    """
        with tf.name_scope('constraint_loss'):
            sample_weights = weights if weights else 1
            predicted_values, _ = self._constraint_network(observations,
                                                           training=training)
            action_predicted_values = common.index_with_actions(
                predicted_values, tf.cast(actions, dtype=tf.int32))
            loss = self._error_loss_fn(
                rewards,
                action_predicted_values,
                sample_weights,
                reduction=tf.compat.v1.losses.Reduction.MEAN)
            return loss
예제 #5
0
  def _compute_next_q_values(self, next_time_steps):
    """Compute the q value of the next state for TD error computation.

    Args:
      next_time_steps: A batch of next timesteps

    Returns:
      A tensor of Q values for the given next state.
    """
    # TODO(b/117175589): Add binary tests for DDQN.
    next_target_q_values, _ = self._target_q_network(
        next_time_steps.observation, next_time_steps.step_type)
    batch_size = (
        next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0])
    dummy_state = self._greedy_policy.get_initial_state(batch_size)
    # Find the greedy actions using our greedy policy. This ensures that masked
    # actions (and other logic) are respected.
    best_next_actions = self._greedy_policy.action(
        next_time_steps, dummy_state).action

    # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
    # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
    multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.ndims > 0
    return common.index_with_actions(
        next_target_q_values,
        best_next_actions,
        multi_dim_actions=multi_dim_actions)
예제 #6
0
 def _apply_action(self, action):
     self._reward = self._environment_dynamics.reward(
         self._observation, self._env_time)
     tf.compat.v1.assign_add(self._env_time,
                             self._environment_dynamics.batch_size)
     return common.index_with_actions(self._reward,
                                      tf.cast(action, dtype=tf.int32))
    def loss(self, observations, actions, rewards, weights=None):
        """Computes loss for reward prediction training.

    Args:
      observations: A batch of observations.
      actions: A batch of actions.
      rewards: A batch of rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output batch loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.

    Returns:
      loss: A `LossInfo` containing the loss for the training step.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        with tf.name_scope('loss'):
            predicted_values, _ = self._reward_network(observations)
            action_predicted_values = common.index_with_actions(
                predicted_values, tf.cast(actions, dtype=tf.int32))

            loss = self._error_loss_fn(rewards, action_predicted_values,
                                       weights if weights else 1)
        return tf_agent.LossInfo(loss, extra=())
예제 #8
0
 def _compute_next_q_values(self, next_time_steps, index, time):
     """spot and flex dispatches of last day (48) + MCP current hour + current state (spot/ flex)"""
     obs_indices = tf.convert_to_tensor(
         list(range(0, 48)) + [48 + time] + [71])
     observation = tf.gather(next_time_steps.observation[index],
                             indices=obs_indices,
                             axis=-1)
     network_observation = observation
     next_target_q_values, _ = self.agent._target_q_network(
         network_observation, next_time_steps.step_type)
     batch_size = (next_target_q_values.shape[0]
                   or tf.shape(next_target_q_values)[0])
     dummy_state = self.agent._target_greedy_policy.get_initial_state(
         batch_size)
     # Find the greedy actions using our target greedy policy. This ensures that
     # action constraints are respected and helps centralize the greedy logic.
     next_individual_qmix_time_step = ts.get_individual_qmix_time_step(
         next_time_steps, index, time)
     greedy_actions = self.agent._target_greedy_policy.action(
         next_individual_qmix_time_step, dummy_state).action
     # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
     # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
     multi_dim_actions = tf.nest.flatten(
         self.agent._action_spec)[0].shape.rank > 0
     return common.index_with_actions(next_target_q_values,
                                      greedy_actions,
                                      multi_dim_actions=multi_dim_actions)
예제 #9
0
  def _compute_next_q_values(self, next_time_steps):
    """Compute the q value of the next state for TD error computation.

    Args:
      next_time_steps: A batch of next timesteps

    Returns:
      A tensor of Q values for the given next state.
    """
    network_observation = next_time_steps.observation

    if self._observation_and_action_constraint_splitter is not None:
      network_observation, _ = self._observation_and_action_constraint_splitter(
          network_observation)

    next_target_q_values, _ = self._target_q_network(
        network_observation, next_time_steps.step_type)
    batch_size = (
        next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0])
    dummy_state = self._target_greedy_policy.get_initial_state(batch_size)
    # Find the greedy actions using our target greedy policy. This ensures that
    # action constraints are respected and helps centralize the greedy logic.
    greedy_actions = self._target_greedy_policy.action(
        next_time_steps, dummy_state).action

    # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
    # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
    multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.rank > 0
    return common.index_with_actions(
        next_target_q_values,
        greedy_actions,
        multi_dim_actions=multi_dim_actions)
예제 #10
0
파일: dqn_agent.py 프로젝트: wuzh07/agents
  def _compute_next_q_values(self, next_time_steps, info):
    """Compute the q value of the next state for TD error computation.

    Args:
      next_time_steps: A batch of next timesteps
      info: PolicyStep.info that may be used by other agents inherited from
        dqn_agent.

    Returns:
      A tensor of Q values for the given next state.
    """
    del info
    # TODO(b/117175589): Add binary tests for DDQN.
    network_observation = next_time_steps.observation

    if self._observation_and_action_constraint_splitter is not None:
      network_observation, _ = self._observation_and_action_constraint_splitter(
          network_observation)

    next_target_q_values, _ = self._target_q_network(
        network_observation, step_type=next_time_steps.step_type)
    batch_size = (
        next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0])
    dummy_state = self._policy.get_initial_state(batch_size)
    # Find the greedy actions using our greedy policy. This ensures that action
    # constraints are respected and helps centralize the greedy logic.
    best_next_actions = self._policy.action(next_time_steps, dummy_state).action

    # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
    # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
    multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.rank > 0
    return common.index_with_actions(
        next_target_q_values,
        best_next_actions,
        multi_dim_actions=multi_dim_actions)
예제 #11
0
def compute_munchausen_td_targets(next_q_values, q_target_values, actions,
                                  rewards, multi_dim_actions, discounts, alpha,
                                  entropy_tau):

    next_max_v_values = tf.expand_dims(tf.reduce_max(next_q_values, 1), -1)
    tau_logsum_next = entropy_tau * tf.reduce_logsumexp(
        (next_q_values - next_max_v_values) / entropy_tau, axis=1)
    # batch x actions
    tau_logsum_next = tf.expand_dims(tau_logsum_next, -1)
    tau_logpi_next = next_q_values - next_max_v_values - tau_logsum_next

    pi_target = tf.nn.softmax(next_q_values / entropy_tau, 1)
    # valid_mask shape: (batch_size, )
    q_target = discounts * tf.reduce_sum(
        (pi_target * (next_q_values - tau_logpi_next)), 1)  # * valid_mask

    v_target_max = tf.expand_dims(tf.reduce_max(q_target_values, 1), -1)
    tau_logsum_target = entropy_tau * tf.reduce_logsumexp(
        (q_target_values - v_target_max) / entropy_tau, 1)
    tau_logsum_target = tf.expand_dims(tau_logsum_target, -1)
    tau_logpi_target = q_target_values - v_target_max - tau_logsum_target

    # munchausen addon uses the current state and actions
    munchausen_addon = common.index_with_actions(
        tau_logpi_target, tf.cast(actions, dtype=tf.int32), multi_dim_actions)
    #rewards = reward_scale_factor * next_time_steps.reward
    munchausen_reward = rewards + alpha * tf.clip_by_value(
        munchausen_addon, clip_value_max=0, clip_value_min=-2)
    td_targets = munchausen_reward + q_target

    return tf.stop_gradient(td_targets)
예제 #12
0
    def _compute_q_values(self, policies, time_steps, actions):
        """Compute the q value of the current state/action for TD error computation.

        Args:
          policies: list of HeteroQPolicy object
          time_steps: A batch of current timesteps
          actions: A batch of actions
        Returns:
          A tensor of Q values for the given next state.
        """

        q_values_seq, _ = self._sequential_network_activation(policies, time_steps, actions)

        action_key = 'raw'
        raw_actions = tf.unstack(actions[action_key], axis=1)
        # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
        # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
        multi_dim_actions = False
        values = [common.index_with_actions(
            self._append2logits(q_values),
            tf.cast(act, dtype=tf.int32),
            multi_dim_actions=multi_dim_actions) for q_values, act in zip(q_values_seq, raw_actions)]
        # due to dist.mode() in GreedyPolicy, 0 is selected for masked action. So we need to remove -inf
        values = [tf.where(tf.equal(v, NEG_INF), 0.0, v) for v in values]
        values = tf.add_n(values)

        return values
예제 #13
0
 def _compute_q_values(self, time_steps, actions):
     q_values, _ = self._q_network(time_steps.observation,
                                   time_steps.step_type)
     # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
     # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
     multi_dim_actions = self._action_spec.shape.ndims > 0
     return common.index_with_actions(q_values,
                                      tf.cast(actions, dtype=tf.int32),
                                      multi_dim_actions=multi_dim_actions)
예제 #14
0
 def testTwoOuterDimsUnknownShape(self):
   q_values = tf.placeholder(tf.float32, shape=[None, None, None])
   actions = tf.placeholder(tf.int32, shape=[None, None])
   values = common.index_with_actions(q_values, actions)
   with self.test_session() as sess:
     self.assertAllClose(
         [[51, 52]],
         sess.run(values,
                  feed_dict={q_values: [[[50, 51], [52, 53]]],
                             actions: [[1, 0]]}))
예제 #15
0
 def checkCorrect(self,
                  q_values,
                  actions,
                  expected_values,
                  multi_dim_actions=False):
     q_values = tf.constant(q_values, dtype=tf.float32)
     actions = tf.constant(actions, dtype=tf.int32)
     selected_q_values = common.index_with_actions(q_values, actions,
                                                   multi_dim_actions)
     selected_q_values_ = self.evaluate(selected_q_values)
     self.assertAllClose(selected_q_values_, expected_values)
예제 #16
0
  def _compute_q_values(self, time_steps, actions):
    q_values, _ = self._q_network(time_steps.observation,
                                  time_steps.step_type)
    # Handle action_spec.shape=(), and shape=(1,) by using the
    # multi_dim_actions param.
    multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.ndims > 0
    q_values = common.index_with_actions(
        q_values,
        tf.cast(actions, dtype=tf.int32),
        multi_dim_actions=multi_dim_actions)

    return q_values
예제 #17
0
    def _compute_q_values(self, time_steps, actions):
        network_observation = time_steps.observation

        if self._observation_anc_action_constraint_splitter:
            network_observation, _ = self._observation_anc_action_constraint_splitter(
                network_observation)

        q_values, _ = self._q_network(network_observation,
                                      time_steps.step_type)
        multi_dim_actions = self._action_spec.shape.rank > 0
        return common.index_with_actions(q_values,
                                         tf.cast(actions, dtype=tf.int32),
                                         multi_dim_actions=multi_dim_actions)
예제 #18
0
  def _loss_using_reward_layer(self,
                               observations: types.NestedTensor,
                               actions: types.Tensor,
                               rewards: types.Tensor,
                               weights: Optional[types.Float] = None,
                               training: bool = False) -> tf_agent.LossInfo:
    """Computes loss for reward prediction training.

    Args:
      observations: A batch of observations.
      actions: A batch of actions.
      rewards: A batch of rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output batch loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether the loss is being used for training.

    Returns:
      loss: A `LossInfo` containing the loss for the training step.
    """
    with tf.name_scope('loss'):
      encoded_observation, _ = self._encoding_network(
          observations, training=training)
      encoded_observation = tf.reshape(
          encoded_observation,
          shape=[-1, self._encoding_dim])
      predicted_rewards = self._reward_layer(
          encoded_observation, training=training)
      chosen_actions_predicted_rewards = common.index_with_actions(
          predicted_rewards,
          tf.cast(actions, dtype=tf.int32))

      loss = self._error_loss_fn(rewards,
                                 chosen_actions_predicted_rewards,
                                 weights if weights else 1)
      if self._summarize_grads_and_vars:
        with tf.name_scope('Per_arm_loss/'):
          for k in range(self._num_models):
            loss_mask_for_arm = tf.cast(tf.equal(actions, k), tf.float32)
            loss_for_arm = self._error_loss_fn(
                rewards,
                chosen_actions_predicted_rewards,
                weights=loss_mask_for_arm)
            tf.compat.v2.summary.scalar(
                name='loss_arm_' + str(k),
                data=loss_for_arm,
                step=self.train_step_counter)

    return tf_agent.LossInfo(loss, extra=())
예제 #19
0
  def _compute_q_values(self, time_steps, actions):
    network_observation = time_steps.observation

    if self._observation_and_action_constraint_splitter is not None:
      network_observation, _ = self._observation_and_action_constraint_splitter(
          network_observation)

    q_values, _ = self._q_network(network_observation, time_steps.step_type)
    # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
    # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
    multi_dim_actions = self._action_spec.shape.rank > 0
    return common.index_with_actions(
        q_values,
        tf.cast(actions, dtype=tf.int32),
        multi_dim_actions=multi_dim_actions)
예제 #20
0
    def call(self, trajectory):
        """Update the constraint violations metric.

    Args:
      trajectory: A tf_agents.trajectory.Trajectory

    Returns:
      The arguments, for easy chaining.
    """
        feasibility_prob_all_actions = self._constraint(trajectory.observation)
        feasibility_prob_selected_actions = common.index_with_actions(
            feasibility_prob_all_actions,
            tf.cast(trajectory.action, dtype=tf.int32))
        self.constraint_violations.assign(
            tf.reduce_mean(1.0 - feasibility_prob_selected_actions))
        return trajectory
예제 #21
0
def compute_dpp_td_targets(next_p_values, p_target_values, actions, rewards,
                           multi_dim_actions, discounts, alpha, entropy_tau):

    boltzmann_p = tf.reduce_sum(
        tf.nn.softmax(p_target_values / entropy_tau, axis=1) * p_target_values,
        1)

    p_target_values = common.index_with_actions(
        p_target_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions)

    action_gap = alpha * (p_target_values - boltzmann_p)

    next_boltzmann_p = tf.reduce_sum(
        tf.nn.softmax(next_p_values / entropy_tau, axis=1) * next_p_values, 1)

    td_targets = rewards + discounts * next_boltzmann_p + action_gap

    return tf.stop_gradient(td_targets)
예제 #22
0
    def _compute_next_q_values(self, next_time_steps):
        network_observation = next_time_steps.observation

        if self._observation_anc_action_constraint_splitter:
            network_observation, _ = self._observation_anc_action_constraint_splitter(
                network_observation)

        next_target_q_values, _ = self._target_q_network(
            network_observation, next_time_steps.step_type)
        batch_size = (next_target_q_values.shape[0]
                      or tf.shape(next_target_q_values)[0])
        dummy_state = self._target_greedy_policy.get_initial_state(batch_size)
        greedy_actions = self._target_greedy_policy.action(
            next_time_steps, dummy_state).action
        multi_dim_actions = tf.nest.flatten(
            self._action_spec)[0].shape.rank > 0
        return common.index_with_actions(next_target_q_values,
                                         greedy_actions,
                                         multi_dim_actions=multi_dim_actions)
예제 #23
0
    def _compute_next_q_values(self, next_time_steps):
        """Compute the q value of the next state for TD error computation.

    Args:
      next_time_steps: A batch of next timesteps

    Returns:
      A tensor of Q values for the given next state.
    """
        # TODO(b/117175589): Add binary tests for DDQN.
        next_q_values, _ = self._q_network(next_time_steps.observation,
                                           next_time_steps.step_type)
        best_next_actions = tf.cast(tf.argmax(input=next_q_values, axis=-1),
                                    dtype=tf.int32)
        next_target_q_values, _ = self._target_q_network(
            next_time_steps.observation, next_time_steps.step_type)
        multi_dim_actions = best_next_actions.shape.ndims > 1
        return common.index_with_actions(next_target_q_values,
                                         best_next_actions,
                                         multi_dim_actions=multi_dim_actions)
예제 #24
0
    def _compute_q_values(self, q_network, time_steps, discrete_actions):

        continuous_action_values, _ = self._actor_network(
            time_steps.observation, time_steps.step_type)
        # noisy_target_action_values = tf.nest.map_structure(self._add_noise_to_action,
        #                                              target_action_values)

        time_step_obs = tf.nest.flatten(
            time_steps.observation) + [continuous_action_values]
        if isinstance(q_network, QNetwork):
            time_step_obs = tf.concat(time_step_obs, axis=-1)
        q_values, _ = q_network(time_step_obs, time_steps.step_type)

        # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
        # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
        multi_dim_actions = tf.nest.flatten(
            self._action_spec.q_network)[0].shape.ndims > 0
        return common.index_with_actions(
            q_values,
            tf.cast(discrete_actions, dtype=tf.int32),
            multi_dim_actions=multi_dim_actions), continuous_action_values
예제 #25
0
    def __call__(self, observation, actions=None):
        """Returns the probability of input actions being feasible."""
        predicted_quantiles, _ = self._constraint_network(observation,
                                                          training=False)
        batch_dims = nest_utils.get_outer_shape(
            observation, self._time_step_spec.observation)

        if self._baseline_action_fn is not None:
            baseline_action = self._baseline_action_fn(observation)
            baseline_action.shape.assert_is_compatible_with(batch_dims)
        else:
            baseline_action = tf.zeros(batch_dims, dtype=tf.int32)

        predicted_quantiles_for_baseline_actions = common.index_with_actions(
            predicted_quantiles, tf.cast(baseline_action, dtype=tf.int32))
        predicted_quantiles_for_baseline_actions = self._reshape_and_broadcast(
            predicted_quantiles_for_baseline_actions,
            tf.shape(predicted_quantiles))
        is_satisfied = self._comparator_fn(
            predicted_quantiles, predicted_quantiles_for_baseline_actions)
        return tf.cast(is_satisfied, tf.float32)
예제 #26
0
 def _compute_q_values(self,
                       time_steps,
                       actions,
                       index,
                       time,
                       training=False):
     """spot and flex dispatches of last day (48) + MCP current hour + current state (spot/ flex)"""
     obs_indices = tf.convert_to_tensor(
         list(range(0, 48)) + [48 + time] + [71])
     observation = tf.gather(time_steps.observation[index],
                             indices=obs_indices,
                             axis=-1)
     network_observation = observation
     q_values, _ = self.agent._q_network(network_observation,
                                         time_steps.step_type,
                                         training=training)
     # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
     # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
     multi_dim_actions = self.agent._action_spec.shape.rank > 0
     actions = tf.reshape(actions, [-1, 1])
     return common.index_with_actions(q_values,
                                      tf.cast(actions, dtype=tf.int32),
                                      multi_dim_actions=multi_dim_actions)
예제 #27
0
  def compute_loss(self,
                   observations: types.NestedTensor,
                   actions: types.NestedTensor,
                   rewards: types.Tensor,
                   weights: Optional[types.TensorOrArray] = None,
                   training: bool = False) -> types.Tensor:
    """Computes loss for training the constraint network.

    Args:
      observations: A batch of observations.
      actions: A batch of actions.
      rewards: A batch of rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output batch loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether the loss is being used for training.

    Returns:
      loss: A `Tensor` containing the loss for the training step.
    """
    with tf.name_scope('constraint_loss'):
      sample_weights = weights if weights else 1
      predicted_values, _ = self._constraint_network(
          observations, training=training)
      action_predicted_values = common.index_with_actions(
          predicted_values,
          tf.cast(actions, dtype=tf.int32))
      # Reduction is done outside of the loss function because non-scalar
      # weights with unknown shapes may trigger shape validation that fails
      # XLA compilation.
      return tf.reduce_mean(
          tf.multiply(
              self._error_loss_fn(
                  rewards,
                  action_predicted_values,
                  reduction=tf.compat.v1.losses.Reduction.NONE),
              sample_weights))
예제 #28
0
    def _compute_next_q_values(self, target_q_network, next_time_steps):
        """Compute the q value of the next state for TD error computation.

    Args:
      next_time_steps: A batch of next timesteps

    Returns:
      A tensor of Q values for the given next state.
    """

        next_target_continuous_action_values, _ = self._target_actor_network(
            next_time_steps.observation, next_time_steps.step_type)
        noisy_target_action_values = tf.nest.map_structure(
            self._add_noise_to_action, next_target_continuous_action_values)

        time_step_obs = tf.nest.flatten(
            next_time_steps.observation) + [noisy_target_action_values]
        if isinstance(target_q_network, QNetwork):
            time_step_obs = tf.concat(time_step_obs, axis=-1)
        next_target_q_values, _ = target_q_network(time_step_obs,
                                                   next_time_steps.step_type)
        batch_size = (next_target_q_values.shape[0]
                      or tf.shape(next_target_q_values)[0])
        dummy_state = self._target_greedy_policy.get_initial_state(batch_size)
        # Find the greedy actions using our target greedy policy. This ensures that
        # masked actions are respected and helps centralize the greedy logic.
        greedy_discrete_actions = self._target_greedy_policy.action(
            next_time_steps, dummy_state).action

        # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
        # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
        multi_dim_actions = tf.nest.flatten(
            self._action_spec.q_network)[0].shape.ndims > 0
        return common.index_with_actions(next_target_q_values,
                                         greedy_discrete_actions,
                                         multi_dim_actions=multi_dim_actions)
    def _single_objective_loss(self,
                               objective_idx: int,
                               observations: tf.Tensor,
                               actions: tf.Tensor,
                               single_objective_values: tf.Tensor,
                               weights: types.Tensor = None,
                               training: bool = False) -> tf.Tensor:
        """Computes loss for a single objective.

    Args:
      objective_idx: The index into `self._objective_networks` for a specific
        objective network.
      observations: A batch of observations.
      actions: A batch of actions.
      single_objective_values: A batch of objective values shaped as
        [batch_size] for the objective that the network indexed by
        `objective_idx` is predicting.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output batch loss will be scaled by these weights, and the
        final scalar loss is the mean of these values.
      training: Whether the loss is being used for training.

    Returns:
      loss: A `Tensor` containing the loss for the training step.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        if objective_idx >= self._num_objectives or objective_idx < 0:
            raise ValueError(
                'objective_idx should be between 0 and {}, but is {}'.format(
                    self._num_objectives, objective_idx))
        with tf.name_scope('loss_for_objective_{}'.format(objective_idx)):
            objective_network = self._objective_networks[objective_idx]
            sample_weights = weights if weights is not None else 1
            if self._heteroscedastic[objective_idx]:
                predictions, _ = objective_network(observations,
                                                   training=training)
                predicted_values = predictions.q_value_logits
                predicted_log_variance = predictions.log_variance
                action_predicted_log_variance = common.index_with_actions(
                    predicted_log_variance, tf.cast(actions, dtype=tf.int32))
                sample_weights = sample_weights * 0.5 * tf.exp(
                    -action_predicted_log_variance)
                loss = 0.5 * tf.reduce_mean(action_predicted_log_variance)
                # loss = 1/(2 * var(x)) * (y - f(x))^2 + 1/2 * log var(x)
                # Kendall, Alex, and Yarin Gal. "What Uncertainties Do We Need in
                # Bayesian Deep Learning for Computer Vision?." Advances in Neural
                # Information Processing Systems. 2017. https://arxiv.org/abs/1703.04977
            else:
                predicted_values, _ = objective_network(observations,
                                                        training=training)
                loss = tf.constant(0.0)

            action_predicted_values = common.index_with_actions(
                predicted_values, tf.cast(actions, dtype=tf.int32))

            loss += self._error_loss_fn(single_objective_values,
                                        action_predicted_values,
                                        sample_weights)

        return loss
  def loss(self,
           observations,
           actions,
           rewards,
           weights=None,
           training=False):
    """Computes loss for reward prediction training.

    Args:
      observations: A batch of observations.
      actions: A batch of actions.
      rewards: A batch of rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output batch loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether the loss is being used for training.

    Returns:
      loss: A `LossInfo` containing the loss for the training step.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
    with tf.name_scope('loss'):
      sample_weights = weights if weights else 1
      if self._heteroscedastic:
        predictions, _ = self._reward_network(observations,
                                              training=training)
        predicted_values = predictions.q_value_logits
        predicted_log_variance = predictions.log_variance
        action_predicted_log_variance = common.index_with_actions(
            predicted_log_variance, tf.cast(actions, dtype=tf.int32))
        sample_weights = sample_weights * 0.5 * tf.exp(
            -action_predicted_log_variance)

        loss = 0.5 * tf.reduce_mean(action_predicted_log_variance)
        # loss = 1/(2 * var(x)) * (y - f(x))^2 + 1/2 * log var(x)
        # Kendall, Alex, and Yarin Gal. "What Uncertainties Do We Need in
        # Bayesian Deep Learning for Computer Vision?." Advances in Neural
        # Information Processing Systems. 2017. https://arxiv.org/abs/1703.04977
      else:
        predicted_values, _ = self._reward_network(observations,
                                                   training=training)
        loss = tf.constant(0.0)

      action_predicted_values = common.index_with_actions(
          predicted_values,
          tf.cast(actions, dtype=tf.int32))

      # Apply Laplacian smoothing on the estimated rewards, if applicable.
      if self._laplacian_matrix is not None:
        smoothness_batched = tf.matmul(
            predicted_values,
            tf.matmul(self._laplacian_matrix, predicted_values,
                      transpose_b=True))
        loss += (self._laplacian_smoothing_weight * tf.reduce_mean(
            tf.linalg.tensor_diag_part(smoothness_batched) * sample_weights))

      loss += self._error_loss_fn(
          rewards,
          action_predicted_values,
          sample_weights,
          reduction=tf.compat.v1.losses.Reduction.MEAN)

    return tf_agent.LossInfo(loss, extra=())