예제 #1
0
    def create_loss(self, features, labels, params, is_training=True):
        """Build the loss tensors for discriminator and generator.

    This method will set self.d_loss and self.g_loss.

    Args:
      features: Optional dictionary with inputs to the model ("images" should
          contain the real images and "z" the noise for the generator).
      labels: Tensor will labels. Use
          self._get_one_hot_labels(labels) to get a one hot encoded tensor.
      params: Dictionary with hyperparameters passed to TPUEstimator.
          Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`,
          `tpu_context`. `batch_size` is the batch size for this core.
      is_training: If True build the model in training mode. If False build the
          model for inference mode (e.g. use trained averages for batch norm).

    Raises:
      ValueError: If set of meta/hyper parameters is not supported.
    """
        images = features["images"]  # Real images.
        generated = features["generated"]  # Fake images.
        if self.conditional:
            y = self._get_one_hot_labels(labels)
            sampled_y = self._get_one_hot_labels(features["sampled_labels"])
            all_y = tf.concat([y, sampled_y], axis=0)
        else:
            y = None
            sampled_y = None
            all_y = None

        if self._deprecated_split_disc_calls:
            with tf.name_scope("disc_for_real"):
                d_real, d_real_logits, _ = self.discriminator(
                    images, y=y, is_training=is_training)
            with tf.name_scope("disc_for_fake"):
                d_fake, d_fake_logits, _ = self.discriminator(
                    generated, y=sampled_y, is_training=is_training)
        else:
            # Compute discriminator output for real and fake images in one batch.
            all_images = tf.concat([images, generated], axis=0)
            d_all, d_all_logits, _ = self.discriminator(
                all_images, y=all_y, is_training=is_training)
            d_real, d_fake = tf.split(d_all, 2)
            d_real_logits, d_fake_logits = tf.split(d_all_logits, 2)

        self.d_loss, _, _, self.g_loss = loss_lib.get_losses(
            d_real=d_real,
            d_fake=d_fake,
            d_real_logits=d_real_logits,
            d_fake_logits=d_fake_logits)

        penalty_loss = penalty_lib.get_penalty_loss(
            x=images,
            x_fake=generated,
            y=y,
            is_training=is_training,
            discriminator=self.discriminator)
        self.d_loss += self._lambda * penalty_loss
