def test_valid_mask(): from utils.utils import pltImshow batch_size = 1 mat_homographies = [sample_homography(3) for i in range(batch_size)] mat_H = np.stack(mat_homographies, axis=0) corner_img = np.array([(-1, -1), (-1, 1), (1, -1), (1, 1)]) # printCorners(corner_img, mat_H) # points = warp_points_np(corner_img, mat_homographies) mat_H = torch.tensor(mat_H, dtype=torch.float32) mat_H_inv = torch.stack([torch.inverse(mat_H[i, :, :]) for i in range(batch_size)]) from utils.utils import compute_valid_mask, labels2Dto3D device = 'cpu' shape = torch.tensor([240, 320]) for i in range(1): r = 3 mask_valid = compute_valid_mask(shape, inv_homography=mat_H_inv, device=device, erosion_radius=r) pltImshow(mask_valid[0,:,:]) cell_size = 8 mask_valid = labels2Dto3D(mask_valid.view(batch_size, 1, mask_valid.shape[1], mask_valid.shape[2]), cell_size=cell_size) mask_valid = torch.prod(mask_valid[:,:cell_size*cell_size,:,:], dim=1) pltImshow(mask_valid[0,:,:].cpu().numpy()) mask = {} mask.update({'homographies': mat_H, 'masks': mask_valid}) np.savez_compressed('h2.npz', **mask) print("finish testing valid mask")
def getMasks(self, mask_2D, cell_size, device="cpu"): """ # 2D mask is constructed into 3D (Hc, Wc) space for training :param mask_2D: tensor [batch, 1, H, W] :param cell_size: 8 (default) :param device: :return: flattened 3D mask for training """ mask_3D = labels2Dto3D(mask_2D.to(device), cell_size=cell_size, add_dustbin=False).float() mask_3D_flattened = torch.prod(mask_3D, 1) return mask_3D_flattened
def train_val_sample(self, sample, n_iter=0, train=False): """ # key function :param sample: :param n_iter: :param train: :return: """ to_floatTensor = lambda x: torch.tensor(x).type(torch.FloatTensor) task = "train" if train else "val" tb_interval = self.config["tensorboard_interval"] if_warp = self.config['data']['warped_pair']['enable'] self.scalar_dict, self.images_dict, self.hist_dict = {}, {}, {} ## get the inputs # logging.info('get input img and label') img, labels_2D, mask_2D = ( sample["image"], sample["labels_2D"], sample["valid_mask"], ) # img, labels = img.to(self.device), labels_2D.to(self.device) # variables batch_size, H, W = img.shape[0], img.shape[2], img.shape[3] self.batch_size = batch_size det_loss_type = self.config["model"]["detector_loss"]["loss_type"] # print("batch_size: ", batch_size) Hc = H // self.cell_size Wc = W // self.cell_size # warped images # img_warp, labels_warp_2D, mask_warp_2D = sample['warped_img'].to(self.device), \ # sample['warped_labels'].to(self.device), \ # sample['warped_valid_mask'].to(self.device) if if_warp: img_warp, labels_warp_2D, mask_warp_2D = ( sample["warped_img"], sample["warped_labels"], sample["warped_valid_mask"], ) # homographies # mat_H, mat_H_inv = \ # sample['homographies'].to(self.device), sample['inv_homographies'].to(self.device) if if_warp: mat_H, mat_H_inv = sample["homographies"], sample[ "inv_homographies"] # zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize if train: # print("img: ", img.shape, ", img_warp: ", img_warp.shape) outs = self.net(img.to(self.device)) semi, coarse_desc = outs["semi"], outs["desc"] if if_warp: outs_warp = self.net(img_warp.to(self.device)) semi_warp, coarse_desc_warp = outs_warp["semi"], outs_warp[ "desc"] else: with torch.no_grad(): outs = self.net(img.to(self.device)) semi, coarse_desc = outs["semi"], outs["desc"] if if_warp: outs_warp = self.net(img_warp.to(self.device)) semi_warp, coarse_desc_warp = outs_warp["semi"], outs_warp[ "desc"] pass # detector loss from utils.utils import labels2Dto3D if self.gaussian: labels_2D = sample["labels_2D_gaussian"] if if_warp: warped_labels = sample["warped_labels_gaussian"] else: labels_2D = sample["labels_2D"] if if_warp: warped_labels = sample["warped_labels"] add_dustbin = False if det_loss_type == "l2": add_dustbin = False elif det_loss_type == "softmax": add_dustbin = True labels_3D = labels2Dto3D(labels_2D.to(self.device), cell_size=self.cell_size, add_dustbin=add_dustbin).float() mask_3D_flattened = self.getMasks(mask_2D, self.cell_size, device=self.device) loss_det = self.detector_loss( input=outs["semi"], target=labels_3D.to(self.device), mask=mask_3D_flattened, loss_type=det_loss_type, ) # warp if if_warp: labels_3D = labels2Dto3D( warped_labels.to(self.device), cell_size=self.cell_size, add_dustbin=add_dustbin, ).float() mask_3D_flattened = self.getMasks(mask_warp_2D, self.cell_size, device=self.device) loss_det_warp = self.detector_loss( input=outs_warp["semi"], target=labels_3D.to(self.device), mask=mask_3D_flattened, loss_type=det_loss_type, ) else: loss_det_warp = torch.tensor([0]).to(self.device) ## get labels, masks, loss for detection # labels3D_in_loss = self.getLabels(labels_2D, self.cell_size, device=self.device) # mask_3D_flattened = self.getMasks(mask_2D, self.cell_size, device=self.device) # loss_det = self.get_loss(semi, labels3D_in_loss, mask_3D_flattened, device=self.device) ## warping # labels3D_in_loss = self.getLabels(labels_warp_2D, self.cell_size, device=self.device) # mask_3D_flattened = self.getMasks(mask_warp_2D, self.cell_size, device=self.device) # loss_det_warp = self.get_loss(semi_warp, labels3D_in_loss, mask_3D_flattened, device=self.device) mask_desc = mask_3D_flattened.unsqueeze(1) lambda_loss = self.config["model"]["lambda_loss"] # print("mask_desc: ", mask_desc.shape) # print("mask_warp_2D: ", mask_warp_2D.shape) # descriptor loss if lambda_loss > 0: assert if_warp == True, "need a pair of images" loss_desc, mask, positive_dist, negative_dist = self.descriptor_loss( coarse_desc, coarse_desc_warp, mat_H, mask_valid=mask_desc, device=self.device, **self.desc_params) else: ze = torch.tensor([0]).to(self.device) loss_desc, positive_dist, negative_dist = ze, ze, ze loss = loss_det + loss_det_warp if lambda_loss > 0: loss += lambda_loss * loss_desc ##### try to minimize the error ###### add_res_loss = False if add_res_loss and n_iter % 10 == 0: print("add_res_loss!!!") heatmap_org = self.get_heatmap(semi, det_loss_type) # tensor [] heatmap_org_nms_batch = self.heatmap_to_nms(self.images_dict, heatmap_org, name="heatmap_org") if if_warp: heatmap_warp = self.get_heatmap(semi_warp, det_loss_type) heatmap_warp_nms_batch = self.heatmap_to_nms( self.images_dict, heatmap_warp, name="heatmap_warp") # original: pred ## check the loss on given labels! outs_res = self.get_residual_loss( sample["labels_2D"] * to_floatTensor(heatmap_org_nms_batch).unsqueeze(1), heatmap_org, sample["labels_res"], name="original_pred", ) loss_res_ori = (outs_res["loss"]**2).mean() # warped: pred if if_warp: outs_res_warp = self.get_residual_loss( sample["warped_labels"] * to_floatTensor(heatmap_warp_nms_batch).unsqueeze(1), heatmap_warp, sample["warped_res"], name="warped_pred", ) loss_res_warp = (outs_res_warp["loss"]**2).mean() else: loss_res_warp = torch.tensor([0]).to(self.device) loss_res = loss_res_ori + loss_res_warp # print("loss_res requires_grad: ", loss_res.requires_grad) loss += loss_res self.scalar_dict.update({ "loss_res_ori": loss_res_ori, "loss_res_warp": loss_res_warp }) ####################################### self.loss = loss self.scalar_dict.update({ "loss": loss, "loss_det": loss_det, "loss_det_warp": loss_det_warp, "positive_dist": positive_dist, "negative_dist": negative_dist, }) self.input_to_imgDict(sample, self.images_dict) if train: loss.backward() self.optimizer.step() if n_iter % tb_interval == 0 or task == "val": logging.info("current iteration: %d, tensorboard_interval: %d", n_iter, tb_interval) # add clean map to tensorboard ## semi_warp: flatten, to_numpy heatmap_org = self.get_heatmap(semi, det_loss_type) # tensor [] heatmap_org_nms_batch = self.heatmap_to_nms(self.images_dict, heatmap_org, name="heatmap_org") if if_warp: heatmap_warp = self.get_heatmap(semi_warp, det_loss_type) heatmap_warp_nms_batch = self.heatmap_to_nms( self.images_dict, heatmap_warp, name="heatmap_warp") def update_overlap(images_dict, labels_warp_2D, heatmap_nms_batch, img_warp, name): # image overlap from utils.draw import img_overlap # result_overlap = img_overlap(img_r, img_g, img_gray) # overlap label, nms, img nms_overlap = [ img_overlap( toNumpy(labels_warp_2D[i]), heatmap_nms_batch[i], toNumpy(img_warp[i]), ) for i in range(heatmap_nms_batch.shape[0]) ] nms_overlap = np.stack(nms_overlap, axis=0) images_dict.update({name + "_nms_overlap": nms_overlap}) from utils.var_dim import toNumpy update_overlap( self.images_dict, labels_2D, heatmap_org_nms_batch[np.newaxis, ...], img, "original", ) update_overlap( self.images_dict, labels_2D, toNumpy(heatmap_org), img, "original_heatmap", ) if if_warp: update_overlap( self.images_dict, labels_warp_2D, heatmap_warp_nms_batch[np.newaxis, ...], img_warp, "warped", ) update_overlap( self.images_dict, labels_warp_2D, toNumpy(heatmap_warp), img_warp, "warped_heatmap", ) # residuals from utils.losses import do_log if self.gaussian: # original: gt self.get_residual_loss( sample["labels_2D"], sample["labels_2D_gaussian"], sample["labels_res"], name="original_gt", ) if if_warp: # warped: gt self.get_residual_loss( sample["warped_labels"], sample["warped_labels_gaussian"], sample["warped_res"], name="warped_gt", ) # from utils.losses import do_log # patches_log = do_log(patches) # original: pred ## check the loss on given labels! # self.get_residual_loss( # sample["labels_2D"] # * to_floatTensor(heatmap_org_nms_batch).unsqueeze(1), # heatmap_org, # sample["labels_res"], # name="original_pred", # ) # print("heatmap_org_nms_batch: ", heatmap_org_nms_batch.shape) # get_residual_loss(to_floatTensor(heatmap_org_nms_batch).unsqueeze(1), heatmap_org, # sample['labels_res'], name='original_pred') # warped: pred # self.get_residual_loss( # sample["warped_labels"] # * to_floatTensor(heatmap_warp_nms_batch).unsqueeze(1), # heatmap_warp, # sample["warped_res"], # name="warped_pred", # ) # get_residual_loss(to_floatTensor(heatmap_warp_nms_batch).unsqueeze(1), heatmap_warp, # sample['warped_res'], name='warped_pred') # precision, recall # pr_mean = self.batch_precision_recall( # to_floatTensor(heatmap_warp_nms_batch[:, np.newaxis, ...]), # sample["warped_labels"], # ) pr_mean = self.batch_precision_recall( to_floatTensor(heatmap_org_nms_batch[:, np.newaxis, ...]), sample["labels_2D"], ) print("pr_mean") self.scalar_dict.update(pr_mean) self.printLosses(self.scalar_dict, task) self.tb_images_dict(task, self.images_dict, max_img=2) self.tb_hist_dict(task, self.hist_dict) self.tb_scalar_dict(self.scalar_dict, task) return loss.item()