def noisy_forward(self, data, noise=tf.constant(0.0), update_batch_stats=False): with tf.name_scope("forward"): protos = concat( [self.protos, tf.zeros_like(self.protos[:, 0:1, :])], 1) encoded = self.phi(data + noise, update_batch_stats=update_batch_stats) logits = compute_logits_radii(protos, tf.expand_dims(encoded, 0), self.radii) return logits
def predict(self): """See `model.py` for documentation.""" nclasses = self.nway num_cluster_steps = self.config.num_cluster_steps h_train, h_unlabel, h_test = self.get_encoded_inputs( self.x_train, self.x_unlabel, self.x_test) y_train = self.y_train protos = self._compute_protos(nclasses, h_train, y_train) # Distractor class has a zero vector as prototype. protos = concat([protos, tf.zeros_like(protos[:, 0:1, :])], 1) # Hard assignment for training images. prob_train = [None] * (nclasses + 1) for kk in range(nclasses): # [B, N, 1] prob_train[kk] = tf.expand_dims( tf.cast(tf.equal(y_train, kk), h_train.dtype), 2) prob_train[-1] = tf.zeros_like(prob_train[0]) prob_train = concat(prob_train, 2) # Initialize cluster radii. radii = [None] * (nclasses + 1) y_train_shape = tf.shape(y_train) bsize = y_train_shape[0] for kk in range(nclasses): radii[kk] = tf.ones([bsize, 1]) * 1.0 # Distractor class has a larger radius. if FLAGS.learn_radius: log_distractor_radius = tf.get_variable( "log_distractor_radius", shape=[], dtype=tf.float32, initializer=tf.constant_initializer(np.log(FLAGS.init_radius))) distractor_radius = tf.exp(log_distractor_radius) else: distractor_radius = FLAGS.init_radius distractor_radius = tf.cond( tf.shape(self._x_unlabel)[1] > 0, lambda: distractor_radius, lambda: 100000.0) # distractor_radius = tf.Print(distractor_radius, [distractor_radius]) radii[-1] = tf.ones([bsize, 1]) * distractor_radius radii = concat(radii, 1) # [B, K] h_all = concat([h_train, h_unlabel], 1) logits_list = [] logits_list.append(compute_logits_radii(protos, h_test, radii)) # Run clustering. for tt in range(num_cluster_steps): # Label assignment. prob_unlabel = assign_cluster_radii(protos, h_unlabel, radii) prob_all = concat([prob_train, prob_unlabel], 1) protos = update_cluster(h_all, prob_all) logits_list.append(compute_logits_radii(protos, h_test, radii)) # Distractor evaluation. is_distractor = tf.equal(tf.argmax(prob_unlabel, axis=-1), nclasses) pred_non_distractor = 1.0 - tf.to_float(is_distractor) acc, recall, precision = eval_distractor(pred_non_distractor, self.y_unlabel) self._non_distractor_acc = acc self._distractor_recall = recall self._distractor_precision = precision self._distractor_pred = 1.0 - tf.exp(prob_unlabel[:, :, -1]) return logits_list