Ejemplo n.º 1
0
    def forward(self, feats, which='left'):
        """
        Forward pass of the attention estimator
        
        :param feats: feature of shape (B, D, H, W)
        :param scales: various scales at which the map should be computed
        """

        _, _, H, W = feats.size()

        # resize feats for to each scale
        maps = []
        for scale in self.scales:
            # scale features
            H_resized, W_resized = int(H * scale), int(W * scale)
            feats_resized = F.interpolate(feats, (H_resized, W_resized),
                                          mode='bilinear')

            # extract attention maps, and scale back
            # (B, 1, H, W)
            map_resized = F.interpolate(self.regress(feats_resized), (H, W),
                                        mode='bilinear')
            maps.append(map_resized)

        # (B, N, H, W)
        maps = torch.cat(maps, dim=1)
        # softmax application
        maps = F.softmax(maps, dim=1)

        # visualize attention map
        args = {'map_' + which: maps[0]}
        logger.update(**args)

        return maps
Ejemplo n.º 2
0
def hard_negative_mining(desc_seeds,
                         desc_maps,
                         kps,
                         images,
                         thresh=16,
                         interval=16):
    """
    The mined locations should be thresh pixels away from the kps.
    To reduce the computational cost, we sample the negative locations every interval pixels.
    desc_seeds: [B, N, D]
    desc_maps: [B, D, H', W']
    kps: [B, N, 2]
    images: [B, 3, H, W]
    :return descs [B, N, D]
    """
    with torch.no_grad():
        # rescale the kps to the size of desc_map
        ratio_h = desc_maps.shape[2] / images.shape[2]
        ratio_w = desc_maps.shape[3] / images.shape[3]
        ratio = torch.tensor([ratio_w, ratio_h]).cuda()
        kps = kps.clone().detach()
        kps *= ratio

        # hard negative mining
        neg_kps = hard_example_mining_layer(desc_maps, desc_seeds, kps, thresh,
                                            interval).float()  # [B, N, 3]
        neg_kps = neg_kps[:, :, 1:]
        neg_kps /= ratio

    logger.update(kps2=neg_kps[0])
    descs = sample_descriptor(desc_maps, neg_kps, images)
    return descs
Ejemplo n.º 3
0
    def forward(self, images0, images1, targets):
        """
        images0: [B, 3, H, W]
        images1: [B, 3, H, W]
        targets: {'kps0': [B, N, 2], 'kps1': [B, N, 2]}
        """
        descs0 = self.desc_extractor(images0, 1)
        descs1 = self.desc_extractor(images1, targets['scale'])

        results = dict(
            descs0=descs0,
            descs1=descs1
        )

        loss, distance, similarity = self.desc_evaluator(
            descs0, targets['kps0'], images0,
            descs1, targets['kps1'], targets['kps2'], images1
        )
        losses = dict(loss=loss, distance=distance, similarity=similarity)

        # keep descriptors for visualization
        logger.update(image0=images0[0], image1=images1[0])
        logger.update(kps0=targets['kps0'][0], kps1=targets['kps1'][0])
        logger.update(d03=descs0[0], d13=descs1[0])
        logger.update(H=targets['H'][0])

        return losses, results
Ejemplo n.º 4
0
    def forward(self, images0, images1, targets):
        """
        images0: [B, 3, H, W]
        images1: [B, 3, H, W]
        targets: {'kps0': [B, N, 2], 'kps1': [B, N, 2]}
        """
        descs0 = self.desc_extractor(images0)
        descs1 = self.desc_extractor(images1)

        scale_pred, scale_loss = self.scale_estimator(descs0, descs1,
                                                      targets['scale'],
                                                      targets['msk'])

        results = dict(descs0=descs0, descs1=descs1)

        losses = dict(loss=scale_loss)

        # keep descriptors for visualization
        logger.update(image0=images0[0], image1=images1[0])
        logger.update(scale_pred=scale_pred[0])
        logger.update(scale=targets['scale'][0])
        logger.update(msk=targets['msk'][0])

        return losses, results
Ejemplo n.º 5
0
    def forward(self, images0, images1, targets=None):
        """
        images0: [N, 3, H, W]
        images1: [N, 3, H, W]
        targets: {"img2": [N, 3, H, W], "kps0": [N, 3000], "kps1": [N, 3000], "kps2": [N, 3000]}
        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")

        descrs0 = self.feature_extractor(images0)
        descrs1 = self.feature_extractor(images1)

        # img = draw_desc_torch(
        #     images0[0], descrs0[0], targets['kps0'][0],
        #     images1[0], descrs1[0], targets['kps1'][0],
        #     targets['H'][0],
        # )
        # import matplotlib.pyplot as plt
        # plt.imshow(img)
        # plt.show()

        results = dict(
            descrs0=descrs0,
            descrs1=descrs1,
        )
        if not self.training:
            return results

        loss, pos_loss, neg_loss = self.evaluator(descrs0, targets["kps0"],
                                                  images0, descrs1,
                                                  targets["kps1"],
                                                  targets["kps2"], images1)

        # if targets["iteration"] % 100 == 0:
        #     print(pos_loss, neg_loss)

        losses = dict(loss=loss, distance=pos_loss, similarity=neg_loss)

        # keep descriptors for visualization
        logger.update(image0=images0[0], image1=images1[0])
        logger.update(kps0=targets['kps0'][0], kps1=targets['kps1'][0])
        logger.update(desc0=descrs0[0], desc1=descrs1[0])

        return losses, results
Ejemplo n.º 6
0
    def forward(self, images0, images1, targets):
        """
        images0: [B, 3, H, W]
        images1: [B, 3, H, W]
        targets: {'kps0': [B, N, 2], 'kps1': [B, N, 2]}
        """
        descs0 = self.desc_extractor(images0, targets['left_scale'], 'left')
        descs1 = self.desc_extractor(images1, targets['scale'], 'right')

        desc_loss, distance, similarity = self.desc_evaluator(
            descs0, targets['kps0'], images0,
            descs1, targets['kps1'], targets['kps2'], images1
        )
        scale_pred, scale_loss = self.scale_estimator(descs0, descs1, targets['scale'], targets['msk'])

        results = dict(
            descs0=descs0,
            descs1=descs1
        )

        # losses = dict(loss=scale_loss)
        loss = desc_loss + scale_loss
        losses = dict(loss=loss, desc_loss=desc_loss, distance=distance, similarity=similarity, scale_loss=scale_loss)

        # keep descriptors for visualization
        logger.update(image0=images0[0], image1=images1[0])
        logger.update(kps0=targets['kps0'][0], kps1=targets['kps1'][0], kps2=targets['kps2'][0])
        logger.update(d03=descs0[0], d13=descs1[0])
        logger.update(H=targets['H'][0])
        logger.update(scale_pred=scale_pred[0])
        logger.update(scale=targets['scale'][0])
        logger.update(msk=targets['msk'][0])

        return losses, results