def noisy_forward(self,
                   data,
                   noise=tf.constant(0.0),
                   update_batch_stats=False):
     with tf.name_scope("forward"):
         protos = concat(
             [self.protos,
              tf.zeros_like(self.protos[:, 0:1, :])], 1)
         encoded = self.phi(data + noise,
                            update_batch_stats=update_batch_stats)
         logits = compute_logits_radii(protos, tf.expand_dims(encoded, 0),
                                       self.radii)
     return logits
    def predict(self):
        """See `model.py` for documentation."""
        nclasses = self.nway
        num_cluster_steps = self.config.num_cluster_steps
        h_train, h_unlabel, h_test = self.get_encoded_inputs(
            self.x_train, self.x_unlabel, self.x_test)
        y_train = self.y_train
        protos = self._compute_protos(nclasses, h_train, y_train)

        # Distractor class has a zero vector as prototype.
        protos = concat([protos, tf.zeros_like(protos[:, 0:1, :])], 1)

        # Hard assignment for training images.
        prob_train = [None] * (nclasses + 1)
        for kk in range(nclasses):
            # [B, N, 1]
            prob_train[kk] = tf.expand_dims(
                tf.cast(tf.equal(y_train, kk), h_train.dtype), 2)
            prob_train[-1] = tf.zeros_like(prob_train[0])
        prob_train = concat(prob_train, 2)

        # Initialize cluster radii.
        radii = [None] * (nclasses + 1)
        y_train_shape = tf.shape(y_train)
        bsize = y_train_shape[0]
        for kk in range(nclasses):
            radii[kk] = tf.ones([bsize, 1]) * 1.0

        # Distractor class has a larger radius.
        if FLAGS.learn_radius:
            log_distractor_radius = tf.get_variable(
                "log_distractor_radius",
                shape=[],
                dtype=tf.float32,
                initializer=tf.constant_initializer(np.log(FLAGS.init_radius)))
            distractor_radius = tf.exp(log_distractor_radius)
        else:
            distractor_radius = FLAGS.init_radius
        distractor_radius = tf.cond(
            tf.shape(self._x_unlabel)[1] > 0, lambda: distractor_radius,
            lambda: 100000.0)
        # distractor_radius = tf.Print(distractor_radius, [distractor_radius])
        radii[-1] = tf.ones([bsize, 1]) * distractor_radius
        radii = concat(radii, 1)  # [B, K]

        h_all = concat([h_train, h_unlabel], 1)
        logits_list = []
        logits_list.append(compute_logits_radii(protos, h_test, radii))

        # Run clustering.
        for tt in range(num_cluster_steps):
            # Label assignment.
            prob_unlabel = assign_cluster_radii(protos, h_unlabel, radii)
            prob_all = concat([prob_train, prob_unlabel], 1)
            protos = update_cluster(h_all, prob_all)
            logits_list.append(compute_logits_radii(protos, h_test, radii))

        # Distractor evaluation.
        is_distractor = tf.equal(tf.argmax(prob_unlabel, axis=-1), nclasses)
        pred_non_distractor = 1.0 - tf.to_float(is_distractor)
        acc, recall, precision = eval_distractor(pred_non_distractor,
                                                 self.y_unlabel)
        self._non_distractor_acc = acc
        self._distractor_recall = recall
        self._distractor_precision = precision
        self._distractor_pred = 1.0 - tf.exp(prob_unlabel[:, :, -1])
        return logits_list