예제 #2
0
    def create_loss(self, features, labels, params, is_training=True):
        """Build the loss tensors for discriminator and generator.

    This method will set self.d_loss and self.g_loss.

    Args:
      features: Optional dictionary with inputs to the model ("images" should
          contain the real images and "z" the noise for the generator).
      labels: Tensor will labels. These are class indices. Use
          self._get_one_hot_labels(labels) to get a one hot encoded tensor.
      params: Dictionary with hyperparameters passed to TPUEstimator.
          Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`,
          `tpu_context`. `batch_size` is the batch size for this core.
      is_training: If True build the model in training mode. If False build the
          model for inference mode (e.g. use trained averages for batch norm).

    Raises:
      ValueError: If set of meta/hyper parameters is not supported.
    """
        images = features["images"]  # Input images.
        generated = features["generated"]  # Fake images.
        if self.conditional:
            y = self._get_one_hot_labels(labels)
            sampled_y = self._get_one_hot_labels(features["sampled_labels"])
        else:
            y = None
            sampled_y = None
            all_y = None

        # Batch size per core.
        bs = images.shape[0].value

        def augment(imgs):
            imgs = random_crop_and_resize(imgs)
            imgs = random_apply(color_distortion, imgs,
                                self._aug_color_jitter_prob)
            imgs = random_apply(color_drop, imgs, self._aug_color_drop_prob)
            return tf.stop_gradient(imgs)

        aug_images, aug_generated = augment(images), augment(generated)

        # concat all images
        all_images = tf.concat([images, generated, aug_images, aug_generated],
                               0)

        if self.conditional:
            all_y = tf.concat([y, sampled_y, y, sampled_y], axis=0)

        # Compute discriminator output for real and fake images in one batch.

        d_all, d_all_logits, d_latents = self.discriminator(
            x=all_images, y=all_y, is_training=is_training)

        z_projs = self._latent_projections(d_latents)

        d_real, d_fake, _, _ = tf.split(d_all, 4)
        d_real_logits, d_fake_logits, _, _ = tf.split(d_all_logits, 4)
        z_projs_real, z_projs_fake, z_aug_projs_real, z_aug_projs_fake = tf.split(
            z_projs, 4)

        self.d_loss, _, _, self.g_loss = loss_lib.get_losses(
            d_real=d_real,
            d_fake=d_fake,
            d_real_logits=d_real_logits,
            d_fake_logits=d_fake_logits)

        penalty_loss = penalty_lib.get_penalty_loss(
            x=images,
            x_fake=generated,
            y=y,
            is_training=is_training,
            discriminator=self.discriminator,
            architecture=self._architecture)
        self.d_loss += self._lambda * penalty_loss

        z_projs = tf.concat([z_projs_real, z_projs_fake], 0)
        z_aug_projs = tf.concat([z_aug_projs_real, z_aug_projs_fake], 0)

        sims_logits = tf.matmul(z_projs, z_aug_projs, transpose_b=True)
        logits_max = tf.reduce_max(sims_logits, 1)
        sims_logits = sims_logits - tf.reshape(logits_max, [-1, 1])
        sims_probs = tf.nn.softmax(sims_logits)

        sim_labels = tf.constant(np.arange(bs * 2, dtype=np.int32))
        sims_onehot = tf.one_hot(sim_labels, bs * 2)

        c_real_loss = -tf.reduce_mean(
            tf.reduce_sum(sims_onehot * tf.log(sims_probs + 1e-10), 1))

        self.d_loss += c_real_loss * self._weight_contrastive_loss_d

        self._tpu_summary.scalar("loss/c_real_loss", c_real_loss)
        self._tpu_summary.scalar("loss/penalty", penalty_loss)
