def warp_coor_cells_with_homographies(coor_cells, homographies, uv=False, device='cpu'): from utils.utils import warp_points # warped_coor_cells = warp_points(coor_cells.view([-1, 2]), homographies, device) # warped_coor_cells = normPts(coor_cells.view([-1, 2]), shape) warped_coor_cells = coor_cells if uv == False: warped_coor_cells = torch.stack( (warped_coor_cells[:, 1], warped_coor_cells[:, 0]), dim=1) # (y, x) to (x, y) # print("homographies: ", homographies) warped_coor_cells = warp_points(warped_coor_cells, homographies, device) if uv == False: warped_coor_cells = torch.stack( (warped_coor_cells[:, :, 1], warped_coor_cells[:, :, 0]), dim=2) # (batch, x, y) to (batch, y, x) # shape_cell = torch.tensor([H//cell_size, W//cell_size]).type(torch.FloatTensor).to(device) # warped_coor_mask = denormPts(warped_coor_cells, shape_cell) return warped_coor_cells
def warpLabels(pnts, H, W, homography, bilinear=False): from utils.utils import homography_scaling_torch as homography_scaling from utils.utils import filter_points from utils.utils import warp_points if isinstance(pnts, torch.Tensor): pnts = pnts.long() else: pnts = torch.tensor(pnts).long() warped_pnts = warp_points(torch.stack((pnts[:, 0], pnts[:, 1]), dim=1), homography_scaling(homography, H, W)) # check the (x, y) outs = {} # warped_pnts # print("extrapolate_points!!") # ext_points = True if bilinear == True: warped_labels_bi = get_labels_bi(warped_pnts, H, W) outs['labels_bi'] = warped_labels_bi warped_pnts = filter_points(warped_pnts, torch.tensor([W, H])) warped_labels = scatter_points(warped_pnts, H, W, res_ext=1) warped_labels_res = torch.zeros(H, W, 2) warped_labels_res[ quan(warped_pnts)[:, 1], quan(warped_pnts)[:, 0], :] = warped_pnts - warped_pnts.round() # print("res sum: ", (warped_pnts - warped_pnts.round()).sum()) outs.update({ 'labels': warped_labels, 'res': warped_labels_res, 'warped_pnts': warped_pnts }) return outs
def inv_warp_image_batch(img, mat_homo_inv, device='cpu', mode='bilinear'): ''' Inverse warp images in batch :param img: batch of images tensor [batch_size, 1, H, W] :param mat_homo_inv: batch of homography matrices tensor [batch_size, 3, 3] :param device: GPU device or CPU :return: batch of warped images tensor [batch_size, 1, H, W] ''' # compute inverse warped points if len(img.shape) == 2 or len(img.shape) == 3: img = img.view(1,1,img.shape[0], img.shape[1]) if len(mat_homo_inv.shape) == 2: mat_homo_inv = mat_homo_inv.view(1,3,3) Batch, channel, H, W = img.shape coor_cells = torch.stack(torch.meshgrid(torch.linspace(-1, 1, W), torch.linspace(-1, 1, H)), dim=2) coor_cells = coor_cells.transpose(0, 1) coor_cells = coor_cells.to(device) coor_cells = coor_cells.contiguous() src_pixel_coords = warp_points(coor_cells.view([-1, 2]), mat_homo_inv, device) src_pixel_coords = src_pixel_coords.view([Batch, H, W, 2]) src_pixel_coords = src_pixel_coords.float() warped_img = F.grid_sample(img, src_pixel_coords, mode=mode, align_corners=True) return warped_img
def warpLabels(pnts, homography, H, W): import torch """ input: pnts: numpy homography: numpy output: warped_pnts: numpy """ from utils.utils import warp_points from utils.utils import filter_points pnts = torch.tensor(pnts).long() homography = torch.tensor(homography, dtype=torch.float32) warped_pnts = warp_points(torch.stack((pnts[:, 0], pnts[:, 1]), dim=1), homography) # check the (x, y) warped_pnts = filter_points(warped_pnts, torch.tensor([W, H])).round().long() return warped_pnts.numpy()
def descriptor_loss(descriptors, descriptors_warped, homographies, mask_valid=None, cell_size=8, lamda_d=250, device='cpu', descriptor_dist=4, **config): ''' Compute descriptor loss from descriptors_warped and given homographies :param descriptors: Output from descriptor head tensor [batch_size, descriptors, Hc, Wc] :param descriptors_warped: Output from descriptor head of warped image tensor [batch_size, descriptors, Hc, Wc] :param homographies: known homographies :param cell_size: 8 :param device: gpu or cpu :param config: :return: loss, and other tensors for visualization ''' # put to gpu homographies = homographies.to(device) # config from utils.utils import warp_points lamda_d = lamda_d # 250 margin_pos = 1 margin_neg = 0.2 batch_size, Hc, Wc = descriptors.shape[0], descriptors.shape[2], descriptors.shape[3] ##### # H, W = Hc.numpy().astype(int) * cell_size, Wc.numpy().astype(int) * cell_size H, W = Hc * cell_size, Wc * cell_size ##### with torch.no_grad(): # shape = torch.tensor(list(descriptors.shape[2:]))*torch.tensor([cell_size, cell_size]).type(torch.FloatTensor).to(device) shape = torch.tensor([H, W]).type(torch.FloatTensor).to(device) # compute the center pixel of every cell in the image coor_cells = torch.stack(torch.meshgrid(torch.arange(Hc), torch.arange(Wc)), dim=2) coor_cells = coor_cells.type(torch.FloatTensor).to(device) coor_cells = coor_cells * cell_size + cell_size // 2 ## coord_cells is now a grid containing the coordinates of the Hc x Wc ## center pixels of the 8x8 cells of the image # coor_cells = coor_cells.view([-1, Hc, Wc, 1, 1, 2]) coor_cells = coor_cells.view([-1, 1, 1, Hc, Wc, 2]) # be careful of the order # warped_coor_cells = warp_points(coor_cells.view([-1, 2]), homographies, device) warped_coor_cells = normPts(coor_cells.view([-1, 2]), shape) warped_coor_cells = torch.stack((warped_coor_cells[:,1], warped_coor_cells[:,0]), dim=1) # (y, x) to (x, y) warped_coor_cells = warp_points(warped_coor_cells, homographies, device) warped_coor_cells = torch.stack((warped_coor_cells[:, :, 1], warped_coor_cells[:, :, 0]), dim=2) # (batch, x, y) to (batch, y, x) shape_cell = torch.tensor([H//cell_size, W//cell_size]).type(torch.FloatTensor).to(device) # warped_coor_mask = denormPts(warped_coor_cells, shape_cell) warped_coor_cells = denormPts(warped_coor_cells, shape) # warped_coor_cells = warped_coor_cells.view([-1, 1, 1, Hc, Wc, 2]) warped_coor_cells = warped_coor_cells.view([-1, Hc, Wc, 1, 1, 2]) # print("warped_coor_cells: ", warped_coor_cells.shape) # compute the pairwise distance cell_distances = coor_cells - warped_coor_cells cell_distances = torch.norm(cell_distances, dim=-1) ##### check # print("descriptor_dist: ", descriptor_dist) mask = cell_distances <= descriptor_dist # 0.5 # trick mask = mask.type(torch.FloatTensor).to(device) # compute the pairwise dot product between descriptors: d^t * d descriptors = descriptors.transpose(1, 2).transpose(2, 3) descriptors = descriptors.view((batch_size, Hc, Wc, 1, 1, -1)) descriptors_warped = descriptors_warped.transpose(1, 2).transpose(2, 3) descriptors_warped = descriptors_warped.view((batch_size, 1, 1, Hc, Wc, -1)) dot_product_desc = descriptors * descriptors_warped dot_product_desc = dot_product_desc.sum(dim=-1) ## dot_product_desc.shape = [batch_size, Hc, Wc, Hc, Wc, desc_len] # hinge loss positive_dist = torch.max(margin_pos - dot_product_desc, torch.tensor(0.).to(device)) # positive_dist[positive_dist < 0] = 0 negative_dist = torch.max(dot_product_desc - margin_neg, torch.tensor(0.).to(device)) # negative_dist[neative_dist < 0] = 0 # sum of the dimension if mask_valid is None: # mask_valid = torch.ones_like(mask) mask_valid = torch.ones(batch_size, 1, Hc*cell_size, Wc*cell_size) mask_valid = mask_valid.view(batch_size, 1, 1, mask_valid.shape[2], mask_valid.shape[3]) loss_desc = lamda_d * mask * positive_dist + (1 - mask) * negative_dist loss_desc = loss_desc * mask_valid # mask_validg = torch.ones_like(mask) ##### bug in normalization normalization = (batch_size * (mask_valid.sum()+1) * Hc * Wc) pos_sum = (lamda_d * mask * positive_dist/normalization).sum() neg_sum = ((1 - mask) * negative_dist/normalization).sum() loss_desc = loss_desc.sum() / normalization # loss_desc = loss_desc.sum() / (batch_size * Hc * Wc) # return loss_desc, mask, mask_valid, positive_dist, negative_dist return loss_desc, mask, pos_sum, neg_sum