def test_embed_similarity(): from mmtrack.core import embed_similarity key_embeds = torch.randn(20, 256) ref_embeds = torch.randn(10, 256) sims = embed_similarity(key_embeds, ref_embeds, method='dot_product', temperature=-1, transpose=True) assert sims.size() == (20, 10) sims = embed_similarity(key_embeds, ref_embeds.t(), method='dot_product', temperature=-1, transpose=False) assert sims.size() == (20, 10) sims = embed_similarity(key_embeds, ref_embeds, method='dot_product', temperature=0.07, transpose=True) assert sims.size() == (20, 10) sims = embed_similarity(key_embeds, ref_embeds, method='cosine', temperature=-1, transpose=True) assert sims.size() == (20, 10) assert sims.max() <= 1
def forward(self, x, ref_x, num_x_per_img, num_x_per_ref_img): """Computing the similarity scores between `x` and `ref_x`. Args: x (Tensor): of shape [N, C, H, W]. N is the number of key frame proposals. ref_x (Tensor): of shape [M, C, H, W]. M is the number of reference frame proposals. num_x_per_img (list[int]): The `x` contains proposals of multi-images. `num_x_per_img` denotes the number of proposals for each key image. num_x_per_ref_img (list[int]): The `ref_x` contains proposals of multi-images. `num_x_per_ref_img` denotes the number of proposals for each reference image. Returns: list[Tensor]: The predicted similarity_logits of each pair of key image and reference image. """ x_split = self._forward(x, num_x_per_img) ref_x_split = self._forward(ref_x, num_x_per_ref_img) similarity_logits = [] for one_x, one_ref_x in zip(x_split, ref_x_split): similarity_logit = embed_similarity( one_x, one_ref_x, method='dot_product') dummy = similarity_logit.new_zeros(one_x.shape[0], 1) similarity_logit = torch.cat((dummy, similarity_logit), dim=1) similarity_logits.append(similarity_logit) return similarity_logits
def match(self, key_embeds, ref_embeds, key_sampling_results, ref_sampling_results): """Calculate the dist matrixes for loss measurement. Args: key_embeds (Tensor): Embeds of positive bboxes in sampling results of key image. ref_embeds (Tensor): Embeds of all bboxes in sampling results of the reference image. keysampling_results (List[obj:SamplingResults]): Assign results of all images in a batch after sampling. ref_sampling_results (List[obj:SamplingResults]): Assign results of all reference images in a batch after sampling. Returns: Tuple[list[Tensor]]: Calculation results. Containing the following list of Tensors: - dists (list[Tensor]): Dot-product dists between key_embeds and ref_embeds, each tensor in list has shape (len(key_pos_bboxes), len(ref_bboxes)). - cos_dists (list[Tensor]): Cosine dists between key_embeds and ref_embeds, each tensor in list has shape (len(key_pos_bboxes), len(ref_bboxes)). """ num_key_rois = [res.pos_bboxes.size(0) for res in key_sampling_results] key_embeds = torch.split(key_embeds, num_key_rois) num_ref_rois = [res.bboxes.size(0) for res in ref_sampling_results] ref_embeds = torch.split(ref_embeds, num_ref_rois) dists, cos_dists = [], [] for key_embed, ref_embed in zip(key_embeds, ref_embeds): dist = embed_similarity( key_embed, ref_embed, method='dot_product', temperature=self.softmax_temp) dists.append(dist) if self.loss_track_aux is not None: cos_dist = embed_similarity( key_embed, ref_embed, method='cosine') cos_dists.append(cos_dist) else: cos_dists.append(None) return dists, cos_dists