Beispiel #1
0
    def _update_actor(self, obs, mask):
        """Updates parameters of critic given samples from the batch.

    Args:
       obs: A tfe.Variable with a batch of observations.
       mask: A tfe.Variable with a batch of masks.
    """
        with tf.GradientTape() as tape:
            if self.use_td3:
                q_pred, _ = self.critic(obs, self.actor(obs))
            else:
                q_pred = self.critic(obs, self.actor(obs))
            if self.use_absorbing_state:
                # Don't update the actor for absorbing states.
                # And skip update if all states are absorbing.
                a_mask = 1.0 - tf.maximum(0, -mask)
                if tf.reduce_sum(a_mask) < 1e-8:
                    return
                actor_loss = -tf.reduce_sum(
                    q_pred * a_mask) / tf.reduce_sum(a_mask)
            else:
                actor_loss = -tf.reduce_mean(q_pred)

        grads = tape.gradient(actor_loss, self.actor.variables)
        # Clipping makes training more stable.
        grads, _ = tf.clip_by_global_norm(grads, 40.0)
        self.actor_optimizer.apply_gradients(zip(grads, self.actor.variables),
                                             global_step=self.actor_step)

        with contrib_summary.record_summaries_every_n_global_steps(
                100, self.actor_step):
            contrib_summary.scalar('actor/loss',
                                   actor_loss,
                                   step=self.actor_step)
Beispiel #2
0
    def _update_critic_ddpg(self, obs, action, next_obs, reward, mask):
        """Updates parameters of ddpg critic given samples from the batch.

    Args:
       obs: A tfe.Variable with a batch of observations.
       action: A tfe.Variable with a batch of actions.
       next_obs: A tfe.Variable with a batch of next observations.
       reward: A tfe.Variable with a batch of rewards.
       mask: A tfe.Variable with a batch of masks.
    """
        if self.use_absorbing_state:
            # Starting from the goal state we can execute only non-actions.
            a_mask = tf.maximum(0, mask)
            q_next = self.critic_target(next_obs,
                                        self.actor_target(next_obs) * a_mask)
            q_target = reward + self.discount * q_next
        else:
            # Without an absorbing state we assign rewards of 0.
            q_next = self.critic_target(next_obs, self.actor_target(next_obs))
            q_target = reward + self.discount * mask * q_next

        with tf.GradientTape() as tape:
            q_pred = self.critic(obs, action)
            critic_loss = tf.losses.mean_squared_error(q_target, q_pred)

        grads = tape.gradient(critic_loss, self.critic.variables)
        self.critic_optimizer.apply_gradients(zip(grads,
                                                  self.critic.variables),
                                              global_step=self.critic_step)

        with contrib_summary.record_summaries_every_n_global_steps(
                100, self.critic_step):
            contrib_summary.scalar('critic/loss',
                                   critic_loss,
                                   step=self.critic_step)
Beispiel #3
0
def train(model, optimizer, dataset, step_counter, log_interval=None):
    """Trains model on `dataset` using `optimizer`."""
    from tensorflow.contrib import summary as contrib_summary  # pylint: disable=g-import-not-at-top

    start = time.time()
    for (batch, (images, labels)) in enumerate(dataset):
        with contrib_summary.record_summaries_every_n_global_steps(
                10, global_step=step_counter):
            # Record the operations used to compute the loss given the input,
            # so that the gradient of the loss with respect to the variables
            # can be computed.
            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                loss_value = loss(logits, labels)
                contrib_summary.scalar('loss', loss_value)
                contrib_summary.scalar('accuracy',
                                       compute_accuracy(logits, labels))
            grads = tape.gradient(loss_value, model.variables)
            optimizer.apply_gradients(list(zip(grads, model.variables)),
                                      global_step=step_counter)
            if log_interval and batch % log_interval == 0:
                rate = log_interval / (time.time() - start)
                print('Step #%d\tLoss: %.6f (%d steps/sec)' %
                      (batch, loss_value, rate))
                start = time.time()
