def merge_with_rotation_data(self, real, fake, real_labels, fake_labels,
                                 num_rot_examples):
        """Returns the original data concatenated with the rotated version."""

        # Put all rotation angles in a single batch, the first batch_size are
        # the original up-right images, followed by rotated_batch_size * 3
        # rotated images with 3 different angles. For NUM_ROTATIONS=4 and
        # num_rot_examples=2 we have labels_rotated [0, 0, 1, 1, 2, 2, 3, 3].
        real_to_rot, fake_to_rot = (real[-num_rot_examples:],
                                    fake[-num_rot_examples:])
        real_rotated = utils.rotate_images(real_to_rot,
                                           rot90_scalars=(1, 2, 3))
        fake_rotated = utils.rotate_images(fake_to_rot,
                                           rot90_scalars=(1, 2, 3))
        all_features = tf.concat([real, real_rotated, fake, fake_rotated], 0)
        all_labels = None
        if self.conditional:
            real_rotated_labels = tf.tile(real_labels[-num_rot_examples:],
                                          [3, 1])
            fake_rotated_labels = tf.tile(fake_labels[-num_rot_examples:],
                                          [3, 1])
            all_labels = tf.concat([
                real_labels, real_rotated_labels, fake_labels,
                fake_rotated_labels
            ], 0)
        return all_features, all_labels
示例#2
0
    def create_loss(self, features, labels, params, is_training=True):
        """Build the loss tensors for discriminator and generator.

    This method will set self.d_loss and self.g_loss.

    Args:
      features: Optional dictionary with inputs to the model ("images" should
          contain the real images and "z" the noise for the generator).
      labels: Tensor will labels. These are class indices. Use
          self._get_one_hot_labels(labels) to get a one hot encoded tensor.
      params: Dictionary with hyperparameters passed to TPUEstimator.
          Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`,
          `tpu_context`. `batch_size` is the batch size for this core.
      is_training: If True build the model in training mode. If False build the
          model for inference mode (e.g. use trained averages for batch norm).

    Raises:
      ValueError: If set of meta/hyper parameters is not supported.
    """
        images = features["images"]  # Input images.
        generated = features["generated"]  # Fake images.
        if self.conditional:
            y = self._get_one_hot_labels(labels)
            sampled_y = self._get_one_hot_labels(features["sampled_labels"])
        else:
            y = None
            sampled_y = None
            all_y = None

        # Batch size per core.
        bs = images.shape[0].value
        num_replicas = params[
            "context"].num_replicas if "context" in params else 1
        assert self._rotated_batch_size % num_replicas == 0
        # Rotated batch size per core.
        rotated_bs = self._rotated_batch_size // num_replicas
        assert rotated_bs % 4 == 0
        # Number of images to rotate. Each images gets rotated 3 times.
        num_rotated_examples = rotated_bs // 4
        logging.info(
            "num_replicas=%s, bs=%s, rotated_bs=%s, "
            "num_rotated_examples=%s, params=%s", num_replicas, bs, rotated_bs,
            num_rotated_examples, params)

        # Augment the images with rotation.
        if "rotation" in self._self_supervision:
            # Put all rotation angles in a single batch, the first batch_size are
            # the original up-right images, followed by rotated_batch_size * 3
            # rotated images with 3 different angles.
            assert num_rotated_examples <= bs, (num_rotated_examples, bs)
            images_rotated = utils.rotate_images(
                images[-num_rotated_examples:], rot90_scalars=(1, 2, 3))
            generated_rotated = utils.rotate_images(
                generated[-num_rotated_examples:], rot90_scalars=(1, 2, 3))
            # Labels for rotation loss (unrotated and 3 rotated versions). For
            # NUM_ROTATIONS=4 and num_rotated_examples=2 this is:
            # [0, 0, 1, 1, 2, 2, 3, 3]
            rotate_labels = tf.constant(
                np.repeat(np.arange(NUM_ROTATIONS, dtype=np.int32),
                          num_rotated_examples))
            rotate_labels_onehot = tf.one_hot(rotate_labels, NUM_ROTATIONS)
            all_images = tf.concat(
                [images, images_rotated, generated, generated_rotated], 0)
            if self.conditional:
                y_rotated = tf.tile(y[-num_rotated_examples:], [3, 1])
                sampled_y_rotated = tf.tile(y[-num_rotated_examples:], [3, 1])
                all_y = tf.concat([y, y_rotated, sampled_y, sampled_y_rotated],
                                  0)
        else:
            all_images = tf.concat([images, generated], 0)
            if self.conditional:
                all_y = tf.concat([y, sampled_y], axis=0)

        # Compute discriminator output for real and fake images in one batch.
        d_all, d_all_logits, c_all_logits = self.discriminator_with_rotation_head(
            all_images, y=all_y, is_training=is_training)
        d_real, d_fake = tf.split(d_all, 2)
        d_real_logits, d_fake_logits = tf.split(d_all_logits, 2)
        c_real_logits, c_fake_logits = tf.split(c_all_logits, 2)

        # Separate the true/fake scores from whole rotation batch.
        d_real_logits = d_real_logits[:bs]
        d_fake_logits = d_fake_logits[:bs]
        d_real = d_real[:bs]
        d_fake = d_fake[:bs]

        self.d_loss, _, _, self.g_loss = loss_lib.get_losses(
            d_real=d_real,
            d_fake=d_fake,
            d_real_logits=d_real_logits,
            d_fake_logits=d_fake_logits)

        penalty_loss = penalty_lib.get_penalty_loss(
            x=images,
            x_fake=generated,
            y=y,
            is_training=is_training,
            discriminator=self.discriminator,
            architecture=self._architecture)
        self.d_loss += self._lambda * penalty_loss

        # Add rotation augmented loss.
        if "rotation" in self._self_supervision:
            # We take an even pieces for all rotation angles
            assert len(c_real_logits.shape.as_list()) == 2, c_real_logits.shape
            assert len(c_fake_logits.shape.as_list()) == 2, c_fake_logits.shape
            c_real_logits = c_real_logits[-rotated_bs:]
            c_fake_logits = c_fake_logits[-rotated_bs:]
            preds_onreal = tf.cast(tf.argmax(c_real_logits, -1),
                                   rotate_labels.dtype)
            accuracy = tf.reduce_mean(
                tf.cast(tf.equal(rotate_labels, preds_onreal), tf.float32))
            c_real_probs = tf.nn.softmax(c_real_logits)
            c_fake_probs = tf.nn.softmax(c_fake_logits)
            c_real_loss = -tf.reduce_mean(
                tf.reduce_sum(
                    rotate_labels_onehot * tf.log(c_real_probs + 1e-10), 1))
            c_fake_loss = -tf.reduce_mean(
                tf.reduce_sum(
                    rotate_labels_onehot * tf.log(c_fake_probs + 1e-10), 1))
            if self._self_supervision == "rotation_only":
                self.d_loss *= 0.0
                self.g_loss *= 0.0
            self.d_loss += c_real_loss * self._weight_rotation_loss_d
            self.g_loss += c_fake_loss * self._weight_rotation_loss_g
        else:
            c_real_loss = 0.0
            c_fake_loss = 0.0
            accuracy = tf.zeros([])

        self._tpu_summary.scalar("loss/c_real_loss", c_real_loss)
        self._tpu_summary.scalar("loss/c_fake_loss", c_fake_loss)
        self._tpu_summary.scalar("accuracy/d_rotation", accuracy)
        self._tpu_summary.scalar("loss/penalty", penalty_loss)