def warpLabels(pnts, H, W, homography, bilinear=False): from utils.utils import homography_scaling_torch as homography_scaling from utils.utils import filter_points from utils.utils import warp_points if isinstance(pnts, torch.Tensor): pnts = pnts.long() else: pnts = torch.tensor(pnts).long() warped_pnts = warp_points(torch.stack((pnts[:, 0], pnts[:, 1]), dim=1), homography_scaling(homography, H, W)) # check the (x, y) outs = {} # warped_pnts # print("extrapolate_points!!") # ext_points = True if bilinear == True: warped_labels_bi = get_labels_bi(warped_pnts, H, W) outs['labels_bi'] = warped_labels_bi warped_pnts = filter_points(warped_pnts, torch.tensor([W, H])) warped_labels = scatter_points(warped_pnts, H, W, res_ext=1) warped_labels_res = torch.zeros(H, W, 2) warped_labels_res[ quan(warped_pnts)[:, 1], quan(warped_pnts)[:, 0], :] = warped_pnts - warped_pnts.round() # print("res sum: ", (warped_pnts - warped_pnts.round()).sum()) outs.update({ 'labels': warped_labels, 'res': warped_labels_res, 'warped_pnts': warped_pnts }) return outs
def __getitem__(self, index): """ :param index: :return: labels_2D: tensor(1, H, W) image: tensor(1, H, W) """ def imgPhotometric(img): """ :param img: numpy (H, W) :return: """ augmentation = self.ImgAugTransform(**self.config["augmentation"]) img = img[:, :, np.newaxis] img = augmentation(img) cusAug = self.customizedTransform() img = cusAug(img, **self.config["augmentation"]) return img from datasets.data_tools import np_to_tensor from utils.utils import filter_points from utils.var_dim import squeezeToNumpy sample = {} H, W = self.config["generation"]["image_size"] imgs, pnts, evts = synthetic_dataset.generate_random_shape((H, W), 5, None) idx = np.random.randint(1000) # Only take the last set of points pnts = torch.tensor(pnts[-1]).float() pnts = torch.stack((pnts[:, 1], pnts[:, 0]), dim=1) # (x, y) pnts = filter_points(pnts, torch.tensor([H, W])) pnts_long = pnts.round().long() labels = torch.zeros(H, W) labels[pnts_long[:, 0], pnts_long[:, 1]] = 1 valid_mask = self.compute_valid_mask(torch.tensor([H, W]), inv_homography=torch.eye(3)) for i, evt in enumerate(evts): evts[i, 0] = imgPhotometric(evt[0, :, :]).squeeze() evts[i, 1] = imgPhotometric(evt[1, :, :]).squeeze() evts = torch.from_numpy(evts.astype(np.float32)) # sample.update({"images": imgs}) sample.update({"valid_mask": valid_mask}) sample.update({"labels_2D": labels.unsqueeze(0)}) # sample.update({"points": pnts}) sample.update({"events": evts}) return sample
def get_labels_bi(warped_pnts, H, W): from utils.utils import filter_points pnts_ext, res_ext = extrapolate_points(warped_pnts) # quan = lambda x: x.long() pnts_ext, mask = filter_points(pnts_ext, torch.tensor([W, H]), return_mask=True) res_ext = res_ext[mask] warped_labels_bi = scatter_points(pnts_ext, H, W, res_ext=res_ext) return warped_labels_bi
def warpLabels(pnts, homography, H, W): import torch """ input: pnts: numpy homography: numpy output: warped_pnts: numpy """ from utils.utils import warp_points from utils.utils import filter_points pnts = torch.tensor(pnts).long() homography = torch.tensor(homography, dtype=torch.float32) warped_pnts = warp_points(torch.stack((pnts[:, 0], pnts[:, 1]), dim=1), homography) # check the (x, y) warped_pnts = filter_points(warped_pnts, torch.tensor([W, H])).round().long() return warped_pnts.numpy()
def descriptor_loss_sparse(descriptors, descriptors_warped, homographies, mask_valid=None, cell_size=8, device='cpu', descriptor_dist=4, lamda_d=250, num_matching_attempts=1000, num_masked_non_matches_per_match=10, dist='cos', method='1d', **config): """ consider batches of descriptors :param descriptors: Output from descriptor head tensor [descriptors, Hc, Wc] :param descriptors_warped: Output from descriptor head of warped image tensor [descriptors, Hc, Wc] """ def uv_to_tuple(uv): return (uv[:, 0], uv[:, 1]) def tuple_to_uv(uv_tuple): return torch.stack([uv_tuple[0], uv_tuple[1]]) def tuple_to_1d(uv_tuple, W, uv=True): if uv: return uv_tuple[0] + uv_tuple[1] * W else: return uv_tuple[0] * W + uv_tuple[1] def uv_to_1d(points, W, uv=True): # assert points.dim == 2 # print("points: ", points[0]) # print("H: ", H) if uv: return points[..., 0] + points[..., 1] * W else: return points[..., 0] * W + points[..., 1] ## calculate matches loss def get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b, dist='cos', method='1d'): match_loss, matches_a_descriptors, matches_b_descriptors = \ PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred, matches_a, matches_b, dist=dist, method=method) return match_loss def get_non_matches_corr(img_b_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=10, device='cpu'): ## sample non matches uv_b_matches = uv_b_matches.squeeze() uv_b_matches_tuple = uv_to_tuple(uv_b_matches) uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences( uv_b_matches_tuple, img_b_shape, num_non_matches_per_match=num_masked_non_matches_per_match, img_b_mask=None) ## create_non_correspondences # print("img_b_shape ", img_b_shape) # print("uv_b_matches ", uv_b_matches.shape) # print("uv_a: ", uv_to_tuple(uv_a)) # print("uv_b_non_matches: ", uv_b_non_matches) # print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches)) uv_a_tuple, uv_b_non_matches_tuple = \ create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match) return uv_a_tuple, uv_b_non_matches_tuple def get_non_match_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_b, dist='cos'): ## non matches loss non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \ PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred, non_matches_a.long().squeeze(), non_matches_b.long().squeeze(), M=0.2, invert=True, dist=dist) non_match_loss = non_match_loss.sum() / (num_hard_negatives + 1) return non_match_loss from utils.utils import filter_points from utils.utils import crop_or_pad_choice from utils.utils import normPts # ##### print configs # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match) # print("num_matching_attempts: ", num_matching_attempts) # dist = 'cos' # print("method: ", method) Hc, Wc = descriptors.shape[1], descriptors.shape[2] img_shape = (Hc, Wc) # print("img_shape: ", img_shape) # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu')) # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2) # torch [batch_size, H*W, D] def descriptor_reshape(descriptors): descriptors = descriptors.view(-1, Hc * Wc).transpose( 0, 1) # torch [D, H, W] --> [H*W, d] descriptors = descriptors.unsqueeze(0) # torch [1, H*W, D] return descriptors image_a_pred = descriptor_reshape(descriptors) # torch [1, H*W, D] # print("image_a_pred: ", image_a_pred.shape) image_b_pred = descriptor_reshape( descriptors_warped) # torch [batch_size, H*W, D] # matches uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu') # print("uv_a: ", uv_a[0]) homographies_H = scale_homography_torch(homographies, img_shape, shift=(-1, -1)) # print("experiment inverse homographies") # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H]) # print("homographies_H: ", homographies_H.shape) # homographies_H = torch.inverse(homographies_H) uv_b_matches = warp_coor_cells_with_homographies(uv_a, homographies_H.to('cpu'), uv=True, device='cpu') # # print("uv_b_matches before round: ", uv_b_matches[0]) uv_b_matches.round_() # print("uv_b_matches after round: ", uv_b_matches[0]) uv_b_matches = uv_b_matches.squeeze(0) # filtering out of range points # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True) uv_b_matches, mask = filter_points(uv_b_matches, torch.tensor([Wc, Hc]).to(device='cpu'), return_mask=True) # print ("pos mask sum: ", mask.sum()) uv_a = uv_a[mask] # crop to the same length shuffle = True if not shuffle: print("shuffle: ", shuffle) choice = crop_or_pad_choice(uv_b_matches.shape[0], num_matching_attempts, shuffle=shuffle) choice = torch.tensor(choice) uv_a = uv_a[choice] uv_b_matches = uv_b_matches[choice] if method == '2d': matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float()) # [u, v] matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float()) else: matches_a = uv_to_1d(uv_a, Wc) matches_b = uv_to_1d(uv_b_matches, Wc) # print("matches_a: ", matches_a.shape) # print("matches_b: ", matches_b.shape) # print("matches_b max: ", matches_b.max()) if method == '2d': match_loss = get_match_loss(descriptors, descriptors_warped, matches_a.to(device), matches_b.to(device), dist=dist, method='2d') else: match_loss = get_match_loss(image_a_pred, image_b_pred, matches_a.long().to(device), matches_b.long().to(device), dist=dist) # non matches # get non matches correspondence uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr( img_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=num_masked_non_matches_per_match) non_matches_a = tuple_to_1d(uv_a_tuple, Wc) non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc) # print("non_matches_a: ", non_matches_a) # print("non_matches_b: ", non_matches_b) non_match_loss = get_non_match_loss(image_a_pred, image_b_pred, non_matches_a.to(device), non_matches_b.to(device), dist=dist) # non_match_loss = non_match_loss.mean() loss = lamda_d * match_loss + non_match_loss return loss, lamda_d * match_loss, non_match_loss pass
def descriptor_loss_sparse_reliability(descriptors, descriptors_warped, reliability, reliability_warp, homographies, aplosser, mask_valid=None, cell_size=8, device='cpu', descriptor_dist=4, lamda_d=250, lamda_r=1, num_matching_attempts=1000, num_masked_non_matches_per_match=10, dist='cos', method='1d', sos=True, reli_base=0.5, **config): """ consider batches of descriptors :param descriptors: Output from descriptor head tensor [descriptors, Hc, Wc] :param descriptors_warped: Output from descriptor head of warped image tensor [descriptors, Hc, Wc] """ def uv_to_tuple(uv): return (uv[:, 0], uv[:, 1]) def tuple_to_uv(uv_tuple): return torch.stack([uv_tuple[0], uv_tuple[1]]) def tuple_to_1d(uv_tuple, W, uv=True): if uv: return uv_tuple[0] + uv_tuple[1] * W else: return uv_tuple[0] * W + uv_tuple[1] def uv_to_1d(points, W, uv=True): # assert points.dim == 2 # print("points: ", points[0]) # print("H: ", H) if uv: return points[..., 0] + points[..., 1] * W else: return points[..., 0] * W + points[..., 1] ## calculate matches loss def get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b, dist='cos', method='1d', sos=True): match_loss, matches_a_descriptors, matches_b_descriptors = \ PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred, matches_a, matches_b, dist=dist, method=method, sos=sos) # Add SOS Loss return match_loss, matches_a_descriptors, matches_b_descriptors def get_non_matches_corr(img_b_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=10, device='cpu'): ## sample non matches uv_b_matches = uv_b_matches.squeeze() uv_b_matches_tuple = uv_to_tuple(uv_b_matches) uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences( uv_b_matches_tuple, img_b_shape, num_non_matches_per_match=num_masked_non_matches_per_match, img_b_mask=None) ## create_non_correspondences # print("img_b_shape ", img_b_shape) # print("uv_b_matches ", uv_b_matches.shape) # print("uv_a: ", uv_to_tuple(uv_a)) # print("uv_b_non_matches: ", uv_b_non_matches) # print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches)) uv_a_tuple, uv_b_non_matches_tuple = \ create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match) return uv_a_tuple, uv_b_non_matches_tuple def get_non_match_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_b, dist='cos'): ## non matches loss non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \ PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred, non_matches_a.long().squeeze(), non_matches_b.long().squeeze(), M=0.2, invert=True, dist=dist) non_match_loss = non_match_loss.sum() / (num_hard_negatives + 1) return non_match_loss # def reliability_loss(descriptors_a, descriptors_b, gt, uv_a, uv_b): # dist = descriptors_a.mm(descriptors_b.T).pow(2) # dist = dist / dist.max() # acc = (dist * gt).sum(-1) # to be tested # # max = dist.max() # # min = dist.min() # # dist_diagonal = dist.diagonal().pow(2) # reli_a = reliability[:, uv_a[:, 1].long(), uv_a[:, 0].long()].transpose(1, 0) # reli_b = reliability_warp[:, uv_b[:, 1].long(), uv_b[:, 0].long()] # # reli = reli_a.mm(reli_b).pow(0.5) # reli = (reli_a + reli_b) / 2 # ap_loss = 1 - acc * reli - (1 - reli) * reli_base # # ap_loss = 1 - dist_diagonal * reli # ap_loss = ap_loss.mean() # return ap_loss def reliability_loss(descriptors_a, descriptors_b, image_b_pred, reliability_a, reliability_b, uv_a, uv_b, img_b_shape, aplosser, method='1d', device='cpu', reli_base=0.5): uv_b = uv_b.squeeze() uv_b_matches_tuple = uv_to_tuple(uv_b) uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences( uv_b_matches_tuple, img_b_shape, num_non_matches_per_match=1, img_b_mask=None) _, uv_b_non_matches_tuple = create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, 1) uv_b_non = tuple_to_uv(uv_b_non_matches_tuple).squeeze().transpose( 1, 0) # negative sample coordinates # generate negative descriptors if method == '2d': def sampleDescriptors(image_a_pred, matches_a, mode, norm=False): image_a_pred = image_a_pred.unsqueeze(0) # torch [1, D, H, W] matches_a.unsqueeze_(0).unsqueeze_(2) matches_a_descriptors = torch.nn.functional.grid_sample( image_a_pred, matches_a, mode=mode, align_corners=True) matches_a_descriptors = matches_a_descriptors.squeeze( ).transpose(0, 1) # print("image_a_pred: ", image_a_pred.shape) # print("matches_a: ", matches_a.shape) # print("matches_a: ", matches_a) # print("matches_a_descriptors: ", matches_a_descriptors) if norm: dn = torch.norm(matches_a_descriptors, p=2, dim=1) # Compute the norm of b_descriptors matches_a_descriptors = matches_a_descriptors.div( torch.unsqueeze(dn, 1)) # Divide by norm to normalize. return matches_a_descriptors matches_b_non = normPts( uv_b_non, torch.tensor([img_b_shape[1], img_b_shape[0]]).float()) descriptors_b_non = sampleDescriptors(image_b_pred, matches_b_non.to(device), mode='bilinear', norm=False) else: matches_b_non = uv_to_1d(uv_b_matches, img_b_shape[1]) descriptors_b_non = torch.index_select( image_b_pred, 1, matches_b_non.long().to(device)) qconf = reliability_a[:, uv_a[:, 1].long(), uv_a[:, 0].long()] + \ reliability_b[:, uv_b[:, 1].long(), uv_b[:, 0].long()] qconf /= 2 pscores = (descriptors_a * descriptors_b).sum(-1)[:, None] nscores = (descriptors_a * descriptors_b_non).sum(-1)[:, None] scores = torch.cat((pscores, nscores), dim=1) gt = torch.zeros_like(scores, dtype=torch.uint8) gt[:, :pscores.shape[1]] = 1 ap = aplosser(scores, gt) ap_loss = 1 - ap * qconf - (1 - qconf) * reli_base return ap_loss.mean() from utils.utils import filter_points from utils.utils import crop_or_pad_choice from utils.utils import normPts # ##### print configs # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match) # print("num_matching_attempts: ", num_matching_attempts) # dist = 'cos' # print("method: ", method) Hc, Wc = descriptors.shape[1], descriptors.shape[2] img_shape = (Hc, Wc) # print("img_shape: ", img_shape) # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu')) # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2) # torch [batch_size, H*W, D] def descriptor_reshape(descriptors): descriptors = descriptors.view(-1, Hc * Wc).transpose( 0, 1) # torch [D, H, W] --> [H*W, d] descriptors = descriptors.unsqueeze(0) # torch [1, H*W, D] return descriptors image_a_pred = descriptor_reshape(descriptors) # torch [1, H*W, D] # print("image_a_pred: ", image_a_pred.shape) image_b_pred = descriptor_reshape( descriptors_warped) # torch [batch_size, H*W, D] # matches uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu') # print("uv_a: ", uv_a[0]) homographies_H = scale_homography_torch( homographies, img_shape, shift=(-1, -1) ) # original scale is for image size, now downscale it to descriptor tensor size # print("experiment inverse homographies") # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H]) # print("homographies_H: ", homographies_H.shape) # homographies_H = torch.inverse(homographies_H) uv_b_matches = warp_coor_cells_with_homographies(uv_a, homographies_H.to('cpu'), uv=True, device='cpu') # # print("uv_b_matches before round: ", uv_b_matches[0]) uv_b_matches.round_() # print("uv_b_matches after round: ", uv_b_matches[0]) uv_b_matches = uv_b_matches.squeeze(0) # filtering out of range points # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True) uv_b_matches, mask = filter_points(uv_b_matches, torch.tensor([Wc, Hc]).to(device='cpu'), return_mask=True) # print ("pos mask sum: ", mask.sum()) uv_a = uv_a[mask] # uv_a, uv_b_matches are the gt # crop to the same length shuffle = True if not shuffle: print("shuffle: ", shuffle) choice = crop_or_pad_choice(uv_b_matches.shape[0], num_matching_attempts, shuffle=shuffle) choice = torch.tensor(choice) uv_a = uv_a[choice] uv_b_matches = uv_b_matches[choice] ## add reliability, if method == '2d': matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float()) # [u, v] matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float()) else: matches_a = uv_to_1d(uv_a, Wc) matches_b = uv_to_1d(uv_b_matches, Wc) # print("matches_a: ", matches_a.shape) # print("matches_b: ", matches_b.shape) # print("matches_b max: ", matches_b.max()) if method == '2d': match_loss, matches_a_descriptors, matches_b_descriptors = get_match_loss( descriptors, descriptors_warped, matches_a.to(device), matches_b.to(device), dist=dist, method='2d', sos=sos) else: match_loss, matches_a_descriptors, matches_b_descriptors = get_match_loss( image_a_pred, image_b_pred, matches_a.long().to(device), matches_b.long().to(device), dist=dist, sos=sos) # reliability loss matches_a_descriptors, matches_b_descriptors = matches_a_descriptors.squeeze( ), matches_b_descriptors.squeeze() # gt = torch.eye(matches_a_descriptors.shape[0]).to(device) if method == '2d': ap_loss = reliability_loss(matches_a_descriptors, matches_b_descriptors, descriptors_warped, reliability, reliability_warp, uv_a, uv_b_matches, img_shape, aplosser, method=method, device=device, reli_base=reli_base) else: ap_loss = reliability_loss(matches_a_descriptors, matches_b_descriptors, image_b_pred, reliability, reliability_warp, uv_a, uv_b_matches, img_shape, aplosser, method=method, device=device, reli_base=reli_base) # non matches # get non matches correspondence uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr( img_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=num_masked_non_matches_per_match) non_matches_a = tuple_to_1d(uv_a_tuple, Wc) non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc) # print("non_matches_a: ", non_matches_a) # print("non_matches_b: ", non_matches_b) non_match_loss = get_non_match_loss(image_a_pred, image_b_pred, non_matches_a.to(device), non_matches_b.to(device), dist=dist) # non_match_loss = non_match_loss.mean() loss = lamda_d * match_loss + lamda_r * ap_loss + non_match_loss return loss, lamda_d * match_loss, non_match_loss, lamda_r * ap_loss pass
def descriptor_loss_sparse(descriptors, descriptors_warped, homographies, mask_valid=None, cell_size=8, device='cpu', descriptor_dist=4, lamda_d=250, num_matching_attempts=1000, num_masked_non_matches_per_match=10, **config): """ consider batches of descriptors :param descriptors: Output from descriptor head tensor [descriptors, Hc, Wc] :param descriptors_warped: Output from descriptor head of warped image tensor [descriptors, Hc, Wc] """ def uv_to_tuple(uv): return (uv[:, 0], uv[:, 1]) def tuple_to_uv(uv_tuple): return torch.stack([uv_tuple[:, 0], uv_tuple[:, 1]]) def tuple_to_1d(uv_tuple, H): return uv_tuple[0] * H + uv_tuple[1] def uv_to_1d(points, H): # assert points.dim == 2 # print("points: ", points[0]) # print("H: ", H) return points[..., 0] * H + points[..., 1] ## calculate matches loss def get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b): match_loss, matches_a_descriptors, matches_b_descriptors = \ PixelwiseContrastiveLoss.match_loss(image_a_pred, image_b_pred, matches_a.long(), matches_b.long()) return match_loss def get_non_matches_corr(img_b_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=10): ## sample non matches uv_b_matches = uv_b_matches.squeeze() uv_b_matches_tuple = uv_to_tuple(uv_b_matches) uv_b_non_matches_tuple = correspondence_finder.create_non_correspondences( uv_b_matches_tuple, img_b_shape, num_non_matches_per_match=num_masked_non_matches_per_match, img_b_mask=None) ## create_non_correspondences # print("img_b_shape ", img_b_shape) # print("uv_b_matches ", uv_b_matches.shape) # print("uv_a: ", uv_to_tuple(uv_a)) # print("uv_b_non_matches: ", uv_b_non_matches) # print("uv_b_non_matches: ", tensorUv2tuple(uv_b_non_matches)) uv_a_tuple, uv_b_non_matches_tuple = \ create_non_matches(uv_to_tuple(uv_a), uv_b_non_matches_tuple, num_masked_non_matches_per_match) return uv_a_tuple, uv_b_non_matches_tuple def get_non_match_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_b): ## non matches loss non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors = \ PixelwiseContrastiveLoss.non_match_descriptor_loss(image_a_pred, image_b_pred, non_matches_a.long().squeeze(), non_matches_b.long().squeeze()) return non_match_loss from utils.utils import filter_points from utils.utils import crop_or_pad_choice Hc, Wc = descriptors.shape[1], descriptors.shape[2] image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose( 1, 2) # torch [batch_size, H*W, D] # print("image_a_pred: ", image_a_pred.shape) image_b_pred = descriptors_warped.view(1, -1, Hc * Wc).transpose( 1, 2) # torch [batch_size, H*W, D] # matches uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True) # print("uv_a: ", uv_a.shape) homographies_H = scale_homography_torch(homographies, image_shape, shift=(-1, -1)) uv_b_matches = warp_coor_cells_with_homographies(uv_a, homographies_H, uv=True) uv_b_matches = uv_b_matches.squeeze(0) # print("uv_b_matches: ", uv_b_matches.shape) # filtering out of range points # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True) uv_b_matches, mask = filter_points(uv_b_matches, torch.tensor([Wc, Hc]), return_mask=True) uv_a = uv_a[mask] # crop to the same length choice = crop_or_pad_choice(uv_b_matches.shape[0], num_matching_attempts, shuffle=True) choice = torch.tensor(choice) uv_a = uv_a[choice] uv_b_matches = uv_b_matches[choice] matches_a = uv_to_1d(uv_a, Hc) matches_b = uv_to_1d(uv_b_matches, Hc) # print("matches_a: ", matches_a.shape) # print("matches_b: ", matches_b.shape) # print("matches_b max: ", matches_b.max()) match_loss = get_match_loss(image_a_pred, image_b_pred, matches_a, matches_b) # non matches img_b_shape = (Hc, Wc) # get non matches correspondence uv_a_tuple, uv_b_non_matches_tuple = get_non_matches_corr( img_b_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=num_masked_non_matches_per_match) non_matches_a = tuple_to_1d(uv_a_tuple, Hc) non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Hc) # print("non_matches_a: ", non_matches_a) # print("non_matches_b: ", non_matches_b) non_match_loss = get_non_match_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_a) non_match_loss = non_match_loss.mean() loss = lamda_d * match_loss + non_match_loss return loss pass
dim=1) # (y, x) to (x, y) return coor_cells coor_cells = get_coor_cells(Hc, Wc, cell_size=cell_size, device=device, uv=True) print("coor_cells: ", coor_cells) print("coor_cells: ", coor_cells.shape) from utils.utils import filter_points filtered_points, mask = filter_points(coor_cells, torch.tensor([Wc, Hc]), return_mask=True) def warp_coor_cells_with_homographies(coor_cells, homographies, uv=False): from utils.utils import warp_points # warped_coor_cells = warp_points(coor_cells.view([-1, 2]), homographies, device) # warped_coor_cells = normPts(coor_cells.view([-1, 2]), shape) warped_coor_cells = coor_cells if uv == False: warped_coor_cells = torch.stack( (warped_coor_cells[:, 1], warped_coor_cells[:, 0]), dim=1) # (y, x) to (x, y) warped_coor_cells = warp_points(warped_coor_cells, homographies, device)
def quadruplet_descriptor_loss_sparse(descriptors, descriptors_warped, homographies, mask_valid=None, cell_size=8, device='cpu', descriptor_dist=4, lamda_d=250, num_matching_attempts=1000, num_masked_non_matches_per_match=10, dist='cos', method='1d', **config): """ consider batches of descriptors :param descriptors: Output from descriptor head tensor [descriptors, Hc, Wc] :param descriptors_warped: Output from descriptor head of warped image tensor [descriptors, Hc, Wc] """ from utils.utils import filter_points from utils.utils import crop_or_pad_choice from utils.utils import normPts # ##### print configs # print("num_masked_non_matches_per_match: ", num_masked_non_matches_per_match) # print("num_matching_attempts: ", num_matching_attempts) # dist = 'cos' # print("method: ", method) Hc, Wc, Dim = descriptors.shape[1], descriptors.shape[ 2], descriptors.shape[0] img_shape = (Hc, Wc) # print("img_shape: ", img_shape) # img_shape_cpu = (Hc.to('cpu'), Wc.to('cpu')) # image_a_pred = descriptors.view(1, -1, Hc * Wc).transpose(1, 2) # torch [batch_size, H*W, D] image_a_pred = descriptor_reshape(descriptors, Hc, Wc) # torch [1, H*W, D] # print("image_a_pred: ", image_a_pred.shape) image_b_pred = descriptor_reshape(descriptors_warped, Hc, Wc) # torch [batch_size, H*W, D] # matches uv_a = get_coor_cells(Hc, Wc, cell_size, uv=True, device='cpu') # print("uv_a: ", uv_a[0]) homographies_H = scale_homography_torch(homographies, img_shape, shift=(-1, -1)) # print("experiment inverse homographies") # homographies_H = torch.stack([torch.inverse(H) for H in homographies_H]) # print("homographies_H: ", homographies_H.shape) # homographies_H = torch.inverse(homographies_H) uv_b_matches = warp_coor_cells_with_homographies(uv_a, homographies_H.to('cpu'), uv=True, device='cpu') # # print("uv_b_matches before round: ", uv_b_matches[0]) uv_b_matches.round_() # print("uv_b_matches after round: ", uv_b_matches[0]) uv_b_matches = uv_b_matches.squeeze(0) # filtering out of range points # choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True) uv_b_matches, mask = filter_points(uv_b_matches, torch.tensor([Wc, Hc]).to(device='cpu'), return_mask=True) # print ("pos mask sum: ", mask.sum()) uv_a = uv_a[mask] # crop to the same length shuffle = True if not shuffle: print("shuffle: ", shuffle) choice = crop_or_pad_choice(uv_b_matches.shape[0], num_matching_attempts, shuffle=shuffle) choice = torch.tensor(choice) uv_a = uv_a[choice] uv_b_matches = uv_b_matches[choice] matches_a = normPts(uv_a, torch.tensor([Wc, Hc]).float()) # [u, v] matches_b = normPts(uv_b_matches, torch.tensor([Wc, Hc]).float()) # print("matches_a: ", matches_a.shape) # print("matches_b: ", matches_b.shape) # print("matches_b max: ", matches_b.max()) # non matches # get non matches correspondence uv_a_tuple, uv_b_matches_tuple, uv_b_non_matches_tuple = get_first_term_corr( img_shape, uv_a, uv_b_matches, num_masked_non_matches_per_match=num_masked_non_matches_per_match) non_matches_a = tuple_to_1d(uv_a_tuple, Wc) long_matches_b = tuple_to_1d(uv_b_matches_tuple, Wc) non_matches_b = tuple_to_1d(uv_b_non_matches_tuple, Wc) non_matches_a_descriptors = torch.index_select( image_a_pred, 1, non_matches_a.to(device).long().squeeze()).squeeze() non_matches_b_descriptors = torch.index_select( image_b_pred, 1, non_matches_b.to(device).long().squeeze()).squeeze() long_matches_b_descriptors = torch.index_select( image_b_pred, 1, long_matches_b.to(device).long().squeeze()).squeeze() all_neg_pairs = (non_matches_a_descriptors - non_matches_b_descriptors).pow(2).sum(dim=-1) all_pos_pairs = (non_matches_a_descriptors - long_matches_b_descriptors).pow(2).sum(dim=-1) alpha = (all_neg_pairs.sum() - all_pos_pairs.sum()) / ( num_matching_attempts * num_masked_non_matches_per_match) first_term = torch.clamp(all_pos_pairs - all_neg_pairs + alpha, min=0).sum() / (num_matching_attempts * num_masked_non_matches_per_match) pos_matches_a = non_matches_a_descriptors.reshape( (num_matching_attempts, num_masked_non_matches_per_match, Dim)).repeat(1, num_masked_non_matches_per_match, 1).reshape( (-1, Dim)) pos_matches_b = long_matches_b_descriptors.reshape( (num_matching_attempts, num_masked_non_matches_per_match, Dim)).repeat(1, num_masked_non_matches_per_match, 1).reshape( (-1, Dim)) neg_matches_b = non_matches_b_descriptors.reshape( (num_matching_attempts, num_masked_non_matches_per_match, 1, Dim)).repeat(1, 1, num_masked_non_matches_per_match, 1).reshape( (-1, Dim)) neg_matches_b_permute = non_matches_b_descriptors.reshape( (num_matching_attempts, 1, num_masked_non_matches_per_match, Dim)).repeat(1, num_masked_non_matches_per_match, 1, 1).reshape( (-1, Dim)) pos_matches = (pos_matches_a - pos_matches_b).pow(2).sum(dim=-1) neg_matches = (neg_matches_b - neg_matches_b_permute).pow(2).sum(dim=-1) match_mask = (torch.ones(num_masked_non_matches_per_match, num_masked_non_matches_per_match) - torch.eye(num_masked_non_matches_per_match)).repeat( num_matching_attempts, 1, 1).flatten().to(device) second_term = (torch.clamp( (pos_matches - neg_matches) * match_mask + 0.5 * alpha, min=0) / (num_masked_non_matches_per_match * (num_masked_non_matches_per_match - 1) * num_matching_attempts)).sum() # non_match_loss = non_match_loss.mean() loss = lamda_d * first_term + second_term return loss, lamda_d * first_term, second_term
def __getitem__(self, index): """ :param index: :return: labels_2D: tensor(1, H, W) image: tensor(1, H, W) """ def checkSat(img, name=""): if img.max() > 1: print(name, img.max()) elif img.min() < 0: print(name, img.min()) def imgPhotometric(img): """ :param img: numpy (H, W) :return: """ augmentation = self.ImgAugTransform(**self.config["augmentation"]) img = img[:, :, np.newaxis] img = augmentation(img) cusAug = self.customizedTransform() img = cusAug(img, **self.config["augmentation"]) return img def get_labels(pnts, H, W): labels = torch.zeros(H, W) # print('--2', pnts, pnts.size()) # pnts_int = torch.min(pnts.round().long(), torch.tensor([[H-1, W-1]]).long()) pnts_int = torch.min(pnts.round().long(), torch.tensor([[W - 1, H - 1]]).long()) # print('--3', pnts_int, pnts_int.size()) labels[pnts_int[:, 1], pnts_int[:, 0]] = 1 return labels def get_label_res(H, W, pnts): quan = lambda x: x.round().long() labels_res = torch.zeros(H, W, 2) # pnts_int = torch.min(pnts.round().long(), torch.tensor([[H-1, W-1]]).long()) labels_res[quan(pnts)[:, 1], quan(pnts)[:, 0], :] = pnts - pnts.round() # print("pnts max: ", quan(pnts).max(dim=0)) # print("labels_res: ", labels_res.shape) labels_res = labels_res.transpose(1, 2).transpose(0, 1) return labels_res from datasets.data_tools import np_to_tensor from utils.utils import filter_points from utils.var_dim import squeezeToNumpy sample = self.samples[index] img = load_as_float(sample["image"]) H, W = img.shape[0], img.shape[1] self.H = H self.W = W pnts = np.load(sample["points"]) # (y, x) pnts = torch.tensor(pnts).float() pnts = torch.stack((pnts[:, 1], pnts[:, 0]), dim=1) # (x, y) pnts = filter_points(pnts, torch.tensor([W, H])) sample = {} # print('pnts: ', pnts[:5]) # print('--1', pnts) labels_2D = get_labels(pnts, H, W) sample.update({"labels_2D": labels_2D.unsqueeze(0)}) # assert Hc == round(Hc) and Wc == round(Wc), "Input image size not fit in the block size" if (self.config["augmentation"]["photometric"]["enable_train"] and self.action == "training") or ( self.config["augmentation"]["photometric"]["enable_val"] and self.action == "validation"): # print('>>> Photometric aug enabled for %s.'%self.action) # augmentation = self.ImgAugTransform(**self.config["augmentation"]) img = imgPhotometric(img) else: # print('>>> Photometric aug disabled for %s.'%self.action) pass if not ((self.config["augmentation"]["homographic"]["enable_train"] and self.action == "training") or (self.config["augmentation"]["homographic"]["enable_val"] and self.action == "validation")): # print('<<< Homograpy aug disabled for %s.'%self.action) img = img[:, :, np.newaxis] # labels = labels.view(-1,H,W) if self.transform is not None: img = self.transform(img) sample["image"] = img # sample = {'image': img, 'labels_2D': labels} valid_mask = self.compute_valid_mask(torch.tensor([H, W]), inv_homography=torch.eye(3)) sample.update({"valid_mask": valid_mask}) labels_res = get_label_res(H, W, pnts) pnts_post = pnts # pnts_for_gaussian = pnts else: # print('>>> Homograpy aug enabled for %s.'%self.action) # img_warp = img from utils.utils import homography_scaling_torch as homography_scaling from numpy.linalg import inv homography = self.sample_homography( np.array([2, 2]), shift=-1, **self.config["augmentation"]["homographic"]["params"], ) ##### use inverse from the sample homography homography = inv(homography) ###### homography = torch.tensor(homography).float() inv_homography = homography.inverse() img = torch.from_numpy(img) warped_img = self.inv_warp_image(img.squeeze(), inv_homography, mode="bilinear") warped_img = warped_img.squeeze().numpy() warped_img = warped_img[:, :, np.newaxis] # labels = torch.from_numpy(labels) # warped_labels = self.inv_warp_image(labels.squeeze(), inv_homography, mode='nearest').unsqueeze(0) warped_pnts = self.warp_points( pnts, homography_scaling(homography, H, W)) warped_pnts = filter_points(warped_pnts, torch.tensor([W, H])) # pnts = warped_pnts[:, [1, 0]] # pnts_for_gaussian = warped_pnts # warped_labels = torch.zeros(H, W) # warped_labels[warped_pnts[:, 1], warped_pnts[:, 0]] = 1 # warped_labels = warped_labels.view(-1, H, W) if self.transform is not None: warped_img = self.transform(warped_img) # sample = {'image': warped_img, 'labels_2D': warped_labels} sample["image"] = warped_img valid_mask = self.compute_valid_mask( torch.tensor([H, W]), inv_homography=inv_homography, erosion_radius=self.config["augmentation"]["homographic"] ["valid_border_margin"], ) # can set to other value sample.update({"valid_mask": valid_mask}) labels_2D = get_labels(warped_pnts, H, W) sample.update({"labels_2D": labels_2D.unsqueeze(0)}) labels_res = get_label_res(H, W, warped_pnts) pnts_post = warped_pnts if self.gaussian_label: # warped_labels_gaussian = get_labels_gaussian(pnts) from datasets.data_tools import get_labels_bi labels_2D_bi = get_labels_bi(pnts_post, H, W) labels_gaussian = self.gaussian_blur(squeezeToNumpy(labels_2D_bi)) labels_gaussian = np_to_tensor(labels_gaussian, H, W) sample["labels_2D_gaussian"] = labels_gaussian # add residua sample.update({"labels_res": labels_res}) ### code for warped image if self.config["warped_pair"]["enable"]: from datasets.data_tools import warpLabels homography = self.sample_homography( np.array([2, 2]), shift=-1, **self.config["warped_pair"]["params"]) ##### use inverse from the sample homography homography = np.linalg.inv(homography) ##### inv_homography = np.linalg.inv(homography) homography = torch.tensor(homography).type(torch.FloatTensor) inv_homography = torch.tensor(inv_homography).type( torch.FloatTensor) # photometric augmentation from original image # warp original image warped_img = img.type(torch.FloatTensor) warped_img = self.inv_warp_image(warped_img.squeeze(), inv_homography, mode="bilinear").unsqueeze(0) if (self.enable_photo_train == True and self.action == "train") or (self.enable_photo_val and self.action == "val"): warped_img = imgPhotometric( warped_img.numpy().squeeze()) # numpy array (H, W, 1) warped_img = torch.tensor(warped_img, dtype=torch.float32) pass warped_img = warped_img.view(-1, H, W) # warped_labels = warpLabels(pnts, H, W, homography) warped_set = warpLabels(pnts, H, W, homography, bilinear=True) warped_labels = warped_set["labels"] warped_res = warped_set["res"] warped_res = warped_res.transpose(1, 2).transpose(0, 1) # print("warped_res: ", warped_res.shape) if self.gaussian_label: # print("do gaussian labels!") # warped_labels_gaussian = get_labels_gaussian(warped_set['warped_pnts'].numpy()) # warped_labels_bi = self.inv_warp_image(labels_2D.squeeze(), inv_homography, mode='nearest').unsqueeze(0) # bilinear, nearest warped_labels_bi = warped_set["labels_bi"] warped_labels_gaussian = self.gaussian_blur( squeezeToNumpy(warped_labels_bi)) warped_labels_gaussian = np_to_tensor(warped_labels_gaussian, H, W) sample["warped_labels_gaussian"] = warped_labels_gaussian sample.update({"warped_labels_bi": warped_labels_bi}) sample.update({ "warped_img": warped_img, "warped_labels": warped_labels, "warped_res": warped_res, }) # print('erosion_radius', self.config['warped_pair']['valid_border_margin']) valid_mask = self.compute_valid_mask( torch.tensor([H, W]), inv_homography=inv_homography, erosion_radius=self.config["warped_pair"] ["valid_border_margin"], ) # can set to other value sample.update({"warped_valid_mask": valid_mask}) sample.update({ "homographies": homography, "inv_homographies": inv_homography }) # labels = self.labels2Dto3D(self.cell_size, labels) # labels = torch.from_numpy(labels[np.newaxis,:,:]) # input.update({'labels': labels}) ### code for warped image # if self.config['gaussian_label']['enable']: # heatmaps = np.zeros((H, W)) # # for center in pnts_int.numpy(): # for center in pnts[:, [1, 0]].numpy(): # # print("put points: ", center) # heatmaps = self.putGaussianMaps(center, heatmaps) # # import matplotlib.pyplot as plt # # plt.figure(figsize=(5, 10)) # # plt.subplot(211) # # plt.imshow(heatmaps) # # plt.colorbar() # # plt.subplot(212) # # plt.imshow(np.squeeze(warped_labels.numpy())) # # plt.show() # # import time # # time.sleep(500) # # results = self.pool.map(self.putGaussianMaps_par, warped_pnts.numpy()) # warped_labels_gaussian = torch.from_numpy(heatmaps).view(-1, H, W) # warped_labels_gaussian[warped_labels_gaussian>1.] = 1. # sample['labels_2D_gaussian'] = warped_labels_gaussian if self.getPts: sample.update({"pts": pnts}) return sample