コード例 #1
0
    def forward(self, batch):
        im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch

        im1_rawsc, im1_scale, im1_orin = self.det(im1_data)
        im2_rawsc, im2_scale, im2_orin = self.det(im2_data)

        im1_gtscale, im1_gtorin = self.gt_scale_orin(im2_scale, im2_orin,
                                                     homo12, homo21)
        im2_gtscale, im2_gtorin = self.gt_scale_orin(im1_scale, im1_orin,
                                                     homo21, homo12)

        im1_gtsc, im1_topkmask, im1_topkvalue, im1_visiblemask = self.gtscore(
            im2_rawsc, homo12)
        im2_gtsc, im2_topkmask, im2_topkvalue, im2_visiblemask = self.gtscore(
            im1_rawsc, homo21)

        im1_score = self.det.process(im1_rawsc)[0]
        im2_score = self.det.process(im2_rawsc)[0]

        ###############################################################################
        # Extract patch and its descriptors by corresponding scale and orination
        ###############################################################################
        # (B*topk, 2, 32, 32)
        im1_ppair, im1_limc, im1_rimcw = pair(
            im1_topkmask,
            im1_topkvalue,
            im1_scale,
            im1_orin,
            im1_info,
            im1_raw,
            homo12,
            im2_gtscale,
            im2_gtorin,
            im2_info,
            im2_raw,
            self.PSIZE,
        )
        im2_ppair, im2_limc, im2_rimcw = pair(
            im2_topkmask,
            im2_topkvalue,
            im2_scale,
            im2_orin,
            im2_info,
            im2_raw,
            homo21,
            im1_gtscale,
            im1_gtorin,
            im1_info,
            im1_raw,
            self.PSIZE,
        )

        im1_lpatch, im1_rpatch = im1_ppair.chunk(chunks=2,
                                                 dim=1)  # each is (N, 32, 32)
        im2_lpatch, im2_rpatch = im2_ppair.chunk(chunks=2,
                                                 dim=1)  # each is (N, 32, 32)

        im1_lpdes, im1_rpdes = self.des(im1_lpatch), self.des(im1_rpatch)
        im2_lpdes, im2_rpdes = self.des(im2_lpatch), self.des(im2_rpatch)

        ###############################################################################
        # Extract patch and its descriptors by predicted scale and orination
        ###############################################################################
        # (B*topk, 2, 32, 32)
        im1_predpair, _, _ = pair(
            im1_topkmask,
            im1_topkvalue,
            im1_scale,
            im1_orin,
            im1_info,
            im1_raw,
            homo12,
            im2_scale,
            im2_orin,
            im2_info,
            im2_raw,
            self.PSIZE,
        )
        im2_predpair, _, _ = pair(
            im2_topkmask,
            im2_topkvalue,
            im2_scale,
            im2_orin,
            im2_info,
            im2_raw,
            homo21,
            im1_scale,
            im1_orin,
            im1_info,
            im1_raw,
            self.PSIZE,
        )

        # each is (N, 32, 32)
        im1_lpredpatch, im1_rpredpatch = im1_predpair.chunk(chunks=2, dim=1)
        im2_lpredpatch, im2_rpredpatch = im2_predpair.chunk(chunks=2, dim=1)

        im1_lpreddes, im1_rpreddes = self.des(im1_lpredpatch), self.des(
            im1_rpredpatch)
        im2_lpreddes, im2_rpreddes = self.des(im2_lpredpatch), self.des(
            im2_rpredpatch)

        endpoint = {
            "im1_score": im1_score,
            "im1_gtsc": im1_gtsc,
            "im1_visible": im1_visiblemask,
            "im2_score": im2_score,
            "im2_gtsc": im2_gtsc,
            "im2_visible": im2_visiblemask,
            "im1_limc": im1_limc,
            "im1_rimcw": im1_rimcw,
            "im2_limc": im2_limc,
            "im2_rimcw": im2_rimcw,
            "im1_lpdes": im1_lpdes,
            "im1_rpdes": im1_rpdes,
            "im2_lpdes": im2_lpdes,
            "im2_rpdes": im2_rpdes,
            "im1_lpreddes": im1_lpreddes,
            "im1_rpreddes": im1_rpreddes,
            "im2_lpreddes": im2_lpreddes,
            "im2_rpreddes": im2_rpreddes,
        }

        return endpoint
