Esempio n. 1
0
  def guess_label(self, logit, temp=0.5):
    logit = tf.reshape(logit, [-1, self.dataset.num_classes])
    logit = tf.split(logit, self.nu, axis=0)
    logit = [logit_norm(x) for x in logit]
    logit = tf.concat(logit, 0)
    ## Done with logit norm
    p_model_y = tf.reshape(
        tf.nn.softmax(logit), [self.nu, -1, self.dataset.num_classes])
    p_model_y = tf.reduce_mean(p_model_y, axis=0)

    p_target = tf.pow(p_model_y, 1.0 / temp)
    p_target /= tf.reduce_sum(p_target, axis=1, keepdims=True)

    return p_target
Esempio n. 2
0
    def crossentropy_minimize(self,
                              u_logits,
                              u_images,
                              l_images,
                              l_labels,
                              u_labels=None):
        """Cross-entropy optimization step implementation for TPU."""
        batch_size = self.batch_size // self.strategy.num_replicas_in_sync
        guessed_label = self.guess_label(u_logits)
        self.guessed_label = guessed_label

        guessed_label = tf.reshape(tf.stop_gradient(guessed_label),
                                   shape=(-1, self.dataset.num_classes))

        l_labels = tf.reshape(tf.one_hot(l_labels, self.dataset.num_classes),
                              shape=(-1, self.dataset.num_classes))
        augment_images, augment_labels = self.augment(
            [l_images, u_images], [l_labels] + [guessed_label] * self.nu,
            [self.beta, self.beta])
        logit = self.net(augment_images, name='model', training=True)

        zbs = batch_size * 2
        halfzbs = batch_size

        split_pos = [tf.shape(l_images)[0], halfzbs, halfzbs]

        logit = [logit_norm(lgt) for lgt in tf.split(logit, split_pos, axis=0)]
        u_logit = tf.concat(logit[1:], axis=0)

        split_pos = [tf.shape(l_images)[0], zbs]
        l_augment_labels, u_augment_labels = tf.split(augment_labels,
                                                      split_pos,
                                                      axis=0)

        u_loss = tf.losses.softmax_cross_entropy(u_augment_labels, u_logit)
        l_loss = tf.losses.softmax_cross_entropy(l_augment_labels, logit[0])

        loss = tf.math.add(l_loss,
                           u_loss * FLAGS.ce_factor,
                           name='crossentropy_minimization_loss')

        return loss