def similarity_collision(self, query, identifier): #pdb.set_trace() mappings = self.get_all_attributes() concept = self.get_concept(identifier) # shape: [batch, attributes, channel] or [attributes, channel] query_mapped = torch.stack([m(query) for m in mappings], dim=-2) query_mapped = query_mapped / query_mapped.norm( 2, dim=-1, keepdim=True) reference = jactorch.add_dim_as_except(concept.normalized_embedding, query_mapped, -2, -1) margin = self._margin margin = 1 self._tau = 0.2 logits = ((query_mapped * reference).sum(dim=-1) - 1 + margin) / margin / self._tau belong = jactorch.add_dim_as_except(concept.log_normalized_belong, logits, -1) logits = jactorch.logsumexp(logits + belong, dim=-1) return logits
def similarity2(self, q1, q2, identifier, _normalized=False): """ Args: _normalized (bool): backdoor for function `cross_similarity`. """ global _query_assisted_same logits_and = lambda x, y: torch.min(x, y) logits_or = lambda x, y: torch.max(x, y) if not _normalized: q1 = q1 / q1.norm(2, dim=-1, keepdim=True) q2 = q2 / q2.norm(2, dim=-1, keepdim=True) if not _query_assisted_same or not self.training: margin = self._margin_cross logits = ((q1 * q2).sum(dim=-1) - 1 + margin) / margin / self._tau return logits else: margin = self._margin_cross logits1 = ((q1 * q2).sum(dim=-1) - 1 + margin) / margin / self._tau _, concepts, attr_id = self.get_concepts_by_attribute(identifier) masks = [] for k, v in concepts: embedding = v.normalized_embedding[attr_id] embedding = jactorch.add_dim_as_except(embedding, q1, -1) margin = self._margin mask1 = ((q1 * embedding).sum(dim=-1) - 1 + margin) / margin / self._tau mask2 = ((q2 * embedding).sum(dim=-1) - 1 + margin) / margin / self._tau belong_score = v.normalized_belong[attr_id] # TODO(Jiayuan Mao @ 08/10): this line may have numerical issue. mask = logits_or( logits_and(mask1, mask2), logits_and(-mask1, -mask2), ) * belong_score masks.append(mask) logits2 = torch.stack(masks, dim=-1).sum(dim=-1) # TODO(Jiayuan Mao @ 08/09): should we take the average here? or just use the logits2? return torch.min(logits1, logits2)
def query_attribute(self, query, identifier): mapping, concepts, attr_id = self.get_concepts_by_attribute(identifier) query = mapping(query) query = query / query.norm(2, dim=-1, keepdim=True) word2idx = {} masks = [] for k, v in concepts.items(): embedding = v.normalized_embedding[attr_id] embedding = jactorch.add_dim_as_except(embedding, query, -1) margin = self._margin mask = ((query * embedding).sum(dim=-1) - 1 + margin) / margin / self._tau belong_score = v.log_normalized_belong[attr_id] mask = mask + belong_score masks.append(mask) word2idx[k] = len(word2idx) masks = torch.stack(masks, dim=-1) return masks, word2idx