Beispiel #4
0
    def eval_metrics_host_call_fn(policy_output,
                                  value_output,
                                  pi_tensor,
                                  policy_cost,
                                  value_cost,
                                  l2_cost,
                                  combined_cost,
                                  est_mode=tf.estimator.ModeKeys.TRAIN):
        policy_entropy = -tf.reduce_mean(
            tf.reduce_sum(policy_output * tf.log(policy_output), axis=1))
        # pi_tensor is one_hot when generated from sgfs (for supervised learning)
        # and soft-max when using self-play records. argmax normalizes the two.
        policy_target_top_1 = tf.argmax(pi_tensor, axis=1)
        policy_output_top_1 = tf.argmax(policy_output, axis=1)

        policy_output_in_top3 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=3))

        policy_top_1_confidence = tf.reduce_max(policy_output, axis=1)
        policy_target_top_1_confidence = tf.boolean_mask(
            policy_output,
            tf.one_hot(policy_target_top_1,
                       tf.shape(policy_output)[1]))

        metric_ops = {
            'policy_cost':
            tf.metrics.mean(policy_cost),
            'value_cost':
            tf.metrics.mean(value_cost),
            'l2_cost':
            tf.metrics.mean(l2_cost),
            'policy_entropy':
            tf.metrics.mean(policy_entropy),
            'combined_cost':
            tf.metrics.mean(combined_cost),
            'policy_accuracy_top_1':
            tf.metrics.accuracy(labels=policy_target_top_1,
                                predictions=policy_output_top_1),
            'policy_accuracy_top_3':
            tf.metrics.mean(policy_output_in_top3),
            'policy_top_1_confidence':
            tf.metrics.mean(policy_top_1_confidence),
            'policy_target_top_1_confidence':
            tf.metrics.mean(policy_target_top_1_confidence),
            'value_confidence':
            tf.metrics.mean(tf.abs(value_output)),
        }

        # Create summary ops so that they show up in SUMMARIES collection
        # That way, they get logged automatically during training
        summary_writer = summary.create_file_writer(FLAGS.model_dir)
        with summary_writer.as_default(), \
                summary.record_summaries_every_n_global_steps(FLAGS.summary_steps):
            for metric_name, metric_op in metric_ops.items():
                summary.scalar(metric_name, metric_op[1])

        if est_mode == tf.estimator.ModeKeys.EVAL:
            return metric_ops
        return summary.all_summary_ops()
