コード例 #1
0
    def _compute_gradients(self, optimizer, loss, agent_vars, gate_grads=True):
        head_loss = loss["head_loss"]
        z_loss = loss["z_loss"]

        # Get the Bootsrapped heads and conv net gradients
        net_grads = DQN_IDS._compute_gradients(self,
                                               optimizer,
                                               head_loss,
                                               agent_vars,
                                               gate_grads=False)

        # Get the train op for the distributional FC layers
        z_vars = tf_utils.scope_vars(agent_vars,
                                     scope='agent_net/distribution_value')
        z_grads = C51._compute_gradients(self,
                                         optimizer,
                                         z_loss,
                                         z_vars,
                                         gate_grads=False)

        grads = net_grads + z_grads

        if gate_grads:
            grads = tf_utils.gate_gradients(grads)

        return grads
コード例 #2
0
    def _compute_loss(self, estimate, target, name):
        q, logits_z = estimate["q_values"], estimate["logits"]
        target_q, target_p = target["target_q"], target["target_p"]

        head_loss = DQN_IDS._compute_loss(self, q, target_q, name)
        z_loss = C51._compute_loss(self, logits_z, target_p, "train/z_loss")

        return dict(head_loss=head_loss, z_loss=z_loss)
コード例 #3
0
  def _compute_loss(self, estimate, target, name):
    q, z                = estimate["q_values"], estimate["quantiles"]
    target_q, target_z  = target["target_q"], target["target_z"]


    head_loss = DQN_IDS._compute_loss(self, q, target_q, name)
    z_loss    = QRDQN._compute_loss(self, z, target_z, "train/z_loss")

    return dict(head_loss=head_loss, z_loss=z_loss)
コード例 #4
0
 def _compute_estimate(self, agent_net):
     """Get the Q value for the selected action
 Args:
   agent_net: tuple of `tf.Tensor`s. Output from the agent network. Shapes:
     `[batch_size, n_heads, n_actions]` and `[batch_size, n_actions, N]`
 Returns:
   Tuple of `tf.Tensor`s of shapes `[batch_size, n_heads]` and `[batch_size, N]`
 """
     q, z = agent_net["q_values"], agent_net["logits"]
     q = DQN_IDS._compute_estimate(self, q)  # out: [None, n_heads]
     z = C51._compute_estimate(self, z)  # logits; out: [None, N]
     return dict(q_values=q, logits=z)
コード例 #5
0
 def _compute_target(self, target_net):
     """Compute the backups
 Args:
   target_net: tuple of `tf.Tensor`s. Output from the target network. Shapes:
     `[batch_size, n_heads, n_actions]` and `[batch_size, n_actions, N]`
 Returns:
   Tuple of `tf.Tensor`s of shapes `[batch_size, n_heads]` and `[batch_size, N]`
 """
     target_q, target_z = target_net["q_values"], target_net["logits"]
     # DQN_IDS call to self._select_target resolves to the C51_IDS._select_target()
     backup_q = DQN_IDS._compute_target(self, target_q)
     # NOTE: Do NOT call C51._compute_target(self, target_z) - call to self._select_target()
     # will resolve to C51_IDS._select_target() - incorrect
     target_z = C51._select_target(self, target_z)
     backup_z = C51._compute_backup(self, target_z)
     backup_z = tf.stop_gradient(backup_z)
     return dict(target_q=backup_q, target_p=backup_z)
コード例 #6
0
  def _act_train(self, agent_net, name):
    # agent_net tuple of shapes: [None, n_heads, n_actions], [None, n_actions, N]

    z_var     = self._compute_z_variance(agent_net["quantiles"], normalize=True)  # [None, n_actions]
    self.rho2 = tf.maximum(z_var, 0.25)

    action    = DQN_IDS._act_train(self, agent_net["q_values"], name)

    # Add debugging data for TB
    tf.summary.histogram("debug/a_rho2", self.rho2)
    tf.summary.scalar("debug/z_var", tf.reduce_mean(z_var))

    # Append the plottable tensors for episode recordings
    p_rho2  = tf.identity(self.rho2[0], name="plot/train/rho2")
    p_a     = self.plot_conf.true_train_spec["train_actions"]["a_mean"]["a"]
    self.plot_conf.true_train_spec["train_actions"]["a_rho2"] = dict(height=p_rho2, a=p_a)

    return action
コード例 #7
0
 def _act_eval(self, agent_net, name):
     return DQN_IDS._act_eval(self, agent_net["q_values"], name)