コード例 #2
0
def get_all_endpoints(det, des, batch):
    im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch
    score_maps, orint_maps = det(im1_data)
    im1_rawsc, im1_scale, im1_orin = handle_det_out(score_maps, orint_maps,
                                                    det.scale_list,
                                                    det.score_com_strength,
                                                    det.scale_com_strength)
    score_maps, orint_maps = det(im2_data)
    im2_rawsc, im2_scale, im2_orin = handle_det_out(score_maps, orint_maps,
                                                    det.scale_list,
                                                    det.score_com_strength,
                                                    det.scale_com_strength)
    im1_gtscale, im1_gtorin = gt_scale_orin(im2_scale, im2_orin, homo12,
                                            homo21)
    im2_gtscale, im2_gtorin = gt_scale_orin(im1_scale, im1_orin, homo21,
                                            homo12)
    im2_score = filter_border(im2_rawsc)
    im1w_score = warp(im2_score, homo12)
    im1_visiblemask = warp(
        im2_score.new_full(im2_score.size(), fill_value=1, requires_grad=True),
        homo12,
    )
    im1_gtsc, im1_topkmask, im1_topkvalue = det.process(im1w_score)

    im1_score = filter_border(im1_rawsc)
    im2w_score = warp(im1_score, homo21)
    im2_visiblemask = warp(
        im2_score.new_full(im1_score.size(), fill_value=1, requires_grad=True),
        homo21,
    )
    im2_gtsc, im2_topkmask, im2_topkvalue = det.process(im2w_score)
    im1_score = det.process(im1_rawsc)[0]
    im2_score = det.process(im2_rawsc)[0]
    im1_ppair, im1_limc, im1_rimcw = pair(
        im1_topkmask,
        im1_topkvalue,
        im1_scale,
        im1_orin,
        im1_info,
        im1_raw,
        homo12,
        im2_gtscale,
        im2_gtorin,
        im2_info,
        im2_raw,
        cfg.PATCH.SIZE,
    )
    im2_ppair, im2_limc, im2_rimcw = pair(
        im2_topkmask,
        im2_topkvalue,
        im2_scale,
        im2_orin,
        im2_info,
        im2_raw,
        homo21,
        im1_gtscale,
        im1_gtorin,
        im1_info,
        im1_raw,
        cfg.PATCH.SIZE,
    )
    im1_lpatch, im1_rpatch = im1_ppair.chunk(chunks=2,
                                             dim=1)  # each is (N, 32, 32)
    im2_lpatch, im2_rpatch = im2_ppair.chunk(chunks=2,
                                             dim=1)  # each is (N, 32, 32)

    im1_lpatch = des.input_norm(im1_lpatch)
    im2_lpatch = des.input_norm(im2_lpatch)
    im1_rpatch = des.input_norm(im1_rpatch)
    im2_rpatch = des.input_norm(im2_rpatch)
    im1_lpdes, im1_rpdes = des(im1_lpatch), des(im1_rpatch)
    im2_lpdes, im2_rpdes = des(im2_lpatch), des(im2_rpatch)
    im1_predpair, _, _ = pair(
        im1_topkmask,
        im1_topkvalue,
        im1_scale,
        im1_orin,
        im1_info,
        im1_raw,
        homo12,
        im2_scale,
        im2_orin,
        im2_info,
        im2_raw,
        cfg.PATCH.SIZE,
    )
    im2_predpair, _, _ = pair(
        im2_topkmask,
        im2_topkvalue,
        im2_scale,
        im2_orin,
        im2_info,
        im2_raw,
        homo21,
        im1_scale,
        im1_orin,
        im1_info,
        im1_raw,
        cfg.PATCH.SIZE,
    )
    # each is (N, 32, 32)
    im1_lpredpatch, im1_rpredpatch = im1_predpair.chunk(chunks=2, dim=1)
    im2_lpredpatch, im2_rpredpatch = im2_predpair.chunk(chunks=2, dim=1)
    im1_lpredpatch = des.input_norm(im1_lpredpatch)
    im2_lpredpatch = des.input_norm(im2_lpredpatch)
    im1_rpredpatch = des.input_norm(im1_rpredpatch)
    im2_rpredpatch = des.input_norm(im2_rpredpatch)

    im1_lpreddes, im1_rpreddes = des(im1_lpredpatch), des(im1_rpredpatch)
    im2_lpreddes, im2_rpreddes = des(im2_lpredpatch), des(im2_rpredpatch)
    endpoint = {
        "im1_score": im1_score,
        "im1_gtsc": im1_gtsc,
        "im1_visible": im1_visiblemask,
        "im2_score": im2_score,
        "im2_gtsc": im2_gtsc,
        "im2_visible": im2_visiblemask,
        "im1_lpreddes": im1_lpreddes,
        "im1_rpreddes": im1_rpreddes,
        "im2_lpreddes": im2_lpreddes,
        "im2_rpreddes": im2_rpreddes,
        "im1_limc": im1_limc,  #
        "im1_rimcw": im1_rimcw,  #
        "im2_limc": im2_limc,  #
        "im2_rimcw": im2_rimcw,  #
        "im1_lpdes": im1_lpdes,  #
        "im1_rpdes": im1_rpdes,  #
        "im2_lpdes": im2_lpdes,  #
        "im2_rpdes": im2_rpdes,  #
    }
    return endpoint
