コード例 #1
0
def test_valid_mask():
  from utils.utils import pltImshow
  batch_size = 1
  mat_homographies = [sample_homography(3) for i in range(batch_size)]
  mat_H = np.stack(mat_homographies, axis=0)


  corner_img = np.array([(-1, -1), (-1, 1), (1, -1), (1, 1)])
  # printCorners(corner_img, mat_H)
  # points = warp_points_np(corner_img, mat_homographies)

  mat_H = torch.tensor(mat_H, dtype=torch.float32)
  mat_H_inv = torch.stack([torch.inverse(mat_H[i, :, :]) for i in range(batch_size)])
  from utils.utils import compute_valid_mask, labels2Dto3D
  device = 'cpu'
  shape = torch.tensor([240, 320])
  for i in range(1):
    r = 3
    mask_valid = compute_valid_mask(shape, inv_homography=mat_H_inv, device=device, erosion_radius=r)
    pltImshow(mask_valid[0,:,:])
    cell_size = 8
    mask_valid = labels2Dto3D(mask_valid.view(batch_size, 1, mask_valid.shape[1], mask_valid.shape[2]), cell_size=cell_size)
    mask_valid = torch.prod(mask_valid[:,:cell_size*cell_size,:,:], dim=1)
    pltImshow(mask_valid[0,:,:].cpu().numpy())

  mask = {}
  mask.update({'homographies': mat_H, 'masks': mask_valid})
  np.savez_compressed('h2.npz', **mask)
  print("finish testing valid mask")
コード例 #2
0
 def getMasks(self, mask_2D, cell_size, device="cpu"):
     """
     # 2D mask is constructed into 3D (Hc, Wc) space for training
     :param mask_2D:
         tensor [batch, 1, H, W]
     :param cell_size:
         8 (default)
     :param device:
     :return:
         flattened 3D mask for training
     """
     mask_3D = labels2Dto3D(mask_2D.to(device),
                            cell_size=cell_size,
                            add_dustbin=False).float()
     mask_3D_flattened = torch.prod(mask_3D, 1)
     return mask_3D_flattened
