def predict(self): """See `model.py` for documentation.""" nclasses = self.nway num_cluster_steps = self.config.num_cluster_steps h_train, h_unlabel, h_test = self.encode(self.x_train, self.x_unlabel, self.x_test) y_train = self.y_train protos = self._compute_protos(nclasses, h_train, y_train) logits_list = [] logits_list.append(compute_logits(protos, h_test)) # Hard assignment for training images. prob_train = [None] * (nclasses) for kk in range(nclasses): # [B, N, 1] prob_train[kk] = tf.expand_dims( tf.cast(tf.equal(y_train, kk), h_train.dtype), 2) prob_train = concat(prob_train, 2) y_train_shape = tf.shape(y_train) bsize = y_train_shape[0] h_all = concat([h_train, h_unlabel], 1) mask = None # Calculate pairwise distances. protos_1 = tf.expand_dims(protos, 2) protos_2 = tf.expand_dims(h_unlabel, 1) pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3]) # [B, K, N] mean_dist = tf.reduce_mean(pair_dist, [2], keep_dims=True) pair_dist_normalize = pair_dist / mean_dist min_dist = tf.reduce_min(pair_dist_normalize, [2], keep_dims=True) # [B, K, 1] max_dist = tf.reduce_max(pair_dist_normalize, [2], keep_dims=True) mean_dist, var_dist = tf.nn.moments(pair_dist_normalize, [2], keep_dims=True) mean_dist += tf.to_float(tf.equal(mean_dist, 0.0)) var_dist += tf.to_float(tf.equal(var_dist, 0.0)) skew = tf.reduce_mean( ((pair_dist_normalize - mean_dist)**3) / (tf.sqrt(var_dist)**3), [2], keep_dims=True) kurt = tf.reduce_mean( ((pair_dist_normalize - mean_dist)**4) / (var_dist**2) - 3, [2], keep_dims=True) n_features = 5 n_out = 3 dist_features = tf.reshape( concat([min_dist, max_dist, var_dist, skew, kurt], 2), [-1, n_features]) # [BK, 4] dist_features = tf.stop_gradient(dist_features) hdim = [n_features, 20, n_out] act_fn = [tf.nn.tanh, None] thresh = mlp(dist_features, hdim, is_training=True, act_fn=act_fn, dtype=tf.float32, add_bias=True, wd=None, init_std=[0.01, 0.01], init_method=None, scope="dist_mlp", dropout=None, trainable=True) scale = tf.exp(thresh[:, 2]) bias_start = tf.exp(thresh[:, 0]) bias_add = thresh[:, 1] bias_start = tf.reshape(bias_start, [bsize, 1, -1]) #[B, 1, K] bias_add = tf.reshape(bias_add, [bsize, 1, -1]) self._scale = scale self._bias_start = bias_start self._bias_add = bias_add # Run clustering. for tt in range(num_cluster_steps): protos_1 = tf.expand_dims(protos, 2) protos_2 = tf.expand_dims(h_unlabel, 1) pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3]) # [B, K, N] m_dist = tf.reduce_mean(pair_dist, [2]) # [B, K] m_dist_1 = tf.expand_dims(m_dist, 1) # [B, 1, K] m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0)) # Label assignment. if num_cluster_steps > 1: bias_tt = bias_start + ( tt / float(num_cluster_steps - 1)) * bias_add else: bias_tt = bias_start negdist = compute_logits(protos, h_unlabel) mask = tf.sigmoid((negdist / m_dist_1 + bias_tt) * scale) prob_unlabel, mask = assign_cluster_soft_mask( protos, h_unlabel, mask) prob_all = concat([prob_train, prob_unlabel * mask], 1) # No update if 0 unlabel. protos = tf.cond( tf.shape(self._x_unlabel)[1] > 0, lambda: update_cluster(h_all, prob_all), lambda: protos) logits_list.append(compute_logits(protos, h_test)) # Distractor evaluation. if mask is not None: max_mask = tf.reduce_max(mask, [2]) mean_mask = tf.reduce_mean(max_mask) pred_non_distractor = tf.to_float(max_mask > mean_mask) acc, recall, precision = eval_distractor(pred_non_distractor, self.y_unlabel) self._non_distractor_acc = acc self._distractor_recall = recall self._distractor_precision = precision self._distractor_pred = max_mask return logits_list
def predict(self): """See `model.py` for documentation.""" nclasses = self.nway num_cluster_steps = self.config.num_cluster_steps h_train, h_unlabel, h_test = self.get_encoded_inputs( self.x_train, self.x_unlabel, self.x_test) y_train = self.y_train protos = self._compute_protos(nclasses, h_train, y_train) # Distractor class has a zero vector as prototype. protos = concat([protos, tf.zeros_like(protos[:, 0:1, :])], 1) # Hard assignment for training images. prob_train = [None] * (nclasses + 1) for kk in range(nclasses): # [B, N, 1] prob_train[kk] = tf.expand_dims( tf.cast(tf.equal(y_train, kk), h_train.dtype), 2) prob_train[-1] = tf.zeros_like(prob_train[0]) prob_train = concat(prob_train, 2) # Initialize cluster radii. radii = [None] * (nclasses + 1) y_train_shape = tf.shape(y_train) bsize = y_train_shape[0] for kk in range(nclasses): radii[kk] = tf.ones([bsize, 1]) * 1.0 # Distractor class has a larger radius. if FLAGS.learn_radius: log_distractor_radius = tf.get_variable( "log_distractor_radius", shape=[], dtype=tf.float32, initializer=tf.constant_initializer(np.log(FLAGS.init_radius))) distractor_radius = tf.exp(log_distractor_radius) else: distractor_radius = FLAGS.init_radius distractor_radius = tf.cond( tf.shape(self._x_unlabel)[1] > 0, lambda: distractor_radius, lambda: 100000.0) # distractor_radius = tf.Print(distractor_radius, [distractor_radius]) radii[-1] = tf.ones([bsize, 1]) * distractor_radius radii = concat(radii, 1) # [B, K] h_all = concat([h_train, h_unlabel], 1) logits_list = [] logits_list.append(compute_logits_radii(protos, h_test, radii)) # Run clustering. for tt in range(num_cluster_steps): # Label assignment. prob_unlabel = assign_cluster_radii(protos, h_unlabel, radii) prob_all = concat([prob_train, prob_unlabel], 1) protos = update_cluster(h_all, prob_all) logits_list.append(compute_logits_radii(protos, h_test, radii)) # Distractor evaluation. is_distractor = tf.equal(tf.argmax(prob_unlabel, axis=-1), nclasses) pred_non_distractor = 1.0 - tf.to_float(is_distractor) acc, recall, precision = eval_distractor(pred_non_distractor, self.y_unlabel) self._non_distractor_acc = acc self._distractor_recall = recall self._distractor_precision = precision self._distractor_pred = 1.0 - tf.exp(prob_unlabel[:, :, -1]) return logits_list