Example #1
0
def test_embed_similarity():
    from mmtrack.core import embed_similarity
    key_embeds = torch.randn(20, 256)
    ref_embeds = torch.randn(10, 256)

    sims = embed_similarity(key_embeds,
                            ref_embeds,
                            method='dot_product',
                            temperature=-1,
                            transpose=True)
    assert sims.size() == (20, 10)

    sims = embed_similarity(key_embeds,
                            ref_embeds.t(),
                            method='dot_product',
                            temperature=-1,
                            transpose=False)
    assert sims.size() == (20, 10)

    sims = embed_similarity(key_embeds,
                            ref_embeds,
                            method='dot_product',
                            temperature=0.07,
                            transpose=True)
    assert sims.size() == (20, 10)

    sims = embed_similarity(key_embeds,
                            ref_embeds,
                            method='cosine',
                            temperature=-1,
                            transpose=True)
    assert sims.size() == (20, 10)
    assert sims.max() <= 1
Example #2
0
    def forward(self, x, ref_x, num_x_per_img, num_x_per_ref_img):
        """Computing the similarity scores between `x` and `ref_x`.

        Args:
            x (Tensor): of shape [N, C, H, W]. N is the number of key frame
                proposals.
            ref_x (Tensor): of shape [M, C, H, W]. M is the number of reference
                frame proposals.
            num_x_per_img (list[int]): The `x` contains proposals of
                multi-images. `num_x_per_img` denotes the number of proposals
                for each key image.
            num_x_per_ref_img (list[int]): The `ref_x` contains proposals of
                multi-images. `num_x_per_ref_img` denotes the number of
                proposals for each reference image.

        Returns:
            list[Tensor]: The predicted similarity_logits of each pair of key
            image and reference image.
        """
        x_split = self._forward(x, num_x_per_img)
        ref_x_split = self._forward(ref_x, num_x_per_ref_img)

        similarity_logits = []
        for one_x, one_ref_x in zip(x_split, ref_x_split):
            similarity_logit = embed_similarity(
                one_x, one_ref_x, method='dot_product')
            dummy = similarity_logit.new_zeros(one_x.shape[0], 1)
            similarity_logit = torch.cat((dummy, similarity_logit), dim=1)
            similarity_logits.append(similarity_logit)
        return similarity_logits
    def match(self, key_embeds, ref_embeds, key_sampling_results,
              ref_sampling_results):
        """Calculate the dist matrixes for loss measurement.

        Args:
            key_embeds (Tensor): Embeds of positive bboxes in sampling results
                of key image.
            ref_embeds (Tensor): Embeds of all bboxes in sampling results
                of the reference image.
            keysampling_results (List[obj:SamplingResults]): Assign results of
                all images in a batch after sampling.
            ref_sampling_results (List[obj:SamplingResults]): Assign results of
                all reference images in a batch after sampling.

        Returns:
            Tuple[list[Tensor]]: Calculation results.
            Containing the following list of Tensors:

                - dists (list[Tensor]): Dot-product dists between
                    key_embeds and ref_embeds, each tensor in list has
                    shape (len(key_pos_bboxes), len(ref_bboxes)).
                - cos_dists (list[Tensor]): Cosine dists between
                    key_embeds and ref_embeds, each tensor in list has
                    shape (len(key_pos_bboxes), len(ref_bboxes)).
        """

        num_key_rois = [res.pos_bboxes.size(0) for res in key_sampling_results]
        key_embeds = torch.split(key_embeds, num_key_rois)
        num_ref_rois = [res.bboxes.size(0) for res in ref_sampling_results]
        ref_embeds = torch.split(ref_embeds, num_ref_rois)

        dists, cos_dists = [], []
        for key_embed, ref_embed in zip(key_embeds, ref_embeds):
            dist = embed_similarity(
                key_embed,
                ref_embed,
                method='dot_product',
                temperature=self.softmax_temp)
            dists.append(dist)
            if self.loss_track_aux is not None:
                cos_dist = embed_similarity(
                    key_embed, ref_embed, method='cosine')
                cos_dists.append(cos_dist)
            else:
                cos_dists.append(None)
        return dists, cos_dists