Example #1
0
    def criterion(self, endpoint):

        im1_score = endpoint["im1_score"]
        im1_gtsc = endpoint["im1_gtsc"]
        im1_visible = endpoint["im1_visible"]

        im2_score = endpoint["im2_score"]
        im2_gtsc = endpoint["im2_gtsc"]
        im2_visible = endpoint["im2_visible"]

        im1_limc = endpoint["im1_limc"]
        im1_rimcw = endpoint["im1_rimcw"]
        im2_limc = endpoint["im2_limc"]
        im2_rimcw = endpoint["im2_rimcw"]

        im1_lpdes = endpoint["im1_lpdes"]
        im1_rpdes = endpoint["im1_rpdes"]
        im2_lpdes = endpoint["im2_lpdes"]
        im2_rpdes = endpoint["im2_rpdes"]

        im1_lpreddes = endpoint["im1_lpreddes"]
        im1_rpreddes = endpoint["im1_rpreddes"]
        im2_lpreddes = endpoint["im2_lpreddes"]
        im2_rpreddes = endpoint["im2_rpreddes"]

        #
        # score loss
        #
        im1_scloss = self.det.loss(im1_score, im1_gtsc, im1_visible)
        im2_scloss = self.det.loss(im2_score, im2_gtsc, im2_visible)
        score_loss = (im1_scloss + im2_scloss) / 2.0 * self.SCORE_W

        #
        # pair loss
        #
        im1_pairloss = distance_matrix_vector(im1_lpreddes,
                                              im1_rpreddes).diag().mean()
        im2_pairloss = distance_matrix_vector(im2_lpreddes,
                                              im2_rpreddes).diag().mean()
        pair_loss = (im1_pairloss + im2_pairloss) / 2.0 * self.PAIR_W

        #
        # hard loss
        #
        im1_hardloss = self.des.loss(im1_lpdes, im1_rpdes, im1_limc, im1_rimcw)
        im2_hardloss = self.des.loss(im2_lpdes, im2_rpdes, im2_limc, im2_rimcw)
        hard_loss = (im1_hardloss + im2_hardloss) / 2.0

        # loss summary
        det_loss = score_loss + pair_loss
        des_loss = hard_loss

        PLT_SCALAR = {}
        PLT = {"scalar": PLT_SCALAR}

        PLT_SCALAR["score_loss"] = score_loss
        PLT_SCALAR["pair_loss"] = pair_loss
        PLT_SCALAR["hard_loss"] = hard_loss

        return PLT, det_loss.mean(), des_loss.mean()
Example #2
0
def nearest_neighbor_distance_ratio_match(des1, des2, kp2, threshold):
    des_dist_matrix = distance_matrix_vector(des1, des2)
    sorted, indices = des_dist_matrix.sort(dim=-1)
    Da, Db, Ia = sorted[:, 0], sorted[:, 1], indices[:, 0]
    DistRatio = Da / Db
    predict_label = DistRatio.lt(threshold)
    nn_kp2 = kp2.index_select(dim=0, index=Ia.view(-1))
    return predict_label, nn_kp2
Example #3
0
def getAC(im1_ldes, im1_rdes):
    im1_distmat = distance_matrix_vector(im1_ldes, im1_rdes)
    row_minidx = im1_distmat.sort(dim=1)[1]  # (topk, topk)
    topk = row_minidx.size(0)
    s = row_minidx[:, :5]  # (topk, 5)
    flagim_index = s[:, 0].contiguous().view(-1) == torch.arange(topk).to(
        s.device)
    ac = flagim_index.float().mean()
    return ac