예제 #3
0
    def create_loss(self, features, labels, params, is_training=True):
        """Build the loss tensors for discriminator and generator.

    This method will set self.d_loss and self.g_loss.

    Args:
      features: Optional dictionary with inputs to the model ("images" should
          contain the real images and "z" the noise for the generator).
      labels: Tensor will labels. These are class indices. Use
          self._get_one_hot_labels(labels) to get a one hot encoded tensor.
      params: Dictionary with hyperparameters passed to TPUEstimator.
          Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`,
          `tpu_context`. `batch_size` is the batch size for this core.
      is_training: If True build the model in training mode. If False build the
          model for inference mode (e.g. use trained averages for batch norm).

    Raises:
      ValueError: If set of meta/hyper parameters is not supported.
    """
        images = features["images"]  # Input images.
        generated = features["generated"]  # Fake images.
        if self.conditional:
            y = self._get_one_hot_labels(labels)
            sampled_y = self._get_one_hot_labels(features["sampled_labels"])
        else:
            y = None
            sampled_y = None
            all_y = None

        # Batch size per core.
        bs = images.shape[0].value
        num_replicas = params[
            "context"].num_replicas if "context" in params else 1
        assert self._rotated_batch_size % num_replicas == 0
        # Rotated batch size per core.
        rotated_bs = self._rotated_batch_size // num_replicas
        assert rotated_bs % 4 == 0
        # Number of images to rotate. Each images gets rotated 3 times.
        num_rotated_examples = rotated_bs // 4
        logging.info(
            "num_replicas=%s, bs=%s, rotated_bs=%s, "
            "num_rotated_examples=%s, params=%s", num_replicas, bs, rotated_bs,
            num_rotated_examples, params)

        # Augment the images with rotation.
        if "rotation" in self._self_supervision:
            # Put all rotation angles in a single batch, the first batch_size are
            # the original up-right images, followed by rotated_batch_size * 3
            # rotated images with 3 different angles.
            assert num_rotated_examples <= bs, (num_rotated_examples, bs)
            images_rotated = utils.rotate_images(
                images[-num_rotated_examples:], rot90_scalars=(1, 2, 3))
            generated_rotated = utils.rotate_images(
                generated[-num_rotated_examples:], rot90_scalars=(1, 2, 3))
            # Labels for rotation loss (unrotated and 3 rotated versions). For
            # NUM_ROTATIONS=4 and num_rotated_examples=2 this is:
            # [0, 0, 1, 1, 2, 2, 3, 3]
            rotate_labels = tf.constant(
                np.repeat(np.arange(NUM_ROTATIONS, dtype=np.int32),
                          num_rotated_examples))
            rotate_labels_onehot = tf.one_hot(rotate_labels, NUM_ROTATIONS)
            all_images = tf.concat(
                [images, images_rotated, generated, generated_rotated], 0)
            if self.conditional:
                y_rotated = tf.tile(y[-num_rotated_examples:], [3, 1])
                sampled_y_rotated = tf.tile(y[-num_rotated_examples:], [3, 1])
                all_y = tf.concat([y, y_rotated, sampled_y, sampled_y_rotated],
                                  0)
        else:
            all_images = tf.concat([images, generated], 0)
            if self.conditional:
                all_y = tf.concat([y, sampled_y], axis=0)

        # Compute discriminator output for real and fake images in one batch.
        d_all, d_all_logits, c_all_logits = self.discriminator_with_rotation_head(
            all_images, y=all_y, is_training=is_training)
        d_real, d_fake = tf.split(d_all, 2)
        d_real_logits, d_fake_logits = tf.split(d_all_logits, 2)
        c_real_logits, c_fake_logits = tf.split(c_all_logits, 2)

        # Separate the true/fake scores from whole rotation batch.
        d_real_logits = d_real_logits[:bs]
        d_fake_logits = d_fake_logits[:bs]
        d_real = d_real[:bs]
        d_fake = d_fake[:bs]

        self.d_loss, _, _, self.g_loss = loss_lib.get_losses(
            d_real=d_real,
            d_fake=d_fake,
            d_real_logits=d_real_logits,
            d_fake_logits=d_fake_logits)

        penalty_loss = penalty_lib.get_penalty_loss(
            x=images,
            x_fake=generated,
            y=y,
            is_training=is_training,
            discriminator=self.discriminator,
            architecture=self._architecture)
        self.d_loss += self._lambda * penalty_loss

        # Add rotation augmented loss.
        if "rotation" in self._self_supervision:
            # We take an even pieces for all rotation angles
            assert len(c_real_logits.shape.as_list()) == 2, c_real_logits.shape
            assert len(c_fake_logits.shape.as_list()) == 2, c_fake_logits.shape
            c_real_logits = c_real_logits[-rotated_bs:]
            c_fake_logits = c_fake_logits[-rotated_bs:]
            preds_onreal = tf.cast(tf.argmax(c_real_logits, -1),
                                   rotate_labels.dtype)
            accuracy = tf.reduce_mean(
                tf.cast(tf.equal(rotate_labels, preds_onreal), tf.float32))
            c_real_probs = tf.nn.softmax(c_real_logits)
            c_fake_probs = tf.nn.softmax(c_fake_logits)
            c_real_loss = -tf.reduce_mean(
                tf.reduce_sum(
                    rotate_labels_onehot * tf.log(c_real_probs + 1e-10), 1))
            c_fake_loss = -tf.reduce_mean(
                tf.reduce_sum(
                    rotate_labels_onehot * tf.log(c_fake_probs + 1e-10), 1))
            if self._self_supervision == "rotation_only":
                self.d_loss *= 0.0
                self.g_loss *= 0.0
            self.d_loss += c_real_loss * self._weight_rotation_loss_d
            self.g_loss += c_fake_loss * self._weight_rotation_loss_g
        else:
            c_real_loss = 0.0
            c_fake_loss = 0.0
            accuracy = tf.zeros([])

        self._tpu_summary.scalar("loss/c_real_loss", c_real_loss)
        self._tpu_summary.scalar("loss/c_fake_loss", c_fake_loss)
        self._tpu_summary.scalar("accuracy/d_rotation", accuracy)
        self._tpu_summary.scalar("loss/penalty", penalty_loss)