コード例 #3
0
    def train():
        start_time = time.time()
        for i_batch, sample_batched in enumerate(train_data, 1):
            det.train()
            des.train()
            batch = parse_batch(sample_batched, device)
            with autograd.detect_anomaly():
                for des_train in range(0, cfg.TRAIN.DES):
                    des.zero_grad()
                    des_optim.zero_grad()
                    im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch
                    score_maps, orint_maps = det(im1_data)
                    im1_rawsc, im1_scale, im1_orin = handle_det_out(
                        score_maps, orint_maps, det.scale_list,
                        det.score_com_strength, det.scale_com_strength)
                    score_maps, orint_maps = det(im2_data)
                    im2_rawsc, im2_scale, im2_orin = handle_det_out(
                        score_maps, orint_maps, det.scale_list,
                        det.score_com_strength, det.scale_com_strength)
                    im1_gtscale, im1_gtorin = gt_scale_orin(
                        im2_scale, im2_orin, homo12, homo21)
                    im2_gtscale, im2_gtorin = gt_scale_orin(
                        im1_scale, im1_orin, homo21, homo12)
                    im2_score = filter_border(im2_rawsc)
                    im1w_score = warp(im2_score, homo12)
                    im1_gtsc, im1_topkmask, im1_topkvalue = det.process(
                        im1w_score)
                    im1_score = filter_border(im1_rawsc)
                    im2w_score = warp(im1_score, homo21)
                    im2_gtsc, im2_topkmask, im2_topkvalue = det.process(
                        im2w_score)
                    im1_ppair, im1_limc, im1_rimcw = pair(
                        im1_topkmask,
                        im1_topkvalue,
                        im1_scale,
                        im1_orin,
                        im1_info,
                        im1_raw,
                        homo12,
                        im2_gtscale,
                        im2_gtorin,
                        im2_info,
                        im2_raw,
                        cfg.PATCH.SIZE,
                    )
                    im2_ppair, im2_limc, im2_rimcw = pair(
                        im2_topkmask,
                        im2_topkvalue,
                        im2_scale,
                        im2_orin,
                        im2_info,
                        im2_raw,
                        homo21,
                        im1_gtscale,
                        im1_gtorin,
                        im1_info,
                        im1_raw,
                        cfg.PATCH.SIZE,
                    )
                    im1_lpatch, im1_rpatch = im1_ppair.chunk(
                        chunks=2, dim=1)  # each is (N, 32, 32)
                    im2_lpatch, im2_rpatch = im2_ppair.chunk(
                        chunks=2, dim=1)  # each is (N, 32, 32)
                    im1_lpatch = des.input_norm(im1_lpatch)
                    im2_lpatch = des.input_norm(im2_lpatch)
                    im1_rpatch = des.input_norm(im1_rpatch)
                    im2_rpatch = des.input_norm(im2_rpatch)
                    im1_lpdes, im1_rpdes = des(im1_lpatch), des(im1_rpatch)
                    im2_lpdes, im2_rpdes = des(im2_lpatch), des(im2_rpatch)
                    endpoint = {
                        "im1_limc": im1_limc,
                        "im1_rimcw": im1_rimcw,
                        "im2_limc": im2_limc,
                        "im2_rimcw": im2_rimcw,
                        "im1_lpdes": im1_lpdes,
                        "im1_rpdes": im1_rpdes,
                        "im2_lpdes": im2_lpdes,
                        "im2_rpdes": im2_rpdes,
                    }

                    desloss = (des.module.criterion(endpoint)
                               if mgpu else des.criterion(endpoint))
                    desloss.backward()
                    des_optim.step()
                for det_train in range(0, cfg.TRAIN.DET):
                    det.zero_grad()
                    det_optim.zero_grad()
                    im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch

                    score_maps, orint_maps = det(im1_data)
                    im1_rawsc, im1_scale, im1_orin = handle_det_out(
                        score_maps, orint_maps, det.scale_list,
                        det.score_com_strength, det.scale_com_strength)
                    score_maps, orint_maps = det(im2_data)
                    im2_rawsc, im2_scale, im2_orin = handle_det_out(
                        score_maps, orint_maps, det.scale_list,
                        det.score_com_strength, det.scale_com_strength)

                    im2_score = filter_border(im2_rawsc)
                    im1w_score = warp(im2_score, homo12)
                    im1_visiblemask = warp(
                        im2_score.new_full(im2_score.size(),
                                           fill_value=1,
                                           requires_grad=True),
                        homo12,
                    )
                    im1_gtsc, im1_topkmask, im1_topkvalue = det.process(
                        im1w_score)

                    im1_score = filter_border(im1_rawsc)
                    im2w_score = warp(im1_score, homo21)
                    im2_visiblemask = warp(
                        im2_score.new_full(im1_score.size(),
                                           fill_value=1,
                                           requires_grad=True),
                        homo21,
                    )
                    im2_gtsc, im2_topkmask, im2_topkvalue = det.process(
                        im2w_score)

                    im1_score = det.process(im1_rawsc)[0]
                    im2_score = det.process(im2_rawsc)[0]
                    im1_predpair, _, _ = pair(
                        im1_topkmask,
                        im1_topkvalue,
                        im1_scale,
                        im1_orin,
                        im1_info,
                        im1_raw,
                        homo12,
                        im2_scale,
                        im2_orin,
                        im2_info,
                        im2_raw,
                        cfg.PATCH.SIZE,
                    )
                    im2_predpair, _, _ = pair(
                        im2_topkmask,
                        im2_topkvalue,
                        im2_scale,
                        im2_orin,
                        im2_info,
                        im2_raw,
                        homo21,
                        im1_scale,
                        im1_orin,
                        im1_info,
                        im1_raw,
                        cfg.PATCH.SIZE,
                    )
                    # each is (N, 32, 32)
                    im1_lpredpatch, im1_rpredpatch = im1_predpair.chunk(
                        chunks=2, dim=1)
                    im2_lpredpatch, im2_rpredpatch = im2_predpair.chunk(
                        chunks=2, dim=1)
                    im1_lpredpatch = des.input_norm(im1_lpredpatch)
                    im2_lpredpatch = des.input_norm(im2_lpredpatch)
                    im1_rpredpatch = des.input_norm(im1_rpredpatch)
                    im2_rpredpatch = des.input_norm(im2_rpredpatch)
                    im1_lpreddes, im1_rpreddes = des(im1_lpredpatch), des(
                        im1_rpredpatch)
                    im2_lpreddes, im2_rpreddes = des(im2_lpredpatch), des(
                        im2_rpredpatch)
                    endpoint = {
                        "im1_score": im1_score,
                        "im1_gtsc": im1_gtsc,
                        "im1_visible": im1_visiblemask,
                        "im2_score": im2_score,
                        "im2_gtsc": im2_gtsc,
                        "im2_visible": im2_visiblemask,
                        "im1_lpreddes": im1_lpreddes,
                        "im1_rpreddes": im1_rpreddes,
                        "im2_lpreddes": im2_lpreddes,
                        "im2_rpreddes": im2_rpreddes,
                    }

                    detloss = (det.module.criterion(endpoint)
                               if mgpu else det.criterion(endpoint))
                    detloss.backward()
                    det_optim.step()

            Lr_Schechuler(cfg.TRAIN.DET_LR_SCHEDULE, det_optim, epoch, cfg)
            Lr_Schechuler(cfg.TRAIN.DES_LR_SCHEDULE, des_optim, epoch, cfg)

            # log
            if i_batch % cfg.TRAIN.LOG_INTERVAL == 0 and i_batch > 0:
                elapsed = time.time() - start_time
                det.eval()
                des.eval()
                with torch.no_grad():
                    parsed_trainbatch = parse_unsqueeze(
                        train_data.dataset[0], device)
                    endpoint = get_all_endpoints(det, des, parsed_trainbatch)
                    detloss = (det.module.criterion(endpoint)
                               if mgpu else det.criterion(endpoint))
                    desloss = (des.module.criterion(endpoint)
                               if mgpu else des.criterion(endpoint))
                    PLT_SCALAR = {}
                    PLT = {"scalar": PLT_SCALAR}
                    PLT_SCALAR["pair_loss"] = detloss
                    PLT_SCALAR["hard_loss"] = desloss
                    PLTS = PLT["scalar"]
                    PLTS["Accuracy"] = getAC(endpoint["im1_lpdes"],
                                             endpoint["im1_rpdes"])
                    PLTS["det_lr"] = det_optim.param_groups[0]["lr"]
                    PLTS["des_lr"] = des_optim.param_groups[0]["lr"]
                    if mgpu:
                        mgpu_merge(PLTS)
                    iteration = (epoch - 1) * len(train_data) + (i_batch - 1)
                    writer_log(train_writer, PLT["scalar"], iteration)

                    pstring = (
                        "epoch {:2d} | {:4d}/{:4d} batches | ms {:4.02f} | "
                        "pair {:05.03f} | des {:05.03f} |".format(
                            epoch,
                            i_batch,
                            len(train_data) // cfg.TRAIN.BATCH_SIZE,
                            elapsed / cfg.TRAIN.LOG_INTERVAL,
                            PLTS["pair_loss"],
                            PLTS["hard_loss"],
                        ))
                    # eval log
                    parsed_valbatch = parse_unsqueeze(val_data.dataset[0],
                                                      device)
                    ept = get_all_endpoints(det, des, parsed_valbatch)
                    detloss = (det.module.criterion(endpoint)
                               if mgpu else det.criterion(endpoint))
                    desloss = (des.module.criterion(endpoint)
                               if mgpu else des.criterion(endpoint))
                    PLT_SCALAR = {}
                    PLT = {"scalar": PLT_SCALAR}
                    PLT_SCALAR["pair_loss"] = detloss
                    PLT_SCALAR["hard_loss"] = desloss
                    PLTS = PLT["scalar"]
                    PLTS["Accuracy"] = getAC(ept["im1_lpdes"],
                                             ept["im1_rpdes"])
                    writer_log(test_writer, PLT["scalar"], iteration)
                    print(f"{gct()} | {pstring}")
                    start_time = time.time()
