def run(self, images):
        """
        input:
            images: tensor[batch(1), 1, H, W]

        """
        from Train_model_heatmap import Train_model_heatmap
        from utils.var_dim import toNumpy
        train_agent = Train_model_heatmap

        with torch.no_grad():
            outs = self.net(images)
        semi = outs['semi']
        self.outs = outs

        channel = semi.shape[1]
        if channel == 64:
            heatmap = train_agent.flatten_64to1(semi, cell_size=self.cell_size)
        elif channel == 65:
            heatmap = flattenDetection(semi, tensor=True)

        heatmap_np = toNumpy(heatmap)
        self.heatmap = heatmap_np
        return heatmap
        pass
    def process_output(self, sp_processer):
        """
        input:
          N: number of points
        return: -- type: tensorFloat
          pts: tensor [batch, N, 2] (no grad)  (x, y)
          pts_offset: tensor [batch, N, 2] (grad) (x, y)
          pts_desc: tensor [batch, N, 256] (grad)
        """
        from utils.utils import flattenDetection
        # from models.model_utils import pred_soft_argmax, sample_desc_from_points
        output = self.output
        semi = output['semi']
        desc = output['desc']
        # flatten
        heatmap = flattenDetection(semi)  # [batch_size, 1, H, W]
        # nms
        heatmap_nms_batch = sp_processer.heatmap_to_nms(heatmap, tensor=True)
        # extract offsets
        outs = sp_processer.pred_soft_argmax(heatmap_nms_batch, heatmap)
        residual = outs['pred']
        # extract points
        outs = sp_processer.batch_extract_features(desc, heatmap_nms_batch,
                                                   residual)

        # output.update({'heatmap': heatmap, 'heatmap_nms': heatmap_nms, 'descriptors': descriptors})
        output.update(outs)
        self.output = output
        return output
Example #3
0
    def run(self, inp, onlyHeatmap=False, train=True):
        """ Process a numpy image to extract points and descriptors.
        Input
          img - HxW tensor float32 input image in range [0,1].
        Output
          corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T.
          desc - 256xN numpy array of corresponding unit normalized descriptors.
          heatmap - HxW numpy heatmap in range [0,1] of point confidences.
          """
        # assert img.ndim == 2, 'Image must be grayscale.'
        # assert img.dtype == np.float32, 'Image must be float32.'
        # H, W = img.shape[0], img.shape[1]
        # inp = img.copy()
        # inp = (inp.reshape(1, H, W))
        # inp = torch.from_numpy(inp)
        # inp = torch.autograd.Variable(inp).view(1, 1, H, W)
        # if self.cuda:
        inp = inp.to(self.device)
        batch_size, H, W = inp.shape[0], inp.shape[2], inp.shape[3]
        if train:
            # outs = self.net.forward(inp, subpixel=self.subpixel)
            outs = self.net.forward(inp)
            # semi, coarse_desc = outs[0], outs[1]
            semi, coarse_desc = outs['semi'], outs['desc']
        else:
            # Forward pass of network.
            with torch.no_grad():
                # outs = self.net.forward(inp, subpixel=self.subpixel)
                outs = self.net.forward(inp)
                # semi, coarse_desc = outs[0], outs[1]
                semi, coarse_desc = outs['semi'], outs['desc']

        # as tensor
        from utils.utils import labels2Dto3D, flattenDetection
        from utils.d2s import DepthToSpace
        # flatten detection
        heatmap = flattenDetection(semi, tensor=True)
        self.heatmap = heatmap
        # depth2space = DepthToSpace(8)
        # print(semi.shape)
        # heatmap = depth2space(semi[:,:-1,:,:]).squeeze(0)
        ## need to change for batches

        if onlyHeatmap:
            return heatmap

        # extract keypoints
        # pts = [self.getPtsFromHeatmap(heatmap[i,:,:,:].cpu().detach().numpy().squeeze()).transpose() for i in range(batch_size)]
        # pts = [self.getPtsFromHeatmap(heatmap[i,:,:,:].cpu().detach().numpy().squeeze()) for i in range(batch_size)]
        # print("heapmap shape: ", heatmap.shape)
        pts = [
            self.getPtsFromHeatmap(heatmap[i, :, :, :].cpu().detach().numpy())
            for i in range(batch_size)
        ]
        self.pts = pts

        if self.subpixel:
            labels_res = outs[2]
            self.pts_subpixel = [
                self.subpixel_predict(toNumpy(labels_res[i, ...]), pts[i])
                for i in range(batch_size)
            ]
        '''
        pts:
            list [batch_size, np(N_i, 3)] -- each point (x, y, probability)
        '''

        # interpolate description
        '''
        coarse_desc:
            tensor (Batch_size, 256, Hc, Wc)
        dense_desc:
            tensor (batch_size, 256, H, W)
        '''
        # m = nn.Upsample(scale_factor=(1, self.cell, self.cell), mode='bilinear')
        dense_desc = nn.functional.interpolate(coarse_desc,
                                               scale_factor=(self.cell,
                                                             self.cell),
                                               mode='bilinear')

        # norm the descriptor
        def norm_desc(desc):
            dn = torch.norm(desc, p=2, dim=1)  # Compute the norm.
            desc = desc.div(torch.unsqueeze(dn,
                                            1))  # Divide by norm to normalize.
            return desc

        dense_desc = norm_desc(dense_desc)

        # extract descriptors
        dense_desc_cpu = dense_desc.cpu().detach().numpy()
        # pts_desc = [dense_desc_cpu[i, :, pts[i][:, 1].astype(int), pts[i][:, 0].astype(int)] for i in range(len(pts))]
        pts_desc = [
            dense_desc_cpu[i, :, pts[i][1, :].astype(int),
                           pts[i][0, :].astype(int)].transpose()
            for i in range(len(pts))
        ]

        if self.subpixel:
            return self.pts_subpixel, pts_desc, dense_desc, heatmap
        return pts, pts_desc, dense_desc, heatmap
