def update(self, expert_dataset_iter, replay_buffer_iter): """Performs a single training step for critic and actor. Args: expert_dataset_iter: An tensorflow graph iteratable over expert data. replay_buffer_iter: An tensorflow graph iteratable over replay buffer. """ expert_states, expert_actions, _ = next(expert_dataset_iter) policy_states, policy_actions, _, _, _ = next(replay_buffer_iter)[0] policy_inputs = tf.concat([policy_states, policy_actions], -1) expert_inputs = tf.concat([expert_states, expert_actions], -1) alpha = tf.random.uniform(shape=(policy_inputs.get_shape()[0], 1)) inter = alpha * policy_inputs + (1 - alpha) * expert_inputs with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(self.discriminator.variables) policy_output = self.discriminator(policy_inputs) expert_output = self.discriminator(expert_inputs) # Using the standard value for label smoothing instead of 0.25. classification_loss = tfgan_losses.modified_discriminator_loss( expert_output, policy_output, label_smoothing=0.0) with tf.GradientTape(watch_accessed_variables=False) as tape2: tape2.watch(inter) output = self.discriminator(inter) grad = tape2.gradient(output, [inter])[0] grad_penalty = tf.reduce_mean(tf.pow(tf.norm(grad, axis=-1) - 1, 2)) total_loss = classification_loss + self.grad_penalty_coeff * grad_penalty grads = tape.gradient(total_loss, self.discriminator.variables) self.optimizer.apply_gradients(zip(grads, self.discriminator.variables)) self.avg_classification_loss(classification_loss) self.avg_gp_loss(grad_penalty) self.avg_total_loss(total_loss) if tf.equal(self.optimizer.iterations % self.log_interval, 0): tf.summary.scalar( 'train gail/classification loss', self.avg_classification_loss.result(), step=self.optimizer.iterations) self.avg_classification_loss.reset_states() tf.summary.scalar( 'train gail/gradient penalty', self.avg_gp_loss.result(), step=self.optimizer.iterations) self.avg_gp_loss.reset_states() tf.summary.scalar( 'train gail/loss', self.avg_total_loss.result(), step=self.optimizer.iterations) self.avg_total_loss.reset_states()
def update(self, expert_dataset_iter, replay_buffer_iter): """Performs a single training step for critic and actor. Args: expert_dataset_iter: An tensorflow graph iteratable over expert data. replay_buffer_iter: An tensorflow graph iteratable over replay buffer. """ expert_states, expert_actions, _ = next(expert_dataset_iter) policy_states, policy_actions, _, _, _ = next(replay_buffer_iter)[0] policy_inputs = tf.concat([policy_states, policy_actions], -1) expert_inputs = tf.concat([expert_states, expert_actions], -1) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(self.discriminator.variables) inputs = tf.concat([policy_inputs, expert_inputs], 0) outputs = self.discriminator(inputs) policy_output, expert_output = tf.split(outputs, num_or_size_splits=2, axis=0) # Using the standard value for label smoothing instead of 0.25. classification_loss = tfgan_losses.modified_discriminator_loss( expert_output, policy_output, label_smoothing=0.0) grads = tape.gradient(classification_loss, self.discriminator.variables) self.optimizer.apply_gradients(zip(grads, self.discriminator.variables)) self.avg_loss(classification_loss) if tf.equal(self.optimizer.iterations % self.log_interval, 0): tf.summary.scalar('train gail/loss', self.avg_loss.result(), step=self.optimizer.iterations) self.avg_loss.reset_states()