Exemplo n.º 1
0
    def pred_soft_argmax(labels_2D, heatmap, labels_res, patch_size=5, device="cuda"):
        """

        return:
            dict {'loss': mean of difference btw pred and res}
        """
        from utils.losses import norm_patches

        outs = {}
        # extract patches
        from utils.losses import extract_patches
        from utils.losses import soft_argmax_2d

        label_idx = labels_2D[...].nonzero().long()

        # patch_size = self.config['params']['patch_size']
        patches = extract_patches(
            label_idx.to(device), heatmap.to(device), patch_size=patch_size
        )
        # norm patches
        patches = norm_patches(patches)

        # predict offsets
        from utils.losses import do_log

        patches_log = do_log(patches)
        # soft_argmax
        dxdy = soft_argmax_2d(
            patches_log, normalized_coordinates=False
        )  # tensor [B, N, patch, patch]
        dxdy = dxdy.squeeze(1)  # tensor [N, 2]
        dxdy = dxdy - patch_size // 2

        # extract residual
        def ext_from_points(labels_res, points):
            """
            input:
                labels_res: tensor [batch, channel, H, W]
                points: tensor [N, 4(pos0(batch), pos1(0), pos2(H), pos3(W) )]
            return:
                tensor [N, channel]
            """
            labels_res = labels_res.transpose(1, 2).transpose(2, 3).unsqueeze(1)
            points_res = labels_res[
                points[:, 0], points[:, 1], points[:, 2], points[:, 3], :
            ]  # tensor [N, 2]
            return points_res

        points_res = ext_from_points(labels_res, label_idx)

        # loss
        outs["pred"] = dxdy
        outs["points_res"] = points_res
        # ls = lambda x, y: dxdy.cpu() - points_res.cpu()
        # outs['loss'] = dxdy.cpu() - points_res.cpu()
        outs["loss"] = dxdy.to(device) - points_res.to(device)
        outs["patches"] = patches
        return outs
 def extract_patches(self, label_idx, img):
     """
     input: 
         label_idx: tensor [N, 4]: (batch, 0, y, x)
         img: tensor [batch, channel(1), H, W]
     """
     from utils.losses import extract_patches
     patch_size = self.config['params']['patch_size']
     patches = extract_patches(label_idx.to(self.device),
                               img.to(self.device),
                               patch_size=patch_size)
     return patches
     pass
Exemplo n.º 3
0
    def pred_soft_argmax(self, labels_2D, heatmap):
        """

        return:
            dict {'loss': mean of difference btw pred and res}
        """
        patch_size = self.patch_size
        device = self.device
        from utils.losses import norm_patches

        outs = {}
        # extract patches
        from utils.losses import extract_patches
        from utils.losses import soft_argmax_2d
        label_idx = labels_2D[...].nonzero()

        # patch_size = self.config['params']['patch_size']
        patches = extract_patches(label_idx.to(device),
                                  heatmap.to(device),
                                  patch_size=patch_size)
        # norm patches
        # patches = norm_patches(patches)

        # predict offsets
        from utils.losses import do_log
        patches_log = do_log(patches)
        # soft_argmax
        dxdy = soft_argmax_2d(
            patches_log,
            normalized_coordinates=False)  # tensor [B, N, patch, patch]
        dxdy = dxdy.squeeze(1)  # tensor [N, 2]
        dxdy = dxdy - patch_size // 2

        # loss
        outs['pred'] = dxdy
        # ls = lambda x, y: dxdy.cpu() - points_res.cpu()
        outs['patches'] = patches
        return outs
