def get_support_set_softmax(self, logits, class_ids): """Softmax normalize over the support set. Args: logits: [N_k, H*W, Q] dimensional tensor. class_ids: [N_k] tensor giving the support-set-id of each image. Returns: Softmax-ed x over the support set. softmax(x) = np.exp(x) / np.reduce_sum(np.exp(x), axis) """ max_logit = tf.reduce_max(logits, axis=1, keepdims=True) max_logit = tf.math.unsorted_segment_max(max_logit, class_ids, tf.reduce_max(class_ids) + 1) max_logit = tf.gather(max_logit, class_ids) logits_reduc = logits - max_logit exp_x = tf.exp(logits_reduc) sum_exp_x = tf.reduce_sum(exp_x, axis=1, keepdims=True) sum_exp_x = tf.math.unsorted_segment_sum(sum_exp_x, class_ids, tf.reduce_max(class_ids) + 1) log_sum_exp_x = tf.log(sum_exp_x) log_sum_exp_x = tf.gather(log_sum_exp_x, class_ids) norm_logits = logits_reduc - log_sum_exp_x softmax = tf.exp(norm_logits) return softmax
def compute_logits(self, support_embeddings, query_embeddings, onehot_support_labels): """Computes the class logits. Probabilities are computed as a weighted sum of one-hot encoded training labels. Weights for individual support/query pairs of examples are proportional to the (potentially semi-normalized) cosine distance between the embeddings of the two examples. Args: support_embeddings: A Tensor of size [num_support_images, embedding dim]. query_embeddings: A Tensor of size [num_query_images, embedding dim]. onehot_support_labels: A Tensor of size [batch size, way]. Returns: The query set logits as a [num_query_images, way] matrix. """ # Undocumented in the paper, but *very important*: *only* the support set # embeddings is L2-normalized, which means that the distance is not exactly # a cosine distance. For comparison we also allow for the actual cosine # distance to be computed, which is controlled with the # `exact_cosine_distance` instance attribute. support_embeddings = tf.nn.l2_normalize(support_embeddings, 1, epsilon=1e-3) if self.exact_cosine_distance: query_embeddings = tf.nn.l2_normalize(query_embeddings, 1, epsilon=1e-3) # [num_query_images, num_support_images] similarities = tf.matmul( query_embeddings, support_embeddings, transpose_b=True) attention = tf.nn.softmax(similarities) # [num_query_images, way] probs = tf.matmul(attention, tf.cast(onehot_support_labels, tf.float32)) return tf.log(probs)
def historgram_loss(y, y_hat, k=100., sigma=1 / 2): raise NotImplementedError() ps = 0. w = 1 / k y = tf.squeeze(y, axis=2) # y_hat = tf.layers.flatten(y_hat) k = np.linspace(0., 1., k) s = (tf.erf((1. - y) / (tf.sqrt(2.) * sigma)) - tf.erf((0. - y) / (tf.sqrt(2.) * sigma))) for idx, j in enumerate(k): u = tf.erf(((j + w - y) / (tf.sqrt(2.) * sigma))) l = tf.erf(((j - y) / (tf.sqrt(2.) * sigma))) p = (u - l) / (2 * s + 1e-6) f_x = tf.log(y_hat[:, :, idx]) ps += p * tf.where(tf.is_nan(f_x), tf.zeros_like(f_x), f_x) return tf.reduce_mean(-ps)
def preprocess_spatial_observation(input_obs, spec, categorical_embedding_dims=16, non_categorical_scaling='log'): with tf.name_scope('preprocess_spatial_obs'): features = Lambda(lambda x: tf.split(x, x.get_shape()[1], axis=1))(input_obs) for f in spec.features: if f.is_categorical: features[f.index] = Lambda(lambda x: tf.squeeze(x, axis=1))(features[f.index]) features[f.index] = Embedding(f.scale, categorical_embedding_dims)(features[f.index]) features[f.index] = Permute((3, 1, 2))(features[f.index]) else: if non_categorical_scaling == 'log': features[f.index] = Lambda(lambda x: tf.log(x + 1e-10))(features[f.index]) elif non_categorical_scaling == 'normalize': features[f.index] = Lambda(lambda x: x / f.scale)(features[f.index]) return features