Beispiel #5
0
            def host_call_fn(gs, loss, lr, ce, bi_list, bo_list, big_list,
                             bog_list):
                """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                # Host call fns are executed params['iterations_per_loop'] times after
                # one TPU loop is finished, setting max_queue value to the same as
                # number of iterations will make the summary writer only flush the data
                # to storage once per loop.
                with summary.create_file_writer(
                        FLAGS.model_dir,
                        max_queue=params['iterations_per_loop']).as_default():
                    with summary.always_record_summaries():
                        summary.scalar('loss', loss[0], step=gs)
                        summary.scalar('learning_rate', lr[0], step=gs)
                        summary.scalar('current_epoch', ce[0], step=gs)

                # TODO record distribution every 1251 steps (steps per epoch)
                with summary.record_summaries_every_n_global_steps(
                        FLAGS.steps_per_eval):
                    index = 0
                    for activ in bi_list:
                        normal_histogram(activ, 'bn-input-' + str(index))
                        log_histogram(activ, 'bn-input-' + str(index))
                        index = index + 1
                    index = 0
                    for activ in bo_list:
                        normal_histogram(activ, 'bn-output-' + str(index))
                        log_histogram(activ, 'bn-output-' + str(index))
                        index = index + 1
                    index = 0
                    for activ in big_list:
                        normal_histogram(activ, 'bn-input-grad-' + str(index))
                        log_histogram(activ, 'bn-input-grad-' + str(index))
                        index = index + 1
                    index = 0
                    for activ in bog_list:
                        normal_histogram(activ, 'bn-output-grad-' + str(index))
                        log_histogram(activ, 'bn-output-grad-' + str(index))
                        index = index + 1
                return summary.all_summary_ops()
Beispiel #6
0
    def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cost,
                                  value_cost, l2_cost, combined_cost, step,
                                  est_mode=tf.estimator.ModeKeys.TRAIN):
        policy_entropy = -tf.reduce_mean(tf.reduce_sum(
            policy_output * tf.log(policy_output), axis=1))
        # pi_tensor is one_hot when generated from sgfs (for supervised learning)
        # and soft-max when using self-play records. argmax normalizes the two.
        policy_target_top_1 = tf.argmax(pi_tensor, axis=1)

        policy_output_in_top1 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=1))
        policy_output_in_top3 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=3))

        policy_top_1_confidence = tf.reduce_max(policy_output, axis=1)
        policy_target_top_1_confidence = tf.boolean_mask(
            policy_output,
            tf.one_hot(policy_target_top_1, tf.shape(policy_output)[1]))

        with tf.variable_scope("metrics"):
            metric_ops = {
                'policy_cost': tf.metrics.mean(policy_cost),
                'value_cost': tf.metrics.mean(value_cost),
                'l2_cost': tf.metrics.mean(l2_cost),
                'policy_entropy': tf.metrics.mean(policy_entropy),
                'combined_cost': tf.metrics.mean(combined_cost),

                'policy_accuracy_top_1': tf.metrics.mean(policy_output_in_top1),
                'policy_accuracy_top_3': tf.metrics.mean(policy_output_in_top3),
                'policy_top_1_confidence': tf.metrics.mean(policy_top_1_confidence),
                'policy_target_top_1_confidence': tf.metrics.mean(
                    policy_target_top_1_confidence),
                'value_confidence': tf.metrics.mean(tf.abs(value_output)),
            }

        if est_mode == tf.estimator.ModeKeys.EVAL:
            return metric_ops

        # NOTE: global_step is rounded to a multiple of FLAGS.summary_steps.
        eval_step = tf.reduce_min(step)

        # Create summary ops so that they show up in SUMMARIES collection
        # That way, they get logged automatically during training
        summary_writer = summary.create_file_writer(FLAGS.work_dir)
        with summary_writer.as_default(), \
                summary.record_summaries_every_n_global_steps(
                    params['summary_steps'], eval_step):
            for metric_name, metric_op in metric_ops.items():
                summary.scalar(metric_name, metric_op[1], step=eval_step)

        # Reset metrics occasionally so that they are mean of recent batches.
        reset_op = tf.variables_initializer(tf.local_variables("metrics"))
        cond_reset_op = tf.cond(
            tf.equal(eval_step % params['summary_steps'], tf.to_int64(1)),
            lambda: reset_op,
            lambda: tf.no_op())

        return summary.all_summary_ops() + [cond_reset_op]
 def log_performance(rewards, actions, tloss, ploss, vloss, entropy):
     with writer.as_default(
     ), summary.record_summaries_every_n_global_steps(10):
         summary.scalar('Perf/Total Reward', tf.reduce_sum(rewards))
         summary.histogram('Actions', actions)
         summary.scalar('Perf/Episode Duration', tf.size(rewards))
         summary.scalar('Perf/Total Loss', tloss)
         summary.scalar('Perf/Policy Loss', tf.reduce_mean(ploss))
         summary.scalar('Perf/Value Loss', tf.reduce_mean(vloss))
         summary.scalar('Perf/Policy Entropy', tf.reduce_mean(entropy))
Beispiel #8
0
 def host_call_fn(global_step, *tensors):
     """Training host call."""
     global_step = global_step[0]
     with contrib_summary.create_file_writer(summary_dir +
                                             '/metrics').as_default():
         with contrib_summary.record_summaries_every_n_global_steps(
                 n=n, global_step=global_step):
             for i, tensor in enumerate(tensors):
                 contrib_summary.scalar(names[i],
                                        tensor[0],
                                        step=global_step)
             return contrib_summary.all_summary_ops()
Beispiel #9
0
 def host_call_fn(global_step, *tensors):
     """Training host call."""
     global_step = global_step[0]
     with contrib_summary.create_file_writer(
             params.output_dir).as_default():
         with contrib_summary.record_summaries_every_n_global_steps(
                 n=params.log_every, global_step=global_step):
             for i, tensor in enumerate(tensors):
                 if 'images' not in names[i]:
                     contrib_summary.scalar(names[i],
                                            tensor[0],
                                            step=global_step)
             return contrib_summary.all_summary_ops()
Beispiel #10
0
 def host_call_fn(gs,
                  loss,
                  lr,
                  mix=None,
                  gt_sources=None,
                  est_sources=None):
     """Training host call. Creates scalar summaries for training metrics.
         This function is executed on the CPU and should not directly reference
         any Tensors in the rest of the `model_fn`. To pass Tensors from the
         model to the `metric_fn`, provide as part of the `host_call`. See
         https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
         for more information.
         Arguments should match the list of `Tensor` objects passed as the second
         element in the tuple passed to `host_call`.
         Args:
           gs: `Tensor with shape `[batch]` for the global_step
           loss: `Tensor` with shape `[batch]` for the training loss.
           lr: `Tensor` with shape `[batch]` for the learning_rate.
           input: `Tensor` with shape `[batch, mix_samples, 1]`
           gt_sources: `Tensor` with shape `[batch, sources_n, output_samples, 1]`
           est_sources: `Tensor` with shape `[batch, sources_n, output_samples, 1]`
         Returns:
           List of summary ops to run on the CPU host.
         """
     gs = gs[0]
     with summary.create_file_writer(
             model_config["model_base_dir"] + os.path.sep +
             str(model_config["experiment_id"])).as_default():
         with summary.always_record_summaries():
             summary.scalar('loss', loss[0], step=gs)
             summary.scalar('learning_rate', lr[0], step=gs)
         if gs % 10000 == 0:
             with summary.record_summaries_every_n_global_steps(
                     model_config["audio_summaries_every_n_steps"]):
                 summary.audio('mix',
                               mix,
                               model_config['expected_sr'],
                               max_outputs=model_config["num_sources"])
                 for source_id in range(gt_sources.shape[1].value):
                     summary.audio('gt_sources_{source_id}'.format(
                         source_id=source_id),
                                   gt_sources[:, source_id, :, :],
                                   model_config['expected_sr'],
                                   max_outputs=model_config["num_sources"])
                     summary.audio('est_sources_{source_id}'.format(
                         source_id=source_id),
                                   est_sources[:, source_id, :, :],
                                   model_config['expected_sr'],
                                   max_outputs=model_config["num_sources"])
     return summary.all_summary_ops()
Beispiel #11
0
    def __init__(self):
        threads = 8
        graph = tf.Graph()
        self.session = tf.Session(graph=graph,
                                  config=tf.ConfigProto(
                                      inter_op_parallelism_threads=threads,
                                      intra_op_parallelism_threads=threads))

        with graph.as_default():
            self.images = tf.placeholder(tf.float32,
                                         shape=[None, 16, 16, 3],
                                         name="images")

            z_dim = 10

            def encoder(img):
                out = tf.layers.flatten(img)
                out = tf.layers.dense(out, 500, activation=tf.nn.relu)
                out = tf.layers.dense(out, 500, activation=tf.nn.relu)
                out = tf.layers.dense(out, z_dim, activation=tf.nn.relu)

                return out

            def decoder(z):
                out = tf.layers.dense(z, 500, activation=tf.nn.relu)
                out = tf.layers.dense(out, 500, activation=tf.nn.relu)
                out = tf.layers.dense(out, 16 * 16 * 3, activation=None)

                return tf.reshape(out, [-1, 16, 16, 3])

            self.z = encoder(self.images)
            self.generated_logits = decoder(self.z)
            self.generated_images = tf.nn.sigmoid(self.generated_logits,
                                                  name="generated_images")

            self.loss = tf.losses.mean_squared_error(self.images,
                                                     self.generated_images)

            global_step = tf.train.create_global_step()

            self.training = tf.train.AdamOptimizer().minimize(
                self.loss, global_step=global_step)

            logdir = "logs/autoencoder-{}-{}".format(
                z_dim,
                datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"))

            summary_writer = tfsum.create_file_writer(logdir,
                                                      flush_millis=10 * 1000)
            with summary_writer.as_default(
            ), tfsum.record_summaries_every_n_global_steps(100):
                self.summaries = [
                    tfsum.scalar("loss", self.loss),
                    tfsum.histogram("latent", self.z)
                ]

            self.generated_images_summary_data = tf.placeholder(
                tf.float32, [None, None, 3])
            with summary_writer.as_default(), tfsum.always_record_summaries():
                self.generated_images_summary = tfsum.image(
                    "generated_image",
                    tf.expand_dims(self.generated_images_summary_data, axis=0))

            init = tf.global_variables_initializer()
            self.session.run(init)

            with summary_writer.as_default():
                tfsum.initialize(session=self.session,
                                 graph=self.session.graph)
Beispiel #12
0
    def _update_critic_td3(self, obs, action, next_obs, reward, mask):
        """Updates parameters of td3 critic given samples from the batch.

    Args:
       obs: A tfe.Variable with a batch of observations.
       action: A tfe.Variable with a batch of actions.
       next_obs: A tfe.Variable with a batch of next observations.
       reward: A tfe.Variable with a batch of rewards.
       mask: A tfe.Variable with a batch of masks.
    """
        # Avoid using tensorflow random functions since it's impossible to get
        # the state of the random number generator used by TensorFlow.
        target_action_noise = np.random.normal(
            size=action.get_shape(), scale=self.policy_noise).astype('float32')
        target_action_noise = contrib_eager_python_tfe.Variable(
            target_action_noise)

        target_action_noise = tf.clip_by_value(target_action_noise,
                                               -self.policy_noise_clip,
                                               self.policy_noise_clip)

        noisy_action_targets = self.actor_target(
            next_obs) + target_action_noise

        clipped_noisy_action_targets = tf.clip_by_value(
            noisy_action_targets, -1, 1)

        if self.use_absorbing_state:
            # Starting from the goal state we can execute only non-actions.
            a_mask = tf.maximum(0, mask)
            q_next1, q_next2 = self.critic_target(
                next_obs, clipped_noisy_action_targets * a_mask)
            q_next = tf.reduce_min(tf.concat([q_next1, q_next2], -1),
                                   -1,
                                   keepdims=True)
            q_target = reward + self.discount * q_next
        else:
            q_next1, q_next2 = self.critic_target(
                next_obs, clipped_noisy_action_targets)
            q_next = tf.reduce_min(tf.concat([q_next1, q_next2], -1),
                                   -1,
                                   keepdims=True)
            q_target = reward + self.discount * mask * q_next

        with tf.GradientTape() as tape:
            q_pred1, q_pred2 = self.critic(obs, action)
            critic_loss = tf.losses.mean_squared_error(
                q_target, q_pred1) + tf.losses.mean_squared_error(
                    q_target, q_pred2)

        grads = tape.gradient(critic_loss, self.critic.variables)
        self.critic_optimizer.apply_gradients(zip(grads,
                                                  self.critic.variables),
                                              global_step=self.critic_step)

        if self.use_absorbing_state:
            with contrib_summary.record_summaries_every_n_global_steps(
                    100, self.critic_step):
                a_mask = tf.maximum(0, -mask)
                if tf.reduce_sum(a_mask).numpy() > 0:
                    contrib_summary.scalar('critic/absorbing_reward',
                                           tf.reduce_sum(reward * a_mask) /
                                           tf.reduce_sum(a_mask),
                                           step=self.critic_step)

        with contrib_summary.record_summaries_every_n_global_steps(
                100, self.critic_step):
            contrib_summary.scalar('critic/loss',
                                   critic_loss,
                                   step=self.critic_step)
 def log_weights(var_list):
     for var in var_list:
         with writer.as_default(
         ), summary.record_summaries_every_n_global_steps(10):
             summary.histogram(var.name, var)
 def log_gradients(gnorms):
     with writer.as_default(
     ), summary.record_summaries_every_n_global_steps(10):
         summary.histogram('Gradient Norms', gnorms)
Beispiel #15
0
    session_name = get_session_name()
    session_logs_path = os.path.join(logs_path, session_name)

    global_step = tf.train.get_or_create_global_step()
    sharpness_multiplier = sharpness_multiplier(50, global_step, 1e6, 1e5)

    data_reader = DataReader(
        "sequence", batch_size, "/localdata/auguste/kitti-raw")
    model = SfMNet()
    optimizer = tf.train.AdamOptimizer(learning_rate=lr)

    # beholder = Beholder(logs_path)
    writer = summary.create_file_writer(session_logs_path, max_queue=0)
    writer.set_as_default()

    with summary.record_summaries_every_n_global_steps(50):

        # Train

        f0, f1 = data_reader.read()
        a = sharpness_multiplier
        depth, points, flow, obj_p, cam_p, pc_t, motion_maps = model(f0, f1, a)
        depth1, points1, flow1, _, _, pc_t1, motion_maps1 = model(f1, f0, a)

        f_loss, f1_t = frame_loss(f0, f1, points)
        f_loss1, _ = frame_loss(f1, f0, points1)

        fb_loss = forward_backward_consistency_loss(depth1, points, pc_t)
        fb_loss1 = forward_backward_consistency_loss(depth, points1, pc_t1)

        ss_loss_d = spatial_smoothness_loss(depth / 100, order=2)
Beispiel #16
0
    def update(self, batch, expert_batch):
        """Updates the WGAN potential function or GAN discriminator.

    Args:
       batch: A batch from training policy.
       expert_batch: A batch from the expert.
    """
        obs = contrib_eager_python_tfe.Variable(
            np.stack(batch.obs).astype('float32'))
        expert_obs = contrib_eager_python_tfe.Variable(
            np.stack(expert_batch.obs).astype('float32'))

        expert_mask = contrib_eager_python_tfe.Variable(
            np.stack(expert_batch.mask).astype('float32'))

        # Since expert trajectories were resampled but no absorbing state,
        # statistics of the states changes, we need to adjust weights accordingly.
        expert_mask = tf.maximum(0, -expert_mask)
        expert_weight = expert_mask / self.subsampling_rate + (1 - expert_mask)

        action = contrib_eager_python_tfe.Variable(
            np.stack(batch.action).astype('float32'))
        expert_action = contrib_eager_python_tfe.Variable(
            np.stack(expert_batch.action).astype('float32'))

        inputs = tf.concat([obs, action], -1)
        expert_inputs = tf.concat([expert_obs, expert_action], -1)

        # Avoid using tensorflow random functions since it's impossible to get
        # the state of the random number generator used by TensorFlow.
        alpha = np.random.uniform(size=(inputs.get_shape()[0], 1))
        alpha = contrib_eager_python_tfe.Variable(alpha.astype('float32'))
        inter = alpha * inputs + (1 - alpha) * expert_inputs

        with tf.GradientTape() as tape:
            output = self.discriminator(inputs)
            expert_output = self.discriminator(expert_inputs)

            with contrib_summary.record_summaries_every_n_global_steps(
                    100, self.disc_step):
                gan_loss = contrib_gan_python_losses_python_losses_impl.modified_discriminator_loss(
                    expert_output,
                    output,
                    label_smoothing=0.0,
                    real_weights=expert_weight)
                contrib_summary.scalar('discriminator/expert_output',
                                       tf.reduce_mean(expert_output),
                                       step=self.disc_step)
                contrib_summary.scalar('discriminator/policy_output',
                                       tf.reduce_mean(output),
                                       step=self.disc_step)

            with tf.GradientTape() as tape2:
                tape2.watch(inter)
                output = self.discriminator(inter)
                grad = tape2.gradient(output, [inter])[0]

            grad_penalty = tf.reduce_mean(tf.pow(
                tf.norm(grad, axis=-1) - 1, 2))

            loss = gan_loss + self.lambd * grad_penalty

        with contrib_summary.record_summaries_every_n_global_steps(
                100, self.disc_step):
            contrib_summary.scalar('discriminator/grad_penalty',
                                   grad_penalty,
                                   step=self.disc_step)

        with contrib_summary.record_summaries_every_n_global_steps(
                100, self.disc_step):
            contrib_summary.scalar('discriminator/loss',
                                   gan_loss,
                                   step=self.disc_step)

        grads = tape.gradient(loss, self.discriminator.variables)

        self.discriminator_optimizer.apply_gradients(
            zip(grads, self.discriminator.variables),
            global_step=self.disc_step)