예제 #1
0
  def update(self, expert_dataset_iter, replay_buffer_iter):
    """Performs a single training step for critic and actor.

    Args:
      expert_dataset_iter: An tensorflow graph iteratable over expert data.
      replay_buffer_iter: An tensorflow graph iteratable over replay buffer.
    """
    expert_states, expert_actions, _ = next(expert_dataset_iter)
    policy_states, policy_actions, _, _, _ = next(replay_buffer_iter)[0]

    policy_inputs = tf.concat([policy_states, policy_actions], -1)
    expert_inputs = tf.concat([expert_states, expert_actions], -1)

    alpha = tf.random.uniform(shape=(policy_inputs.get_shape()[0], 1))
    inter = alpha * policy_inputs + (1 - alpha) * expert_inputs

    with tf.GradientTape(watch_accessed_variables=False) as tape:
      tape.watch(self.discriminator.variables)
      policy_output = self.discriminator(policy_inputs)
      expert_output = self.discriminator(expert_inputs)

      # Using the standard value for label smoothing instead of 0.25.
      classification_loss = tfgan_losses.modified_discriminator_loss(
          expert_output, policy_output, label_smoothing=0.0)

      with tf.GradientTape(watch_accessed_variables=False) as tape2:
        tape2.watch(inter)
        output = self.discriminator(inter)

      grad = tape2.gradient(output, [inter])[0]
      grad_penalty = tf.reduce_mean(tf.pow(tf.norm(grad, axis=-1) - 1, 2))
      total_loss = classification_loss + self.grad_penalty_coeff * grad_penalty

    grads = tape.gradient(total_loss, self.discriminator.variables)

    self.optimizer.apply_gradients(zip(grads, self.discriminator.variables))

    self.avg_classification_loss(classification_loss)
    self.avg_gp_loss(grad_penalty)
    self.avg_total_loss(total_loss)

    if tf.equal(self.optimizer.iterations % self.log_interval, 0):
      tf.summary.scalar(
          'train gail/classification loss',
          self.avg_classification_loss.result(),
          step=self.optimizer.iterations)
      self.avg_classification_loss.reset_states()

      tf.summary.scalar(
          'train gail/gradient penalty',
          self.avg_gp_loss.result(),
          step=self.optimizer.iterations)
      self.avg_gp_loss.reset_states()

      tf.summary.scalar(
          'train gail/loss',
          self.avg_total_loss.result(),
          step=self.optimizer.iterations)
      self.avg_total_loss.reset_states()
예제 #2
0
    def update(self, expert_dataset_iter, replay_buffer_iter):
        """Performs a single training step for critic and actor.

    Args:
      expert_dataset_iter: An tensorflow graph iteratable over expert data.
      replay_buffer_iter: An tensorflow graph iteratable over replay buffer.
    """
        expert_states, expert_actions, _ = next(expert_dataset_iter)
        policy_states, policy_actions, _, _, _ = next(replay_buffer_iter)[0]

        policy_inputs = tf.concat([policy_states, policy_actions], -1)
        expert_inputs = tf.concat([expert_states, expert_actions], -1)

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(self.discriminator.variables)
            inputs = tf.concat([policy_inputs, expert_inputs], 0)
            outputs = self.discriminator(inputs)

            policy_output, expert_output = tf.split(outputs,
                                                    num_or_size_splits=2,
                                                    axis=0)

            # Using the standard value for label smoothing instead of 0.25.
            classification_loss = tfgan_losses.modified_discriminator_loss(
                expert_output, policy_output, label_smoothing=0.0)

        grads = tape.gradient(classification_loss,
                              self.discriminator.variables)

        self.optimizer.apply_gradients(zip(grads,
                                           self.discriminator.variables))

        self.avg_loss(classification_loss)

        if tf.equal(self.optimizer.iterations % self.log_interval, 0):
            tf.summary.scalar('train gail/loss',
                              self.avg_loss.result(),
                              step=self.optimizer.iterations)
            self.avg_loss.reset_states()