コード例 #3
0
    def train_val_sample(self, sample, n_iter=0, train=False):
        """
        # key function
        :param sample:
        :param n_iter:
        :param train:
        :return:
        """
        to_floatTensor = lambda x: torch.tensor(x).type(torch.FloatTensor)

        task = "train" if train else "val"
        tb_interval = self.config["tensorboard_interval"]
        if_warp = self.config['data']['warped_pair']['enable']

        self.scalar_dict, self.images_dict, self.hist_dict = {}, {}, {}
        ## get the inputs
        # logging.info('get input img and label')
        img, labels_2D, mask_2D = (
            sample["image"],
            sample["labels_2D"],
            sample["valid_mask"],
        )
        # img, labels = img.to(self.device), labels_2D.to(self.device)

        # variables
        batch_size, H, W = img.shape[0], img.shape[2], img.shape[3]
        self.batch_size = batch_size
        det_loss_type = self.config["model"]["detector_loss"]["loss_type"]
        # print("batch_size: ", batch_size)
        Hc = H // self.cell_size
        Wc = W // self.cell_size

        # warped images
        # img_warp, labels_warp_2D, mask_warp_2D = sample['warped_img'].to(self.device), \
        #     sample['warped_labels'].to(self.device), \
        #     sample['warped_valid_mask'].to(self.device)
        if if_warp:
            img_warp, labels_warp_2D, mask_warp_2D = (
                sample["warped_img"],
                sample["warped_labels"],
                sample["warped_valid_mask"],
            )

        # homographies
        # mat_H, mat_H_inv = \
        # sample['homographies'].to(self.device), sample['inv_homographies'].to(self.device)
        if if_warp:
            mat_H, mat_H_inv = sample["homographies"], sample[
                "inv_homographies"]

        # zero the parameter gradients
        self.optimizer.zero_grad()

        # forward + backward + optimize
        if train:
            # print("img: ", img.shape, ", img_warp: ", img_warp.shape)
            outs = self.net(img.to(self.device))
            semi, coarse_desc = outs["semi"], outs["desc"]
            if if_warp:
                outs_warp = self.net(img_warp.to(self.device))
                semi_warp, coarse_desc_warp = outs_warp["semi"], outs_warp[
                    "desc"]
        else:
            with torch.no_grad():
                outs = self.net(img.to(self.device))
                semi, coarse_desc = outs["semi"], outs["desc"]
                if if_warp:
                    outs_warp = self.net(img_warp.to(self.device))
                    semi_warp, coarse_desc_warp = outs_warp["semi"], outs_warp[
                        "desc"]
                pass

        # detector loss
        from utils.utils import labels2Dto3D

        if self.gaussian:
            labels_2D = sample["labels_2D_gaussian"]
            if if_warp:
                warped_labels = sample["warped_labels_gaussian"]
        else:
            labels_2D = sample["labels_2D"]
            if if_warp:
                warped_labels = sample["warped_labels"]

        add_dustbin = False
        if det_loss_type == "l2":
            add_dustbin = False
        elif det_loss_type == "softmax":
            add_dustbin = True

        labels_3D = labels2Dto3D(labels_2D.to(self.device),
                                 cell_size=self.cell_size,
                                 add_dustbin=add_dustbin).float()
        mask_3D_flattened = self.getMasks(mask_2D,
                                          self.cell_size,
                                          device=self.device)
        loss_det = self.detector_loss(
            input=outs["semi"],
            target=labels_3D.to(self.device),
            mask=mask_3D_flattened,
            loss_type=det_loss_type,
        )
        # warp
        if if_warp:
            labels_3D = labels2Dto3D(
                warped_labels.to(self.device),
                cell_size=self.cell_size,
                add_dustbin=add_dustbin,
            ).float()
            mask_3D_flattened = self.getMasks(mask_warp_2D,
                                              self.cell_size,
                                              device=self.device)
            loss_det_warp = self.detector_loss(
                input=outs_warp["semi"],
                target=labels_3D.to(self.device),
                mask=mask_3D_flattened,
                loss_type=det_loss_type,
            )
        else:
            loss_det_warp = torch.tensor([0]).to(self.device)

        ## get labels, masks, loss for detection
        # labels3D_in_loss = self.getLabels(labels_2D, self.cell_size, device=self.device)
        # mask_3D_flattened = self.getMasks(mask_2D, self.cell_size, device=self.device)
        # loss_det = self.get_loss(semi, labels3D_in_loss, mask_3D_flattened, device=self.device)

        ## warping
        # labels3D_in_loss = self.getLabels(labels_warp_2D, self.cell_size, device=self.device)
        # mask_3D_flattened = self.getMasks(mask_warp_2D, self.cell_size, device=self.device)
        # loss_det_warp = self.get_loss(semi_warp, labels3D_in_loss, mask_3D_flattened, device=self.device)

        mask_desc = mask_3D_flattened.unsqueeze(1)
        lambda_loss = self.config["model"]["lambda_loss"]
        # print("mask_desc: ", mask_desc.shape)
        # print("mask_warp_2D: ", mask_warp_2D.shape)

        # descriptor loss
        if lambda_loss > 0:
            assert if_warp == True, "need a pair of images"
            loss_desc, mask, positive_dist, negative_dist = self.descriptor_loss(
                coarse_desc,
                coarse_desc_warp,
                mat_H,
                mask_valid=mask_desc,
                device=self.device,
                **self.desc_params)
        else:
            ze = torch.tensor([0]).to(self.device)
            loss_desc, positive_dist, negative_dist = ze, ze, ze

        loss = loss_det + loss_det_warp
        if lambda_loss > 0:
            loss += lambda_loss * loss_desc

        ##### try to minimize the error ######
        add_res_loss = False
        if add_res_loss and n_iter % 10 == 0:
            print("add_res_loss!!!")
            heatmap_org = self.get_heatmap(semi, det_loss_type)  # tensor []
            heatmap_org_nms_batch = self.heatmap_to_nms(self.images_dict,
                                                        heatmap_org,
                                                        name="heatmap_org")
            if if_warp:
                heatmap_warp = self.get_heatmap(semi_warp, det_loss_type)
                heatmap_warp_nms_batch = self.heatmap_to_nms(
                    self.images_dict, heatmap_warp, name="heatmap_warp")

            # original: pred
            ## check the loss on given labels!
            outs_res = self.get_residual_loss(
                sample["labels_2D"] *
                to_floatTensor(heatmap_org_nms_batch).unsqueeze(1),
                heatmap_org,
                sample["labels_res"],
                name="original_pred",
            )
            loss_res_ori = (outs_res["loss"]**2).mean()
            # warped: pred
            if if_warp:
                outs_res_warp = self.get_residual_loss(
                    sample["warped_labels"] *
                    to_floatTensor(heatmap_warp_nms_batch).unsqueeze(1),
                    heatmap_warp,
                    sample["warped_res"],
                    name="warped_pred",
                )
                loss_res_warp = (outs_res_warp["loss"]**2).mean()
            else:
                loss_res_warp = torch.tensor([0]).to(self.device)
            loss_res = loss_res_ori + loss_res_warp
            # print("loss_res requires_grad: ", loss_res.requires_grad)
            loss += loss_res
            self.scalar_dict.update({
                "loss_res_ori": loss_res_ori,
                "loss_res_warp": loss_res_warp
            })

        #######################################

        self.loss = loss

        self.scalar_dict.update({
            "loss": loss,
            "loss_det": loss_det,
            "loss_det_warp": loss_det_warp,
            "positive_dist": positive_dist,
            "negative_dist": negative_dist,
        })

        self.input_to_imgDict(sample, self.images_dict)

        if train:
            loss.backward()
            self.optimizer.step()

        if n_iter % tb_interval == 0 or task == "val":
            logging.info("current iteration: %d, tensorboard_interval: %d",
                         n_iter, tb_interval)

            # add clean map to tensorboard
            ## semi_warp: flatten, to_numpy

            heatmap_org = self.get_heatmap(semi, det_loss_type)  # tensor []
            heatmap_org_nms_batch = self.heatmap_to_nms(self.images_dict,
                                                        heatmap_org,
                                                        name="heatmap_org")
            if if_warp:
                heatmap_warp = self.get_heatmap(semi_warp, det_loss_type)
                heatmap_warp_nms_batch = self.heatmap_to_nms(
                    self.images_dict, heatmap_warp, name="heatmap_warp")

            def update_overlap(images_dict, labels_warp_2D, heatmap_nms_batch,
                               img_warp, name):
                # image overlap
                from utils.draw import img_overlap

                # result_overlap = img_overlap(img_r, img_g, img_gray)
                # overlap label, nms, img
                nms_overlap = [
                    img_overlap(
                        toNumpy(labels_warp_2D[i]),
                        heatmap_nms_batch[i],
                        toNumpy(img_warp[i]),
                    ) for i in range(heatmap_nms_batch.shape[0])
                ]
                nms_overlap = np.stack(nms_overlap, axis=0)
                images_dict.update({name + "_nms_overlap": nms_overlap})

            from utils.var_dim import toNumpy
            update_overlap(
                self.images_dict,
                labels_2D,
                heatmap_org_nms_batch[np.newaxis, ...],
                img,
                "original",
            )

            update_overlap(
                self.images_dict,
                labels_2D,
                toNumpy(heatmap_org),
                img,
                "original_heatmap",
            )
            if if_warp:
                update_overlap(
                    self.images_dict,
                    labels_warp_2D,
                    heatmap_warp_nms_batch[np.newaxis, ...],
                    img_warp,
                    "warped",
                )
                update_overlap(
                    self.images_dict,
                    labels_warp_2D,
                    toNumpy(heatmap_warp),
                    img_warp,
                    "warped_heatmap",
                )
            # residuals
            from utils.losses import do_log

            if self.gaussian:
                # original: gt
                self.get_residual_loss(
                    sample["labels_2D"],
                    sample["labels_2D_gaussian"],
                    sample["labels_res"],
                    name="original_gt",
                )
                if if_warp:
                    # warped: gt
                    self.get_residual_loss(
                        sample["warped_labels"],
                        sample["warped_labels_gaussian"],
                        sample["warped_res"],
                        name="warped_gt",
                    )

            # from utils.losses import do_log
            # patches_log = do_log(patches)

            # original: pred
            ## check the loss on given labels!
            # self.get_residual_loss(
            #     sample["labels_2D"]
            #     * to_floatTensor(heatmap_org_nms_batch).unsqueeze(1),
            #     heatmap_org,
            #     sample["labels_res"],
            #     name="original_pred",
            # )
            # print("heatmap_org_nms_batch: ", heatmap_org_nms_batch.shape)
            # get_residual_loss(to_floatTensor(heatmap_org_nms_batch).unsqueeze(1), heatmap_org,
            # sample['labels_res'], name='original_pred')
            # warped: pred
            # self.get_residual_loss(
            #     sample["warped_labels"]
            #     * to_floatTensor(heatmap_warp_nms_batch).unsqueeze(1),
            #     heatmap_warp,
            #     sample["warped_res"],
            #     name="warped_pred",
            # )
            # get_residual_loss(to_floatTensor(heatmap_warp_nms_batch).unsqueeze(1), heatmap_warp,
            # sample['warped_res'], name='warped_pred')

            # precision, recall
            # pr_mean = self.batch_precision_recall(
            #     to_floatTensor(heatmap_warp_nms_batch[:, np.newaxis, ...]),
            #     sample["warped_labels"],
            # )
            pr_mean = self.batch_precision_recall(
                to_floatTensor(heatmap_org_nms_batch[:, np.newaxis, ...]),
                sample["labels_2D"],
            )
            print("pr_mean")
            self.scalar_dict.update(pr_mean)

            self.printLosses(self.scalar_dict, task)
            self.tb_images_dict(task, self.images_dict, max_img=2)
            self.tb_hist_dict(task, self.hist_dict)

        self.tb_scalar_dict(self.scalar_dict, task)

        return loss.item()