Example #4
0
 def get_heatmap(self, semi, det_loss_type="softmax"):
     if det_loss_type == "l2":
         heatmap = self.flatten_64to1(semi)
     else:
         heatmap = flattenDetection(semi)
     return heatmap
Example #5
0
    def add2tensorboard_nms(self,
                            img,
                            labels_2D,
                            semi,
                            task="training",
                            batch_size=1):
        """
        # deprecated:
        :param img:
        :param labels_2D:
        :param semi:
        :param task:
        :param batch_size:
        :return:
        """
        from utils.utils import getPtsFromHeatmap
        from utils.utils import box_nms

        boxNms = False
        n_iter = self.n_iter

        nms_dist = self.config["model"]["nms"]
        conf_thresh = self.config["model"]["detection_threshold"]
        # print("nms_dist: ", nms_dist)
        precision_recall_list = []
        precision_recall_boxnms_list = []
        for idx in range(batch_size):
            semi_flat_tensor = flattenDetection(semi[idx, :, :, :]).detach()
            semi_flat = toNumpy(semi_flat_tensor)
            semi_thd = np.squeeze(semi_flat, 0)
            pts_nms = getPtsFromHeatmap(semi_thd, conf_thresh, nms_dist)
            semi_thd_nms_sample = np.zeros_like(semi_thd)
            semi_thd_nms_sample[pts_nms[1, :].astype(np.int),
                                pts_nms[0, :].astype(np.int)] = 1

            label_sample = torch.squeeze(labels_2D[idx, :, :, :])
            # pts_nms = getPtsFromHeatmap(label_sample.numpy(), conf_thresh, nms_dist)
            # label_sample_rms_sample = np.zeros_like(label_sample.numpy())
            # label_sample_rms_sample[pts_nms[1, :].astype(np.int), pts_nms[0, :].astype(np.int)] = 1
            label_sample_nms_sample = label_sample

            if idx < 5:
                result_overlap = img_overlap(
                    np.expand_dims(label_sample_nms_sample, 0),
                    np.expand_dims(semi_thd_nms_sample, 0),
                    toNumpy(img[idx, :, :, :]),
                )
                self.writer.add_image(
                    task + "-detector_output_thd_overlay-NMS" + "/%d" % idx,
                    result_overlap,
                    n_iter,
                )
            assert semi_thd_nms_sample.shape == label_sample_nms_sample.size()
            precision_recall = precisionRecall_torch(
                torch.from_numpy(semi_thd_nms_sample), label_sample_nms_sample)
            precision_recall_list.append(precision_recall)

            if boxNms:
                semi_flat_tensor_nms = box_nms(semi_flat_tensor.squeeze(),
                                               nms_dist,
                                               min_prob=conf_thresh).cpu()
                semi_flat_tensor_nms = (semi_flat_tensor_nms >=
                                        conf_thresh).float()

                if idx < 5:
                    result_overlap = img_overlap(
                        np.expand_dims(label_sample_nms_sample, 0),
                        semi_flat_tensor_nms.numpy()[np.newaxis, :, :],
                        toNumpy(img[idx, :, :, :]),
                    )
                    self.writer.add_image(
                        task + "-detector_output_thd_overlay-boxNMS" +
                        "/%d" % idx,
                        result_overlap,
                        n_iter,
                    )
                precision_recall_boxnms = precisionRecall_torch(
                    semi_flat_tensor_nms, label_sample_nms_sample)
                precision_recall_boxnms_list.append(precision_recall_boxnms)

        precision = np.mean([
            precision_recall["precision"]
            for precision_recall in precision_recall_list
        ])
        recall = np.mean([
            precision_recall["recall"]
            for precision_recall in precision_recall_list
        ])
        self.writer.add_scalar(task + "-precision_nms", precision, n_iter)
        self.writer.add_scalar(task + "-recall_nms", recall, n_iter)
        print("-- [%s-%d-fast NMS] precision: %.4f, recall: %.4f" %
              (task, n_iter, precision, recall))
        if boxNms:
            precision = np.mean([
                precision_recall["precision"]
                for precision_recall in precision_recall_boxnms_list
            ])
            recall = np.mean([
                precision_recall["recall"]
                for precision_recall in precision_recall_boxnms_list
            ])
            self.writer.add_scalar(task + "-precision_boxnms", precision,
                                   n_iter)
            self.writer.add_scalar(task + "-recall_boxnms", recall, n_iter)
            print("-- [%s-%d-boxNMS] precision: %.4f, recall: %.4f" %
                  (task, n_iter, precision, recall))