예제 #4
0
    def create_loss(self, features, labels, params, is_training=True):
        """Build the loss tensors for discriminator and generator.

    This method will set self.d_loss and self.g_loss.

    Args:
      features: Optional dictionary with inputs to the model ("images" should
          contain the real images and "z" the noise for the generator).
      labels: Tensor will labels. Use
          self._get_one_hot_labels(labels) to get a one hot encoded tensor.
      params: Dictionary with hyperparameters passed to TPUEstimator.
          Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`,
          `tpu_context`. `batch_size` is the batch size for this core.
      is_training: If True build the model in training mode. If False build the
          model for inference mode (e.g. use trained averages for batch norm).

    Raises:
      ValueError: If set of meta/hyper parameters is not supported.
    """
        images = features["images"]  # Real images.
        generated = features["generated"]  # Fake images.
        eps_generated = features["eps_generated"]  #fake eps

        eps_bs = generated.get_shape().as_list()[0]
        half_bs = eps_bs // 2

        if self.conditional:
            y = self._get_one_hot_labels(labels)
            sampled_y = self._get_one_hot_labels(features["sampled_labels"])
            all_y = tf.concat([y, sampled_y, sampled_y], axis=0)
        else:
            y = None
            sampled_y = None
            all_y = None

        if self._deprecated_split_disc_calls:
            with tf.name_scope("disc_for_real"):
                d_real, d_real_logits, _ = self.discriminator(
                    images, y=y, is_training=is_training)
            with tf.name_scope("disc_for_fake"):
                d_fake, d_fake_logits, D_in_op_g = self.discriminator(
                    generated, y=sampled_y, is_training=is_training)
            with tf.name_scope("disc_eps_for_fake"):
                d_fake_eps, d_fake_logits_eps, D_in_op_eg = self.discriminator(
                    eps_generated, y=sampled_y, is_training=is_training)
        else:
            # Compute discriminator output for real and fake images in one batch.
            all_images = tf.concat([images, generated, eps_generated], axis=0)
            d_all, d_all_logits, D_in_op_all = self.discriminator(
                all_images, y=all_y, is_training=is_training)
            d_real, d_fake, d_fake_eps = tf.split(d_all, 3)
            d_real_logits, d_fake_logits, d_fake_logits_eps = tf.split(
                d_all_logits, 3)
            _, D_in_op_g, D_in_op_eg = tf.split(D_in_op_all, 3)

        self.d_loss, _, _, _ = loss_lib.get_losses(
            d_real=d_real,
            d_fake=tf.concat([d_fake[:half_bs], d_fake_eps[:half_bs]], axis=0),
            d_real_logits=d_real_logits,
            d_fake_logits=tf.concat(
                [d_fake_logits[:half_bs], d_fake_logits_eps[:half_bs]],
                axis=0))

        _, _, _, g_loss = loss_lib.get_losses(d_real=None,
                                              d_fake=d_fake,
                                              d_real_logits=None,
                                              d_fake_logits=d_fake_logits)

        _, _, _, g_loss_eps = loss_lib.get_losses(
            d_real=None,
            d_fake=d_fake_eps,
            d_real_logits=None,
            d_fake_logits=d_fake_logits_eps)

        self.g_loss = g_loss + g_loss_eps

        penalty_loss = penalty_lib.get_penalty_loss(
            x=images,
            x_fake=generated,
            y=y,
            is_training=is_training,
            discriminator=self.discriminator)
        self.d_loss += self._lambda * penalty_loss

        D_in_op_g = tf.reshape(D_in_op_g, [eps_bs, -1], name="reshape_g")
        D_in_op_eg = tf.reshape(D_in_op_eg, [eps_bs, -1], name="reshape_eg")

        if self._choice_of_f == 'subtract':
            f = tf.subtract(D_in_op_g, D_in_op_eg)
        elif self._choice_of_f == 'concat':
            f = tf.concat([D_in_op_g, D_in_op_eg], axis=1)

        pred_eps = self.aux_network(f)

        features["labels_aux"] = tf.cast(features["random_mask"],
                                         tf.float32,
                                         name="labels")
        self.aux_loss = tf.losses.sigmoid_cross_entropy(
            features["labels_aux"], pred_eps
        )  # tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_eps, labels=features["labels"], name="aux_cross_entropy"))
        self.aux_acc = tf.reduce_mean(
            tf.cast(tf.equal(features["labels_aux"],
                             tf.round(tf.sigmoid(pred_eps))),
                    dtype=tf.float32))
        self.g_loss += self._lambda_bce_loss * self.aux_loss
