def compute_logits_for_episode(self, local_support_embeddings,
                                 local_query_embeddings, data):
    all_support_labels = distribute_utils.aggregate(data.onehot_support_labels)
    all_support_embeddings = distribute_utils.aggregate(
        local_support_embeddings)
    query_logits = self.compute_logits(
        all_support_embeddings,
        local_query_embeddings,
        all_support_labels,
    )

    return query_logits
Exemple #2
0
    def compute_logits_for_episode(self, support_embeddings, query_embeddings,
                                   data):
        """Compute CrossTransformer logits."""
        with tf.variable_scope('tformer_keys', reuse=tf.AUTO_REUSE):
            support_keys, key_params = functional_backbones.conv(
                support_embeddings, [1, 1],
                self.query_dim,
                1,
                weight_decay=self.tformer_weight_decay)
            query_queries, _ = functional_backbones.conv(
                query_embeddings, [1, 1],
                self.query_dim,
                1,
                params=key_params,
                weight_decay=self.tformer_weight_decay)

        with tf.variable_scope('tformer_values', reuse=tf.AUTO_REUSE):
            support_values, value_params = functional_backbones.conv(
                support_embeddings, [1, 1],
                self.val_dim,
                1,
                weight_decay=self.tformer_weight_decay)
            query_values, _ = functional_backbones.conv(
                query_embeddings, [1, 1],
                self.val_dim,
                1,
                params=value_params,
                weight_decay=self.tformer_weight_decay)

        onehot_support_labels = distribute_utils.aggregate(
            data.onehot_support_labels)
        support_keys = distribute_utils.aggregate(support_keys)
        support_values = distribute_utils.aggregate(support_values)

        labels = tf.argmax(onehot_support_labels, axis=1)
        if self.rematerialize:
            distances = self._get_dist_rematerialize(query_queries,
                                                     query_values,
                                                     support_keys,
                                                     support_values, labels)
        else:
            distances = self._get_dist(query_queries, query_values,
                                       support_keys, support_values, labels)

        self.test_logits = -tf.transpose(distances)

        return self.test_logits
Exemple #3
0
 def query_shots(self):
     """Return global query shots for the episode."""
     return compute_shot(self.way,
                         distribute_utils.aggregate(self.query_labels))
Exemple #4
0
 def support_shots(self):
     """Return global support shots for the episode."""
     return compute_shot(self.way,
                         distribute_utils.aggregate(self.support_labels))
Exemple #5
0
 def unique_class_ids(self):
     """Return global unique class id's for the episode."""
     return compute_unique_class_ids(
         tf.concat((distribute_utils.aggregate(self.support_class_ids),
                    distribute_utils.aggregate(self.query_class_ids)), -1))