Example #6
0
    def addImg2tensorboard(
        self,
        img,
        labels_2D,
        semi,
        img_warp=None,
        labels_warp_2D=None,
        mask_warp_2D=None,
        semi_warp=None,
        mask_3D_flattened=None,
        task="training",
    ):
        """
        # deprecated: add images to tensorboard
        :param img:
        :param labels_2D:
        :param semi:
        :param img_warp:
        :param labels_warp_2D:
        :param mask_warp_2D:
        :param semi_warp:
        :param mask_3D_flattened:
        :param task:
        :return:
        """
        # print("add images to tensorboard")

        n_iter = self.n_iter
        semi_flat = flattenDetection(semi[0, :, :, :])
        semi_warp_flat = flattenDetection(semi_warp[0, :, :, :])

        thd = self.config["model"]["detection_threshold"]
        semi_thd = thd_img(semi_flat, thd=thd)
        semi_warp_thd = thd_img(semi_warp_flat, thd=thd)

        result_overlap = img_overlap(toNumpy(labels_2D[0, :, :, :]),
                                     toNumpy(semi_thd),
                                     toNumpy(img[0, :, :, :]))

        self.writer.add_image(task + "-detector_output_thd_overlay",
                              result_overlap, n_iter)
        saveImg(
            result_overlap.transpose([1, 2, 0])[..., [2, 1, 0]] * 255,
            "test_0.png")  # rgb to bgr * 255

        result_overlap = img_overlap(
            toNumpy(labels_warp_2D[0, :, :, :]),
            toNumpy(semi_warp_thd),
            toNumpy(img_warp[0, :, :, :]),
        )
        self.writer.add_image(task + "-warp_detector_output_thd_overlay",
                              result_overlap, n_iter)
        saveImg(
            result_overlap.transpose([1, 2, 0])[..., [2, 1, 0]] * 255,
            "test_1.png")  # rgb to bgr * 255

        mask_overlap = img_overlap(
            toNumpy(1 - mask_warp_2D[0, :, :, :]) / 2,
            np.zeros_like(toNumpy(img_warp[0, :, :, :])),
            toNumpy(img_warp[0, :, :, :]),
        )

        # writer.add_image(task + '_mask_valid_first_layer', mask_warp[0, :, :, :], n_iter)
        # writer.add_image(task + '_mask_valid_last_layer', mask_warp[-1, :, :, :], n_iter)
        ##### print to check
        # print("mask_2D shape: ", mask_warp_2D.shape)
        # print("mask_3D_flattened shape: ", mask_3D_flattened.shape)
        for i in range(self.batch_size):
            if i < 5:
                self.writer.add_image(task + "-mask_warp_origin",
                                      mask_warp_2D[i, :, :, :], n_iter)
                self.writer.add_image(task + "-mask_warp_3D_flattened",
                                      mask_3D_flattened[i, :, :], n_iter)
        # self.writer.add_image(task + '-mask_warp_origin-1', mask_warp_2D[1, :, :, :], n_iter)
        # self.writer.add_image(task + '-mask_warp_3D_flattened-1', mask_3D_flattened[1, :, :], n_iter)
        self.writer.add_image(task + "-mask_warp_overlay", mask_overlap,
                              n_iter)
