def batch_center_loss(self, u, label_u, loss_direction, loss_scale, knn_k, normed=False): if normed: per_img_avg = tfdist.normed_euclidean2(u, tf.reduce_mean(u, 0)) else: per_img_avg = tfdist.euclidean(u, tf.reduce_mean(u, 0)) if loss_scale == 'similarity': per_img_avg_in_scale = tf.negative( tf.exp(tf.cast(tf.negative(per_img_avg), dtype=tf.float32))) else: per_img_avg_in_scale = per_img_avg loss_before_direction = tf.reduce_mean(per_img_avg_in_scale) if loss_direction == 'increase': loss = tf.negative(loss_before_direction) else: loss = loss_before_direction return loss
def knn_loss(self, u, label_u, loss_direction, loss_scale, knn_k, normed=False): if knn_k is not None: k = knn_k distances = tfdist.distance(u, u, pair=True, dist_type='euclidean') values, indices = tf.math.top_k(-distances, k=k, sorted=True) if normed: per_img_avg = tfdist.normed_euclidean2( u, tf.reduce_mean(tf.gather(u, indices), 1)) else: per_img_avg = tfdist.euclidean( u, tf.reduce_mean(tf.gather(u, indices), 1)) if loss_scale == 'similarity': per_img_avg_in_scale = tf.negative( tf.exp(tf.cast(tf.negative(per_img_avg), dtype=tf.float32))) else: per_img_avg_in_scale = per_img_avg loss_before_direction = tf.reduce_mean(per_img_avg_in_scale) if loss_direction == 'increase': loss = tf.negative(loss_before_direction) else: loss = loss_before_direction return loss
def class_knn_loss(self, u, label_u, loss_direction, loss_scale, knn_k, normed=False): if knn_k is not None: k = knn_k indices = tf.constant(0, shape=[0, k], dtype=tf.int32) iters = tf.cond(tf.equal(self.stage, 0), lambda: self.batch_size, lambda: self.val_batch_size) def condition(indices, i): return tf.less(i, iters) def body(indices, i): distances = tfdist.distance(u, u, pair=True, dist_type='euclidean') b1_nt = tf.where(tf.equal(tf.gather(label_u, i, axis=0), 1)) corrected_distances_b1 = tf.where( tf.squeeze( tf.reduce_any(tf.equal(tf.gather(label_u, b1_nt, axis=1), 1), axis=1)), tf.gather(distances, i, axis=0), tf.multiply(tf.ones_like(tf.gather(distances, i, axis=0)), float("inf"))) val, ind = tf.math.top_k(-corrected_distances_b1, k=k, sorted=True) return [ tf.concat([indices, tf.reshape(ind, shape=[1, -1])], 0), tf.add(i, 1) ] indices, i = tf.while_loop(condition, body, [indices, 0], shape_invariants=[ tf.TensorShape([None, None]), tf.TensorShape([]) ]) if normed: per_img_avg = tfdist.normed_euclidean2( u, tf.reduce_mean(tf.gather(u, indices), 1)) else: per_img_avg = tfdist.euclidean( u, tf.reduce_mean(tf.gather(u, indices), 1)) if loss_scale == 'similarity': per_img_avg_in_scale = tf.negative( tf.exp(tf.cast(tf.negative(per_img_avg), dtype=tf.float32))) else: per_img_avg_in_scale = per_img_avg loss_before_direction = tf.reduce_mean(per_img_avg_in_scale) if loss_direction == 'increase': loss = tf.negative(loss_before_direction) else: loss = loss_before_direction return loss
def class_center_loss(self, u, label_u, loss_direction, loss_scale, knn_k, normed=False): if self.extract_features or self.extract_hashlayer_features: shape1 = label_u.shape[1].value targets = tf.constant(0.0, shape=[0, u.shape[1]]) for i in range(0, shape1): targets = tf.concat([ targets, tf.stop_gradient( tf.reshape( tf.reduce_mean( tf.reshape( tf.gather( u, tf.where(tf.equal(label_u[:, i], 1))), [-1, u.shape[1]]), 0), [1, -1])) ], 0) corrected_targets = tf.where(tf.is_nan(targets), tf.zeros_like(targets), targets) targets = tf.stop_gradient(corrected_targets) else: targets = self.targets mean = tf.divide( tf.reduce_sum( tf.multiply( tf.cast(tf.multiply(tf.expand_dims(label_u, 2), np.ones((1, 1, np.int(u.shape[1])))), dtype=tf.float32), targets), 1), tf.reshape(tf.cast(tf.reduce_sum(label_u, 1), dtype=tf.float32), (-1, 1))) if normed: per_img_avg = tfdist.normed_euclidean2(u, mean) else: per_img_avg = tfdist.euclidean(u, mean) if loss_scale == 'similarity': per_img_avg_in_scale = tf.negative( tf.exp(tf.cast(tf.negative(per_img_avg), dtype=tf.float32))) else: per_img_avg_in_scale = per_img_avg loss_before_direction = tf.reduce_mean(per_img_avg_in_scale) if loss_direction == 'increase': loss = tf.negative(loss_before_direction) else: loss = loss_before_direction return loss