コード例 #1
0
    def similarity_collision(self, query, identifier):

        #pdb.set_trace()

        mappings = self.get_all_attributes()
        concept = self.get_concept(identifier)

        # shape: [batch, attributes, channel] or [attributes, channel]
        query_mapped = torch.stack([m(query) for m in mappings], dim=-2)
        query_mapped = query_mapped / query_mapped.norm(
            2, dim=-1, keepdim=True)
        reference = jactorch.add_dim_as_except(concept.normalized_embedding,
                                               query_mapped, -2, -1)

        margin = self._margin
        margin = 1
        self._tau = 0.2
        logits = ((query_mapped * reference).sum(dim=-1) - 1 +
                  margin) / margin / self._tau

        belong = jactorch.add_dim_as_except(concept.log_normalized_belong,
                                            logits, -1)
        logits = jactorch.logsumexp(logits + belong, dim=-1)

        return logits
コード例 #2
0
    def similarity2(self, q1, q2, identifier, _normalized=False):
        """
        Args:
            _normalized (bool): backdoor for function `cross_similarity`.
        """

        global _query_assisted_same

        logits_and = lambda x, y: torch.min(x, y)
        logits_or = lambda x, y: torch.max(x, y)

        if not _normalized:
            q1 = q1 / q1.norm(2, dim=-1, keepdim=True)
            q2 = q2 / q2.norm(2, dim=-1, keepdim=True)

        if not _query_assisted_same or not self.training:
            margin = self._margin_cross
            logits = ((q1 * q2).sum(dim=-1) - 1 + margin) / margin / self._tau
            return logits
        else:
            margin = self._margin_cross
            logits1 = ((q1 * q2).sum(dim=-1) - 1 + margin) / margin / self._tau

            _, concepts, attr_id = self.get_concepts_by_attribute(identifier)
            masks = []
            for k, v in concepts:
                embedding = v.normalized_embedding[attr_id]
                embedding = jactorch.add_dim_as_except(embedding, q1, -1)

                margin = self._margin
                mask1 = ((q1 * embedding).sum(dim=-1) - 1 +
                         margin) / margin / self._tau
                mask2 = ((q2 * embedding).sum(dim=-1) - 1 +
                         margin) / margin / self._tau

                belong_score = v.normalized_belong[attr_id]
                # TODO(Jiayuan Mao @ 08/10): this line may have numerical issue.
                mask = logits_or(
                    logits_and(mask1, mask2),
                    logits_and(-mask1, -mask2),
                ) * belong_score

                masks.append(mask)
            logits2 = torch.stack(masks, dim=-1).sum(dim=-1)

            # TODO(Jiayuan Mao @ 08/09): should we take the average here? or just use the logits2?
            return torch.min(logits1, logits2)
コード例 #3
0
    def query_attribute(self, query, identifier):
        mapping, concepts, attr_id = self.get_concepts_by_attribute(identifier)
        query = mapping(query)
        query = query / query.norm(2, dim=-1, keepdim=True)

        word2idx = {}
        masks = []
        for k, v in concepts.items():
            embedding = v.normalized_embedding[attr_id]
            embedding = jactorch.add_dim_as_except(embedding, query, -1)

            margin = self._margin
            mask = ((query * embedding).sum(dim=-1) - 1 + margin) / margin / self._tau

            belong_score = v.log_normalized_belong[attr_id]
            mask = mask + belong_score

            masks.append(mask)
            word2idx[k] = len(word2idx)

        masks = torch.stack(masks, dim=-1)
        return masks, word2idx