예제 #5
0
    def create_loss(self,
                    features,
                    labels,
                    params,
                    is_training=True,
                    reuse=False):
        """Build the loss tensors for discriminator and generator.

    This method will set self.d_loss and self.g_loss.

    Args:
      features: Optional dictionary with inputs to the model ("images" should
          contain the real images and "z" the noise for the generator).
      labels: Tensor will labels. Use
          self._get_one_hot_labels(labels) to get a one hot encoded tensor.
      params: Dictionary with hyperparameters passed to TPUEstimator.
          Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`,
          `tpu_context`. `batch_size` is the batch size for this core.
      is_training: If True build the model in training mode. If False build the
          model for inference mode (e.g. use trained averages for batch norm).
      reuse: Bool, whether to reuse existing variables for the models.
          This is only used for unrolling discriminator iterations when training
          on TPU.

    Raises:
      ValueError: If set of meta/hyper parameters is not supported.
    """
        images = features["images"]  # Input images.
        if self.conditional:
            y = self._get_one_hot_labels(labels)
            sampled_y = self._get_one_hot_labels(features["sampled_labels"])
            all_y = tf.concat([y, sampled_y], axis=0)
        else:
            y = None
            sampled_y = None
            all_y = None

        if self._experimental_joint_gen_for_disc:
            assert "generated" in features
            generated = features["generated"]
        else:
            logging.warning("Computing fake images for sub step separately.")
            z = features["z"]  # Noise vector.
            generated = self.generator(z,
                                       y=sampled_y,
                                       is_training=is_training,
                                       reuse=reuse)

        if self._deprecated_split_disc_calls:
            with tf.name_scope("disc_for_real"):
                d_real, d_real_logits, _ = self.discriminator(
                    images, y=y, is_training=is_training, reuse=reuse)
            with tf.name_scope("disc_for_fake"):
                d_fake, d_fake_logits, _ = self.discriminator(
                    generated,
                    y=sampled_y,
                    is_training=is_training,
                    reuse=True)
        else:
            # Compute discriminator output for real and fake images in one batch.
            all_images = tf.concat([images, generated], axis=0)
            d_all, d_all_logits, _ = self.discriminator(
                all_images, y=all_y, is_training=is_training, reuse=reuse)
            d_real, d_fake = tf.split(d_all, 2)
            d_real_logits, d_fake_logits = tf.split(d_all_logits, 2)

        self.d_loss, _, _, self.g_loss = loss_lib.get_losses(
            d_real=d_real,
            d_fake=d_fake,
            d_real_logits=d_real_logits,
            d_fake_logits=d_fake_logits)

        discriminator = functools.partial(self.discriminator, y=y)
        penalty_loss = penalty_lib.get_penalty_loss(
            x=images,
            x_fake=generated,
            is_training=is_training,
            discriminator=discriminator,
            architecture=self._architecture)
        self.d_loss += self._lambda * penalty_loss
        self._tpu_summary.scalar("loss/penalty", penalty_loss)