Example #7
0
    def train_val_sample(self, sample, n_iter=0, train=False):
        """
        # deprecated: default train_val_sample
        :param sample:
        :param n_iter:
        :param train:
        :return:
        """
        task = "train" if train else "val"
        tb_interval = self.config["tensorboard_interval"]

        losses = {}
        ## get the inputs
        # logging.info('get input img and label')
        img, labels_2D, mask_2D = (
            sample["image"],
            sample["labels_2D"],
            sample["valid_mask"],
        )
        # img, labels = img.to(self.device), labels_2D.to(self.device)

        # variables
        batch_size, H, W = img.shape[0], img.shape[2], img.shape[3]
        self.batch_size = batch_size
        # print("batch_size: ", batch_size)
        Hc = H // self.cell_size
        Wc = W // self.cell_size

        # warped images
        # img_warp, labels_warp_2D, mask_warp_2D = sample['warped_img'].to(self.device), \
        #     sample['warped_labels'].to(self.device), \
        #     sample['warped_valid_mask'].to(self.device)
        img_warp, labels_warp_2D, mask_warp_2D = (
            sample["warped_img"],
            sample["warped_labels"],
            sample["warped_valid_mask"],
        )

        # homographies
        # mat_H, mat_H_inv = \
        # sample['homographies'].to(self.device), sample['inv_homographies'].to(self.device)
        mat_H, mat_H_inv = sample["homographies"], sample["inv_homographies"]

        # zero the parameter gradients
        self.optimizer.zero_grad()

        # forward + backward + optimize
        if train:
            # print("img: ", img.shape, ", img_warp: ", img_warp.shape)
            outs, outs_warp = (
                self.net(img.to(self.device)),
                self.net(img_warp.to(self.device), subpixel=self.subpixel),
            )
            semi, coarse_desc = outs[0], outs[1]
            semi_warp, coarse_desc_warp = outs_warp[0], outs_warp[1]
        else:
            with torch.no_grad():
                outs, outs_warp = (
                    self.net(img.to(self.device)),
                    self.net(img_warp.to(self.device), subpixel=self.subpixel),
                )
                semi, coarse_desc = outs[0], outs[1]
                semi_warp, coarse_desc_warp = outs_warp[0], outs_warp[1]
                pass

        # detector loss
        ## get labels, masks, loss for detection
        labels3D_in_loss = self.getLabels(labels_2D,
                                          self.cell_size,
                                          device=self.device)
        mask_3D_flattened = self.getMasks(mask_2D,
                                          self.cell_size,
                                          device=self.device)
        loss_det = self.get_loss(semi,
                                 labels3D_in_loss,
                                 mask_3D_flattened,
                                 device=self.device)

        ## warping
        labels3D_in_loss = self.getLabels(labels_warp_2D,
                                          self.cell_size,
                                          device=self.device)
        mask_3D_flattened = self.getMasks(mask_warp_2D,
                                          self.cell_size,
                                          device=self.device)
        loss_det_warp = self.get_loss(semi_warp,
                                      labels3D_in_loss,
                                      mask_3D_flattened,
                                      device=self.device)

        mask_desc = mask_3D_flattened.unsqueeze(1)

        # print("mask_desc: ", mask_desc.shape)
        # print("mask_warp_2D: ", mask_warp_2D.shape)

        # descriptor loss

        # if self.desc_loss_type == 'dense':
        loss_desc, mask, positive_dist, negative_dist = self.descriptor_loss(
            coarse_desc,
            coarse_desc_warp,
            mat_H,
            mask_valid=mask_desc,
            device=self.device,
            **self.desc_params)

        loss = (loss_det + loss_det_warp +
                self.config["model"]["lambda_loss"] * loss_desc)

        if self.subpixel:
            # coarse to dense descriptor
            # work on warped level
            # dense_desc = interpolate_to_dense(coarse_desc_warp, cell_size=self.cell_size) # tensor [batch, 256, H, W]
            dense_map = flattenDetection(semi_warp)  # tensor [batch, 1, H, W]
            # concat image and dense_desc
            concat_features = torch.cat((img_warp.to(self.device), dense_map),
                                        dim=1)  # tensor [batch, n, H, W]
            # prediction
            # pred_heatmap = self.subpixNet(concat_features.to(self.device)) # tensor [batch, 1, H, W]
            pred_heatmap = outs_warp[2]  # tensor [batch, 1, H, W]
            # print("pred_heatmap: ",  pred_heatmap.shape)
            # add histogram here
            # tensor [batch, channels, H, W]
            # loss
            labels_warped_res = sample["warped_res"]
            # writer.add_histogram(task + '-' + 'warped_res',
            #     labels_warped_res[0,...].clone().cpu().data.numpy().transpose(0,1).transpose(1,2).view(-1, 2),
            #     n_iter)

            # from utils.losses import subpixel_loss
            subpix_loss = self.subpixel_loss_func(
                labels_warp_2D.to(self.device),
                labels_warped_res.to(self.device),
                pred_heatmap.to(self.device),
                patch_size=11,
            )
            # print("subpix_loss: ", subpix_loss)
            # loss += subpix_loss
            # loss = subpix_loss

            # extract the patches from labels
            label_idx = labels_2D[...].nonzero()
            from utils.losses import extract_patches

            patch_size = 32
            patches = extract_patches(
                label_idx.to(self.device),
                img_warp.to(self.device),
                patch_size=patch_size,
            )  # tensor [N, patch_size, patch_size]
            # patches = extract_patches(label_idx.to(device), labels_2D.to(device), patch_size=15) # tensor [N, patch_size, patch_size]
            print("patches: ", patches.shape)

            def label_to_points(labels_res, points):
                labels_res = labels_res.transpose(1,
                                                  2).transpose(2,
                                                               3).unsqueeze(1)
                points_res = labels_res[points[:, 0], points[:, 1], points[:,
                                                                           2],
                                        points[:, 3], :]  # tensor [N, 2]
                return points_res

            points_res = label_to_points(labels_warped_res, label_idx)

            num_patches_max = 500
            # feed into the network
            pred_res = self.subnet(patches[:num_patches_max, ...].to(
                self.device))  # tensor [1, N, 2]

            # loss function
            def get_loss(points_res, pred_res):
                loss = points_res - pred_res
                loss = torch.norm(loss, p=2, dim=-1).mean()
                return loss

            loss = get_loss(points_res[:num_patches_max, ...].to(self.device),
                            pred_res)

            losses.update({"subpix_loss": subpix_loss})

        self.loss = loss

        losses.update({
            "loss": loss,
            "loss_det": loss_det,
            "loss_det_warp": loss_det_warp,
            "loss_det": loss_det,
            "loss_det_warp": loss_det_warp,
            "positive_dist": positive_dist,
            "negative_dist": negative_dist,
        })
        # print("losses: ", losses)

        if train:
            loss.backward()
            self.optimizer.step()

        self.addLosses2tensorboard(losses, task)
        if n_iter % tb_interval == 0 or task == "val":
            logging.info("current iteration: %d, tensorboard_interval: %d",
                         n_iter, tb_interval)
            self.addImg2tensorboard(
                img,
                labels_2D,
                semi,
                img_warp,
                labels_warp_2D,
                mask_warp_2D,
                semi_warp,
                mask_3D_flattened=mask_3D_flattened,
                task=task,
            )

            if self.subpixel:
                # print("only update subpixel_loss")

                self.add_single_image_to_tb(task,
                                            pred_heatmap,
                                            n_iter,
                                            name="subpixel_heatmap")

            self.printLosses(losses, task)

            # if n_iter % tb_interval == 0 or task == 'val':
            # print ("add nms")
            self.add2tensorboard_nms(img,
                                     labels_2D,
                                     semi,
                                     task=task,
                                     batch_size=batch_size)

        return loss.item()