Example #4
0
def vis_descriptor_with_patches(endpoint, cfg, saveimg=False, imname=None):
    psize = cfg.PATCH.size
    save = cfg.TRAIN.SAVE

    def imarrange(imgs, topk):
        imgs = imgs.view(topk, -1, psize, psize)
        imgs = imgs.permute(0, 2, 1, 3).contiguous()
        imgs = imgs.view(topk * psize, -1)
        return imgs

    im1_ldes, im1_rdes = (
        endpoint["im1_ldes"],
        endpoint["im1_rdes"],
    )  # each is (topk, 128)
    im1_distmat = distance_matrix_vector(im1_ldes, im1_rdes)
    row_minidx = im1_distmat.sort(dim=1)[1]  # (topk, topk)
    topk = row_minidx.size(0)
    sorted = row_minidx[:, :5]  # (topk, 5)
    flagim_index = sorted[:, 0].contiguous().view(-1) == torch.arange(topk).to(
        sorted.device)
    ac = flagim_index.float().mean()
    if saveimg is True and imname is not None:
        # save image with batch op
        flagim_index = flagim_index.cpu().detach().numpy()
        im1_ppair = (endpoint["im1_ppair"] * cfg.IMAGE.STD +
                     cfg.IMAGE.MEAN) * 255
        im1_lpatch, im1_rpatch = im1_ppair.chunk(
            chunks=2, dim=1)  # each is (topk, 1, 32, 32)
        tim = cv2.cvtColor(
            cv2.resize(cv2.imread("./tools/t.jpg"),
                       (psize, psize)).astype(np.uint8),
            cv2.COLOR_RGB2GRAY,
        )
        fim = cv2.cvtColor(
            cv2.resize(cv2.imread("./tools/f.jpg"),
                       (psize, psize)).astype(np.uint8),
            cv2.COLOR_RGB2GRAY,
        )
        flagim = (torch.from_numpy(np.stack((fim, tim),
                                            axis=0)).float().to(sorted.device))
        flagr = imarrange(flagim[flagim_index], topk)  # (topk, 32, 32)
        anchor = imarrange(im1_lpatch.squeeze(), topk)  # (topk, 32, 32)
        target = imarrange(im1_rpatch.squeeze(), topk)  # (topk, 32, 32)
        sorted = sorted.contiguous().view(-1).cpu().detach().numpy()
        patches = imarrange(im1_rpatch[sorted].squeeze(),
                            topk)  # (topk*5, 32, 32)
        im1_result = torch.cat((anchor, target, flagr, patches), dim=1)
        imname = imname + f"_ac{ac:05.2f}.png"
        cv2.imwrite(f"{save}/image/{imname}",
                    im1_result.cpu().detach().numpy())
    return ac
Example #5
0
def nearest_neighbor_match_score(des1, des2, kp1w, kp2, visible, COO_THRSH):
    des_dist_matrix = distance_matrix_vector(des1, des2)
    nn_value, nn_idx = des_dist_matrix.min(dim=-1)

    nn_kp2 = kp2.index_select(dim=0, index=nn_idx)

    coo_dist_matrix = pairwise_distances(kp1w[:, 1:3].float(),
                                         nn_kp2[:, 1:3].float()).diag()
    correct_match_label = coo_dist_matrix.le(COO_THRSH) * visible

    correct_matches = correct_match_label.sum().item()
    predict_matches = max(visible.sum().item(), 1)

    return correct_matches, predict_matches
Example #6
0
def threshold_match_score(des1, des2, kp1w, kp2, visible, DES_THRSH,
                          COO_THRSH):
    des_dist_matrix = distance_matrix_vector(des1, des2)
    visible = visible.unsqueeze(-1).repeat(1, des_dist_matrix.size(1))
    predict_label = des_dist_matrix.lt(DES_THRSH) * visible

    coo_dist_matrix = pairwise_distances(kp1w[:, 1:3].float(),
                                         kp2[:, 1:3].float())
    correspondences_label = coo_dist_matrix.le(COO_THRSH) * visible

    correct_match_label = predict_label * correspondences_label

    correct_matches = correct_match_label.sum().item()
    predict_matches = max(predict_label.sum().item(), 1)
    correspond_matches = max(correspondences_label.sum().item(), 1)

    return correct_matches, predict_matches, correspond_matches
Example #7
0
    def loss(self, anchor, positive, anchor_kp, positive_kp):
        """
        HardNetNeiMask
        margin loss - calculates loss based on distance matrix based on positive distance and closest negative distance.
        if set C=0 the loss function is same as hard loss.
        """
        "Input sizes between positive and negative must be equal."
        assert anchor.size() == positive.size()
        "Inputd must be a 2D matrix."
        assert anchor.dim() == 2

        dist_matrix = distance_matrix_vector(anchor, positive)
        eye = torch.eye(dist_matrix.size(1)).to(dist_matrix.device)

        # steps to filter out same patches that occur in distance matrix as negatives
        pos = dist_matrix.diag()
        dist_without_min_on_diag = dist_matrix + eye * 10

        # neighbor mask
        coo_dist_matrix = pairwise_distances(anchor_kp[:, 1:3].to(
            torch.float), anchor_kp[:, 1:3].to(torch.float)).lt(self.C)
        dist_without_min_on_diag = (dist_without_min_on_diag +
                                    coo_dist_matrix.to(torch.float) * 10)
        coo_dist_matrix = pairwise_distances(
            positive_kp[:, 1:3].to(torch.float),
            positive_kp[:, 1:3].to(torch.float)).lt(self.C)
        dist_without_min_on_diag = (dist_without_min_on_diag +
                                    coo_dist_matrix.to(torch.float) * 10)
        col_min = dist_without_min_on_diag.min(dim=1)[0]
        row_min = dist_without_min_on_diag.min(dim=0)[0]
        col_row_min = torch.min(col_min, row_min)

        # triplet loss
        hard_loss = torch.clamp(self.MARGIN + pos - col_row_min, min=0.0)
        hard_loss = hard_loss.mean()

        return hard_loss