Exemplo n.º 4
0
    def train_val_sample(self, sample, n_iter=0, train=False):
        """
        # deprecated: default train_val_sample
        :param sample:
        :param n_iter:
        :param train:
        :return:
        """
        task = "train" if train else "val"
        tb_interval = self.config["tensorboard_interval"]

        losses = {}
        ## get the inputs
        # logging.info('get input img and label')
        img, labels_2D, mask_2D = (
            sample["image"],
            sample["labels_2D"],
            sample["valid_mask"],
        )
        # img, labels = img.to(self.device), labels_2D.to(self.device)

        # variables
        batch_size, H, W = img.shape[0], img.shape[2], img.shape[3]
        self.batch_size = batch_size
        # print("batch_size: ", batch_size)
        Hc = H // self.cell_size
        Wc = W // self.cell_size

        # warped images
        # img_warp, labels_warp_2D, mask_warp_2D = sample['warped_img'].to(self.device), \
        #     sample['warped_labels'].to(self.device), \
        #     sample['warped_valid_mask'].to(self.device)
        img_warp, labels_warp_2D, mask_warp_2D = (
            sample["warped_img"],
            sample["warped_labels"],
            sample["warped_valid_mask"],
        )

        # homographies
        # mat_H, mat_H_inv = \
        # sample['homographies'].to(self.device), sample['inv_homographies'].to(self.device)
        mat_H, mat_H_inv = sample["homographies"], sample["inv_homographies"]

        # zero the parameter gradients
        self.optimizer.zero_grad()

        # forward + backward + optimize
        if train:
            # print("img: ", img.shape, ", img_warp: ", img_warp.shape)
            outs, outs_warp = (
                self.net(img.to(self.device)),
                self.net(img_warp.to(self.device), subpixel=self.subpixel),
            )
            semi, coarse_desc = outs[0], outs[1]
            semi_warp, coarse_desc_warp = outs_warp[0], outs_warp[1]
        else:
            with torch.no_grad():
                outs, outs_warp = (
                    self.net(img.to(self.device)),
                    self.net(img_warp.to(self.device), subpixel=self.subpixel),
                )
                semi, coarse_desc = outs[0], outs[1]
                semi_warp, coarse_desc_warp = outs_warp[0], outs_warp[1]
                pass

        # detector loss
        ## get labels, masks, loss for detection
        labels3D_in_loss = self.getLabels(labels_2D,
                                          self.cell_size,
                                          device=self.device)
        mask_3D_flattened = self.getMasks(mask_2D,
                                          self.cell_size,
                                          device=self.device)
        loss_det = self.get_loss(semi,
                                 labels3D_in_loss,
                                 mask_3D_flattened,
                                 device=self.device)

        ## warping
        labels3D_in_loss = self.getLabels(labels_warp_2D,
                                          self.cell_size,
                                          device=self.device)
        mask_3D_flattened = self.getMasks(mask_warp_2D,
                                          self.cell_size,
                                          device=self.device)
        loss_det_warp = self.get_loss(semi_warp,
                                      labels3D_in_loss,
                                      mask_3D_flattened,
                                      device=self.device)

        mask_desc = mask_3D_flattened.unsqueeze(1)

        # print("mask_desc: ", mask_desc.shape)
        # print("mask_warp_2D: ", mask_warp_2D.shape)

        # descriptor loss

        # if self.desc_loss_type == 'dense':
        loss_desc, mask, positive_dist, negative_dist = self.descriptor_loss(
            coarse_desc,
            coarse_desc_warp,
            mat_H,
            mask_valid=mask_desc,
            device=self.device,
            **self.desc_params)

        loss = (loss_det + loss_det_warp +
                self.config["model"]["lambda_loss"] * loss_desc)

        if self.subpixel:
            # coarse to dense descriptor
            # work on warped level
            # dense_desc = interpolate_to_dense(coarse_desc_warp, cell_size=self.cell_size) # tensor [batch, 256, H, W]
            dense_map = flattenDetection(semi_warp)  # tensor [batch, 1, H, W]
            # concat image and dense_desc
            concat_features = torch.cat((img_warp.to(self.device), dense_map),
                                        dim=1)  # tensor [batch, n, H, W]
            # prediction
            # pred_heatmap = self.subpixNet(concat_features.to(self.device)) # tensor [batch, 1, H, W]
            pred_heatmap = outs_warp[2]  # tensor [batch, 1, H, W]
            # print("pred_heatmap: ",  pred_heatmap.shape)
            # add histogram here
            # tensor [batch, channels, H, W]
            # loss
            labels_warped_res = sample["warped_res"]
            # writer.add_histogram(task + '-' + 'warped_res',
            #     labels_warped_res[0,...].clone().cpu().data.numpy().transpose(0,1).transpose(1,2).view(-1, 2),
            #     n_iter)

            # from utils.losses import subpixel_loss
            subpix_loss = self.subpixel_loss_func(
                labels_warp_2D.to(self.device),
                labels_warped_res.to(self.device),
                pred_heatmap.to(self.device),
                patch_size=11,
            )
            # print("subpix_loss: ", subpix_loss)
            # loss += subpix_loss
            # loss = subpix_loss

            # extract the patches from labels
            label_idx = labels_2D[...].nonzero()
            from utils.losses import extract_patches

            patch_size = 32
            patches = extract_patches(
                label_idx.to(self.device),
                img_warp.to(self.device),
                patch_size=patch_size,
            )  # tensor [N, patch_size, patch_size]
            # patches = extract_patches(label_idx.to(device), labels_2D.to(device), patch_size=15) # tensor [N, patch_size, patch_size]
            print("patches: ", patches.shape)

            def label_to_points(labels_res, points):
                labels_res = labels_res.transpose(1,
                                                  2).transpose(2,
                                                               3).unsqueeze(1)
                points_res = labels_res[points[:, 0], points[:, 1], points[:,
                                                                           2],
                                        points[:, 3], :]  # tensor [N, 2]
                return points_res

            points_res = label_to_points(labels_warped_res, label_idx)

            num_patches_max = 500
            # feed into the network
            pred_res = self.subnet(patches[:num_patches_max, ...].to(
                self.device))  # tensor [1, N, 2]

            # loss function
            def get_loss(points_res, pred_res):
                loss = points_res - pred_res
                loss = torch.norm(loss, p=2, dim=-1).mean()
                return loss

            loss = get_loss(points_res[:num_patches_max, ...].to(self.device),
                            pred_res)

            losses.update({"subpix_loss": subpix_loss})

        self.loss = loss

        losses.update({
            "loss": loss,
            "loss_det": loss_det,
            "loss_det_warp": loss_det_warp,
            "loss_det": loss_det,
            "loss_det_warp": loss_det_warp,
            "positive_dist": positive_dist,
            "negative_dist": negative_dist,
        })
        # print("losses: ", losses)

        if train:
            loss.backward()
            self.optimizer.step()

        self.addLosses2tensorboard(losses, task)
        if n_iter % tb_interval == 0 or task == "val":
            logging.info("current iteration: %d, tensorboard_interval: %d",
                         n_iter, tb_interval)
            self.addImg2tensorboard(
                img,
                labels_2D,
                semi,
                img_warp,
                labels_warp_2D,
                mask_warp_2D,
                semi_warp,
                mask_3D_flattened=mask_3D_flattened,
                task=task,
            )

            if self.subpixel:
                # print("only update subpixel_loss")

                self.add_single_image_to_tb(task,
                                            pred_heatmap,
                                            n_iter,
                                            name="subpixel_heatmap")

            self.printLosses(losses, task)

            # if n_iter % tb_interval == 0 or task == 'val':
            # print ("add nms")
            self.add2tensorboard_nms(img,
                                     labels_2D,
                                     semi,
                                     task=task,
                                     batch_size=batch_size)

        return loss.item()
    def train_val_sample(self, sample, n_iter=0, train=False):
        task = 'train' if train else 'val'
        tb_interval = self.config['tensorboard_interval']

        losses, tb_imgs, tb_hist = {}, {}, {}
        ## get the inputs
        # logging.info('get input img and label')
        img, labels_2D, mask_2D = sample['image'], sample['labels_2D'], sample[
            'valid_mask']
        # img, labels = img.to(self.device), labels_2D.to(self.device)
        labels_res = sample['labels_res']

        # variables
        batch_size, H, W = img.shape[0], img.shape[2], img.shape[3]
        self.batch_size = batch_size
        # print("batch_size: ", batch_size)
        Hc = H // self.cell_size
        Wc = W // self.cell_size

        # zero the parameter gradients
        self.optimizer.zero_grad()

        # extract patches
        # extract the patches from labels
        label_idx = labels_2D[...].nonzero()
        from utils.losses import extract_patches
        patch_size = self.config['model']['params']['patch_size']
        patches = extract_patches(
            label_idx.to(self.device),
            img.to(self.device),
            patch_size=patch_size)  # tensor [N, patch_size, patch_size]
        # patches = extract_patches(label_idx.to(device), labels_2D.to(device), patch_size=15) # tensor [N, patch_size, patch_size]
        # print("patches: ", patches.shape)

        patch_channels = self.config['model']['params'].get(
            'subpixel_channel', 1)
        if patch_channels == 2:
            patch_heat = extract_patches(
                label_idx.to(self.device),
                img.to(self.device),
                patch_size=patch_size)  # tensor [N, patch_size, patch_size]

        def label_to_points(labels_res, points):
            labels_res = labels_res.transpose(1, 2).transpose(2,
                                                              3).unsqueeze(1)
            points_res = labels_res[points[:, 0], points[:, 1], points[:, 2],
                                    points[:, 3], :]  # tensor [N, 2]
            return points_res

        points_res = label_to_points(labels_res, label_idx)

        num_patches_max = 500
        # feed into the network
        pred_res = self.net(patches[:num_patches_max,
                                    ...].to(self.device))  # tensor [1, N, 2]

        # loss function
        def get_loss(points_res, pred_res):
            loss = (points_res - pred_res)
            loss = torch.norm(loss, p=2, dim=-1).mean()
            return loss

        loss = get_loss(points_res[:num_patches_max, ...].to(self.device),
                        pred_res)
        self.loss = loss

        losses.update({'loss': loss})
        tb_hist.update({'points_res_0': points_res[:, 0]})
        tb_hist.update({'points_res_1': points_res[:, 1]})
        tb_hist.update({'pred_res_0': pred_res[:, 0]})
        tb_hist.update({'pred_res_1': pred_res[:, 1]})
        tb_imgs.update({'patches': patches[:, ...].unsqueeze(1)})
        tb_imgs.update({'img': img})
        # forward + backward + optimize
        # if train:
        #     print("img: ", img.shape)
        #     outs, outs_warp = self.net(img.to(self.device)), self.net(img_warp.to(self.device), subpixel=self.subpixel)
        #     semi, coarse_desc = outs[0], outs[1]
        #     semi_warp, coarse_desc_warp = outs_warp[0], outs_warp[1]
        # else:
        #     with torch.no_grad():
        #         outs, outs_warp = self.net(img.to(self.device)), self.net(img_warp.to(self.device), subpixel=self.subpixel)
        #         semi, coarse_desc = outs[0], outs[1]
        #         semi_warp, coarse_desc_warp = outs_warp[0], outs_warp[1]
        #         pass

        # descriptor loss

        losses.update({'loss': loss})
        # print("losses: ", losses)

        if train:
            loss.backward()
            self.optimizer.step()

        self.tb_scalar_dict(losses, task)
        if n_iter % tb_interval == 0 or task == 'val':
            logging.info("current iteration: %d, tensorboard_interval: %d",
                         n_iter, tb_interval)
            self.tb_images_dict(task, tb_imgs, max_img=5)
            self.tb_hist_dict(task, tb_hist)

        return loss.item()