コード例 #4
0
    def forward(self, batch):
        im1_data, im1_info, homo12, im2_data, im2_info, homo21, im1_raw, im2_raw = batch
        # im1_raw = im1_raw[:,0:1,:,:]
        # im2_raw = im2_raw[:,0:1,:,:]
        im1_rawsc, im1_scale, im1_orin = self.det(im1_data)
        im2_rawsc, im2_scale, im2_orin = self.det(im2_data)

        im1_gtscale, im1_gtorin = self.gt_scale_orin(
            im2_scale, im2_orin, homo12, homo21
        )
        im2_gtscale, im2_gtorin = self.gt_scale_orin(
            im1_scale, im1_orin, homo21, homo12
        )

        im1_gtsc, im1_topkmask, im1_topkvalue, im1_visiblemask = self.gtscore(
            im2_rawsc, homo12
        )
        im2_gtsc, im2_topkmask, im2_topkvalue, im2_visiblemask = self.gtscore(
            im1_rawsc, homo21
        )
        # import pdb
        # pdb.set_trace()
        # from matplotlib import pyplot as plt;plt.imshow(im1_gtsc[0,:,:,0].cpu().detach().numpy());plt.show()
        im1_score = self.det.process(im1_rawsc)[0]
        im2_score = self.det.process(im2_rawsc)[0]

        ###############################################################################
        # Extract patch and its descriptors by corresponding scale and orination
        ###############################################################################
        # (B*topk, 2, 32, 32)
        im1_ppair, im1_limc, im1_rimcw = pair(
            im1_topkmask,
            im1_topkvalue,
            im1_scale,
            im1_orin,
            im1_info,
            im1_raw,
            homo12,
            im2_gtscale,
            im2_gtorin,
            im2_info,
            im2_raw,
            self.PSIZE,
        )
        im2_ppair, im2_limc, im2_rimcw = pair(
            im2_topkmask,
            im2_topkvalue,
            im2_scale,
            im2_orin,
            im2_info,
            im2_raw,
            homo21,
            im1_gtscale,
            im1_gtorin,
            im1_info,
            im1_raw,
            self.PSIZE,
        )

        im1_lpatch, im1_rpatch = im1_ppair.chunk(chunks=2, dim=1)  # each is (N, 32, 32)
        im2_lpatch, im2_rpatch = im2_ppair.chunk(chunks=2, dim=1)  # each is (N, 32, 32)

        im1_lpdes, im1_rpdes = self.des(im1_lpatch), self.des(im1_rpatch)
        im2_lpdes, im2_rpdes = self.des(im2_lpatch), self.des(im2_rpatch)
        ###############################################################################
        # Extract patch and its descriptors by predicted scale and orination
        ###############################################################################
        # (B*topk, 2, 32, 32)
        im1_predpair, _, _ = pair(
            im1_topkmask,
            im1_topkvalue,
            im1_scale,
            im1_orin,
            im1_info,
            im1_raw,
            homo12,
            im2_scale,
            im2_orin,
            im2_info,
            im2_raw,
            self.PSIZE,
        )
        im2_predpair, _, _ = pair(
            im2_topkmask,
            im2_topkvalue,
            im2_scale,
            im2_orin,
            im2_info,
            im2_raw,
            homo21,
            im1_scale,
            im1_orin,
            im1_info,
            im1_raw,
            self.PSIZE,
        )

        # each is (N, 32, 32)
        im1_lpredpatch, im1_rpredpatch = im1_predpair.chunk(chunks=2, dim=1)
        im2_lpredpatch, im2_rpredpatch = im2_predpair.chunk(chunks=2, dim=1)

        im1_lpreddes, im1_rpreddes = self.des(im1_lpredpatch), self.des(im1_rpredpatch)
        im2_lpreddes, im2_rpreddes = self.des(im2_lpredpatch), self.des(im2_rpredpatch)
        endpoint = {
            "im1_score": im1_score,
            "im1_gtsc": im1_gtsc,
            "im1_visible": im1_visiblemask,
            "im2_score": im2_score,
            "im2_gtsc": im2_gtsc,
            "im2_visible": im2_visiblemask,
            "im1_limc": im1_limc,
            "im1_rimcw": im1_rimcw,
            "im2_limc": im2_limc,
            "im2_rimcw": im2_rimcw,
            "im1_lpdes": im1_lpdes,
            "im1_rpdes": im1_rpdes,
            "im2_lpdes": im2_lpdes,
            "im2_rpdes": im2_rpdes,
            "im1_lpreddes": im1_lpreddes,
            "im1_rpreddes": im1_rpreddes,
            "im2_lpreddes": im2_lpreddes,
            "im2_rpreddes": im2_rpreddes,
            #
            # wang add for scale ori loss
            #
            # "im1_scale": im1_scale,
            # "im1_gtscale": im1_gtscale,
            # "im2_scale": im2_scale,
            # "im2_gtscale": im2_gtscale,
            # "im1_orin": im1_orin,
            # "im2_orin": im2_orin,
            # "im1_gtorin": im1_gtorin,
            # "im2_gtorin": im2_gtorin,
        }

        return endpoint