def test_warp_perspective(self): # generate input data batch_size = 1 height, width = 16, 32 alpha = tgm.pi / 2 # 90 deg rotation # create data patch patch = torch.rand(batch_size, 1, height, width) # create transformation (rotation) M = torch.tensor([[ [torch.cos(alpha), -torch.sin(alpha), 0.], [torch.sin(alpha), torch.cos(alpha), 0.], [0., 0., 1.], ]]) # Bx3x3 # apply transformation and inverse _, _, h, w = patch.shape patch_warped = tgm.warp_perspective(patch, M, dsize=(height, width)) patch_warped_inv = tgm.warp_perspective(patch_warped, tgm.inverse(M), dsize=(height, width)) # generate mask to compute error mask = torch.ones_like(patch) mask_warped_inv = tgm.warp_perspective( tgm.warp_perspective(patch, M, dsize=(height, width)), tgm.inverse(M), dsize=(height, width)) res = utils.check_equal_torch(mask_warped_inv * patch, mask_warped_inv * patch_warped_inv) self.assertTrue(res)
def get_warped_images_torch(base_img): h, w = base_img.shape[1:] imgs = [] rotXhere = rotXdeg distances, degrees, images = [], [], [] for dist in range(start_dist, min_dist, -step_size): rotXhere -= x_decay rotZdeg = (2 * z_deviation) * np.random.random_sample() + z_min rotX = np.deg2rad(rotXhere - 90) rotZ = np.deg2rad(rotZdeg - 90) # Final and overall transformation matrix H = get_H_matrix(dist, w, h, rotX, rotZ) H = ch.from_numpy(H) img_warp = tgm.warp_perspective(base_img.unsqueeze(0), H, dsize=(h, w)) imgs.append(img_warp) distances.append(dist) degrees.append(rotZdeg) return np.array(distances), np.array(degrees), ch.cat(imgs, 0).float()
def test_warp_perspective_crop(self): # generate input data batch_size = 1 src_h, src_w = 3, 4 dst_h, dst_w = 3, 2 # [x, y] origin # top-left, top-right, bottom-right, bottom-left points_src = torch.FloatTensor([[ [1, 0], [2, 0], [2, 2], [1, 2], ]]) # [x, y] destination # top-left, top-right, bottom-right, bottom-left points_dst = torch.FloatTensor([[ [0, 0], [dst_w - 1, 0], [dst_w - 1, dst_h - 1], [0, dst_h - 1], ]]) # compute transformation between points dst_pix_trans_src_pix = tgm.get_perspective_transform( points_src, points_dst) # create points grid in normalized coordinates grid_src_norm = tgm.create_meshgrid(src_h, src_w, normalized_coordinates=True) grid_src_norm = torch.unsqueeze(grid_src_norm, dim=0) # create points grid in pixel coordinates grid_src_pix = tgm.create_meshgrid(src_h, src_w, normalized_coordinates=False) grid_src_pix = torch.unsqueeze(grid_src_pix, dim=0) src_norm_trans_src_pix = tgm.normal_transform_pixel(src_h, src_w) src_pix_trans_src_norm = tgm.inverse(src_norm_trans_src_pix) dst_norm_trans_dst_pix = tgm.normal_transform_pixel(dst_h, dst_w) # transform pixel grid grid_dst_pix = tgm.transform_points( dst_pix_trans_src_pix, grid_src_pix) grid_dst_norm = tgm.transform_points( dst_norm_trans_dst_pix, grid_dst_pix) # transform norm grid dst_norm_trans_src_norm = torch.matmul( dst_norm_trans_dst_pix, torch.matmul( dst_pix_trans_src_pix, src_pix_trans_src_norm)) grid_dst_norm2 = tgm.transform_points( dst_norm_trans_src_norm, grid_src_norm) # grids should be equal self.assertTrue(utils.check_equal_torch( grid_dst_norm, grid_dst_norm2)) # warp tensor patch = torch.rand(batch_size, 1, src_h, src_w) patch_warped = tgm.warp_perspective( patch, dst_pix_trans_src_pix, (dst_h, dst_w)) self.assertTrue(utils.check_equal_torch( patch[:, :, :3, 1:3], patch_warped))
def test_warp_perspective_rotation(batch_shape, device_type): # generate input data batch_size, channels, height, width = batch_shape alpha = 0.5 * tgm.pi * torch.ones(batch_size) # 90 deg rotation # create data patch device = torch.device(device_type) patch = torch.rand(batch_shape).to(device) # create transformation (rotation) M = torch.eye(3, device=device).repeat(batch_size, 1, 1) # Bx3x3 M[:, 0, 0] = torch.cos(alpha) M[:, 0, 1] = -torch.sin(alpha) M[:, 1, 0] = torch.sin(alpha) M[:, 1, 1] = torch.cos(alpha) # apply transformation and inverse _, _, h, w = patch.shape patch_warped = tgm.warp_perspective(patch, M, dsize=(height, width)) patch_warped_inv = tgm.warp_perspective(patch_warped, torch.inverse(M), dsize=(height, width)) # generate mask to compute error mask = torch.ones_like(patch) mask_warped_inv = tgm.warp_perspective(tgm.warp_perspective(patch, M, dsize=(height, width)), torch.inverse(M), dsize=(height, width)) assert utils.check_equal_torch(mask_warped_inv * patch, mask_warped_inv * patch_warped_inv) # evaluate function gradient patch = utils.tensor_to_gradcheck_var(patch) # to var M = utils.tensor_to_gradcheck_var(M, requires_grad=False) # to var assert gradcheck(tgm.warp_perspective, (patch, M, ( height, width, )), raise_exception=True)
def test_crop(self, device_type, batch_size, channels): # generate input data src_h, src_w = 3, 3 dst_h, dst_w = 3, 3 device = torch.device(device_type) # [x, y] origin # top-left, top-right, bottom-right, bottom-left points_src = torch.FloatTensor([[ [0, 0], [0, src_w - 1], [src_h - 1, src_w - 1], [src_h - 1, 0], ]]) # [x, y] destination # top-left, top-right, bottom-right, bottom-left points_dst = torch.FloatTensor([[ [0, 0], [0, dst_w - 1], [dst_h - 1, dst_w - 1], [dst_h - 1, 0], ]]) # compute transformation between points dst_trans_src = tgm.get_perspective_transform(points_src, points_dst).expand( batch_size, -1, -1) # warp tensor patch = torch.FloatTensor([[[ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16], ]]]).expand(batch_size, channels, -1, -1) expected = torch.FloatTensor([[[ [1, 2, 3], [5, 6, 7], [9, 10, 11], ]]]) # warp and assert patch_warped = tgm.warp_perspective(patch, dst_trans_src, (dst_h, dst_w)) assert_allclose(patch_warped, expected)
def bilinear(I, theta, mode='translation', device='cuda:0'): ## shift images I (B,C,H,W) by traslation theta (b,2) theta = theta H = torch.eye(3).to(device) H = H.repeat(I.size()[0], 1).view(-1, 3, 3) H[:, 0, 2] = theta[:, 0] # tx H[:, 1, 2] = theta[:, 1] # ty if mode == 'rototranslation': H[:, 0, 0] = torch.cos(theta[:, 2]) # cos H[:, 0, 1] = -torch.sin(theta[:, 2]) # -sin H[:, 1, 0] = torch.sin(theta[:, 2]) # sin H[:, 1, 1] = torch.cos(theta[:, 2]) # cos new_I = tgm.warp_perspective(I, H, dsize=(I.size()[2], I.size()[3])) # bilinear interpolation return new_I
def test_crop_center_resize(self, device_type): # generate input data dst_h, dst_w = 4, 4 device = torch.device(device_type) # [x, y] origin # top-left, top-right, bottom-right, bottom-left points_src = torch.FloatTensor([[ [1, 1], [1, 2], [2, 2], [2, 1], ]]) # [x, y] destination # top-left, top-right, bottom-right, bottom-left points_dst = torch.FloatTensor([[ [0, 0], [0, dst_w - 1], [dst_h - 1, dst_w - 1], [dst_h - 1, 0], ]]) # compute transformation between points dst_trans_src = tgm.get_perspective_transform(points_src, points_dst) # warp tensor patch = torch.FloatTensor([[[ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16], ]]]) expected = torch.FloatTensor([[[ [6.000, 6.333, 6.666, 7.000], [7.333, 7.666, 8.000, 8.333], [8.666, 9.000, 9.333, 9.666], [10.000, 10.333, 10.666, 11.000], ]]]) # warp and assert patch_warped = tgm.warp_perspective(patch, dst_trans_src, (dst_h, dst_w)) assert_allclose(patch_warped, expected)
def test_warp_perspective_crop(batch_size, device_type, channels): # generate input data src_h, src_w = 3, 4 dst_h, dst_w = 3, 2 device = torch.device(device_type) # [x, y] origin # top-left, top-right, bottom-right, bottom-left points_src = torch.rand(batch_size, 4, 2).to(device) points_src[:, :, 0] *= dst_h points_src[:, :, 1] *= dst_w # [x, y] destination # top-left, top-right, bottom-right, bottom-left points_dst = torch.zeros_like(points_src) points_dst[:, 1, 0] = dst_w - 1 points_dst[:, 2, 0] = dst_w - 1 points_dst[:, 2, 1] = dst_h - 1 points_dst[:, 3, 1] = dst_h - 1 # compute transformation between points dst_pix_trans_src_pix = tgm.get_perspective_transform( points_src, points_dst) # create points grid in normalized coordinates grid_src_norm = tgm.create_meshgrid(src_h, src_w, normalized_coordinates=True) grid_src_norm = grid_src_norm.repeat(batch_size, 1, 1, 1).to(device) # create points grid in pixel coordinates grid_src_pix = tgm.create_meshgrid(src_h, src_w, normalized_coordinates=False) grid_src_pix = grid_src_pix.repeat(batch_size, 1, 1, 1).to(device) src_norm_trans_src_pix = tgm.normal_transform_pixel(src_h, src_w).repeat( batch_size, 1, 1).to(device) src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix) dst_norm_trans_dst_pix = tgm.normal_transform_pixel(dst_h, dst_w).repeat( batch_size, 1, 1).to(device) # transform pixel grid grid_dst_pix = tgm.transform_points(dst_pix_trans_src_pix.unsqueeze(1), grid_src_pix) grid_dst_norm = tgm.transform_points(dst_norm_trans_dst_pix.unsqueeze(1), grid_dst_pix) # transform norm grid dst_norm_trans_src_norm = torch.matmul( dst_norm_trans_dst_pix, torch.matmul(dst_pix_trans_src_pix, src_pix_trans_src_norm)) grid_dst_norm2 = tgm.transform_points(dst_norm_trans_src_norm.unsqueeze(1), grid_src_norm) # grids should be equal # TODO: investage why precision is that low assert utils.check_equal_torch(grid_dst_norm, grid_dst_norm2, 1e-2) # warp tensor patch = torch.rand(batch_size, channels, src_h, src_w) patch_warped = tgm.warp_perspective(patch, dst_pix_trans_src_pix, (dst_h, dst_w)) assert patch_warped.shape == (batch_size, channels, dst_h, dst_w)
def forward_pass(secret_img, secret_target, cover_img, cover_target, Hnet, Rnet, criterion, val_cover=0, i_c=None, position=None, Se_two=None, with_std_noise=True, with_warp_noise=True): batch_size_secret, channel_secret, _, _ = secret_img.size() batch_size_cover, channel_cover, _, _ = cover_img.size() if opt.cuda: cover_img = cover_img.cuda() secret_img = secret_img.cuda() #concat_img = concat_img.cuda() secret_imgv = secret_img.view(batch_size_secret // opt.num_secret, channel_secret * opt.num_secret, opt.imageSize, opt.imageSize) secret_imgv_nh = secret_imgv.repeat(opt.num_training, 1, 1, 1) cover_img = cover_img.view(batch_size_cover // opt.num_cover, channel_cover * opt.num_cover, opt.imageSize, opt.imageSize) if opt.no_cover and ( val_cover == 0 ): # if val_cover = 1, always use cover in val; otherwise, no_cover True >>> not using cover in training cover_img.fill_(0.0) if (opt.plain_cover or opt.noise_cover) and (val_cover == 0): cover_img.fill_(0.0) b, c, w, h = cover_img.size() if opt.plain_cover and (val_cover == 0): img_w1 = torch.cat( (torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2) img_w2 = torch.cat( (torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2) img_w3 = torch.cat( (torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2) img_w4 = torch.cat( (torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2) img_wh = torch.cat((img_w1, img_w2, img_w3, img_w4), dim=3) cover_img = cover_img + img_wh if opt.noise_cover and (val_cover == 0): cover_img = cover_img + ( (torch.rand(b, c, w, h) - 0.5) * 2 * 0 / 255).cuda() cover_imgv = cover_img if opt.cover_dependent: H_input = torch.cat((cover_imgv, secret_imgv), dim=1) else: H_input = secret_imgv itm_secret_img = Hnet(H_input) if i_c != None: if type(i_c) == type(1.0): #######To keep one channel itm_secret_img_clone = itm_secret_img.clone() itm_secret_img.fill_(0) itm_secret_img[:, int(i_c):int(i_c) + 1, :, :] = itm_secret_img_clone[:, int(i_c):int(i_c) + 1, :, :] if type(i_c) == type(1): print('aaaaa', i_c) #######To set one channel to zero itm_secret_img[:, i_c:i_c + 1, :, :].fill_(0.0) if position != None: itm_secret_img[:, :, position:position + 1, position:position + 1].fill_(0.0) if Se_two == 2: itm_secret_img_half = itm_secret_img[0:batch_size_secret // 2, :, :, :] itm_secret_img = itm_secret_img + torch.cat( (itm_secret_img_half.clone().fill_(0.0), itm_secret_img_half), 0) elif type(Se_two) == type(0.1): itm_secret_img = itm_secret_img + Se_two * torch.rand( itm_secret_img.size()).cuda() if opt.cover_dependent: container_img = itm_secret_img else: itm_secret_img = itm_secret_img.repeat(opt.num_training, 1, 1, 1) container_img = itm_secret_img + cover_imgv print('L1 metric of residual={:.4f}'.format( itm_secret_img.abs().mean())) errH = criterion(container_img, cover_imgv) # Hiding net if with_std_noise: std_noise = (torch.rand(1) * 0.05).item() noise = torch.randn_like(container_img) * std_noise container_img = container_img + noise if with_warp_noise: # Get random homography matrix homography = get_rand_homography_mat(opt.imageSize, opt.imageSize * 0.1, opt.bs_secret) homography = torch.from_numpy(homography).float().cuda() # Apply the homography and then undo it # (DCMMC) here is the key to train model for Universal photographic steganography container_img = tgm.warp_perspective(container_img, homography[:, 1], (opt.imageSize, opt.imageSize)) container_img = tgm.warp_perspective(container_img, homography[:, 0], (opt.imageSize, opt.imageSize)) rev_secret_img = Rnet(container_img) #secret_imgv = Variable(secret_img) errR = criterion(rev_secret_img, secret_imgv_nh) # Reveal net # L1 metric diffH = (container_img - cover_imgv).abs().mean() * 255 diffR = (rev_secret_img - secret_imgv_nh).abs().mean() * 255 return cover_imgv, container_img, secret_imgv_nh, rev_secret_img, errH, errR, diffH, diffR
def loss_function( model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False ): output = model({ 'image1': batch['image1'].to(device), 'image2': batch['image2'].to(device) }) loss = torch.tensor(np.array([0], dtype=np.float32), device=device) has_grad = False n_valid_samples = 0 for idx_in_batch in range(batch['image1'].size(0)): # Annotations depth1 = batch['depth1'][idx_in_batch].to(device) # [h1, w1] intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device) # [3, 3] pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device) # [4, 4] bbox1 = batch['bbox1'][idx_in_batch].to(device) # [2] depth2 = batch['depth2'][idx_in_batch].to(device) intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device) pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device) bbox2 = batch['bbox2'][idx_in_batch].to(device) # Network output dense_features1 = output['dense_features1'][idx_in_batch] c, h1, w1 = dense_features1.size() scores1 = output['scores1'][idx_in_batch].view(-1) dense_features2 = output['dense_features2'][idx_in_batch] _, h2, w2 = dense_features2.size() scores2 = output['scores2'][idx_in_batch] all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0) descriptors1 = all_descriptors1 all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0) # Warp the positions from image 1 to image 2 fmap_pos1 = grid_positions(h1, w1, device) hOrig, wOrig = int(batch['image1'].shape[2]/8), int(batch['image1'].shape[3]/8) fmap_pos1Orig = grid_positions(hOrig, wOrig, device) pos1 = upscale_positions(fmap_pos1Orig, scaling_steps=scaling_steps) try: pos1, pos2, ids = warp( pos1, depth1, intrinsics1, pose1, bbox1, depth2, intrinsics2, pose2, bbox2 ) except EmptyTensorError: continue H1 = output['H1'][idx_in_batch] H2 = output['H2'][idx_in_batch] try: pos1, pos2 = homoAlign(pos1, pos2, H1, H2, device) except IndexError: continue ids = idsAlign(pos1, device) img_warp1 = tgm.warp_perspective(batch['image1'].to(device), H1, dsize=(400, 400)) img_warp2 = tgm.warp_perspective(batch['image2'].to(device), H2, dsize=(400, 400)) # drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output) # exit(1) fmap_pos1 = fmap_pos1[:, ids] descriptors1 = descriptors1[:, ids] scores1 = scores1[ids] # Skip the pair if not enough GT correspondences are available if ids.size(0) < 128: continue # Descriptors at the corresponding positions fmap_pos2 = torch.round( downscale_positions(pos2, scaling_steps=scaling_steps) ).long() descriptors2 = F.normalize( dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]], dim=0 ) positive_distance = 2 - 2 * ( descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2) ).squeeze() all_fmap_pos2 = grid_positions(h2, w2, device) position_distance = torch.max( torch.abs( fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1) ), dim=0 )[0] is_out_of_safe_radius = position_distance > safe_radius distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2) # negative_distance2 = torch.min( # distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., # dim=1 # )[0] negative_distance2 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin) all_fmap_pos1 = grid_positions(h1, w1, device) position_distance = torch.max( torch.abs( fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1) ), dim=0 )[0] is_out_of_safe_radius = position_distance > safe_radius distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1) # negative_distance1 = torch.min( # distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., # dim=1 # )[0] negative_distance1 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin) diff = positive_distance - torch.min( negative_distance1, negative_distance2 ) # if(batch['batch_idx']%20 == 0): # print("positive_distance: {} | negative_distance: {}".format(positive_distance, torch.min(negative_distance1, negative_distance2))) scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]] loss = loss + ( torch.sum(scores1 * scores2 * F.relu(margin + diff)) / (torch.sum(scores1 * scores2) ) ) has_grad = True n_valid_samples += 1 if plot and batch['batch_idx'] % batch['log_interval'] == 0: # drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=True) drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output, save=True) if not has_grad: raise NoGradientError loss = loss / (n_valid_samples ) return loss
def forward(self, img1, img2): img_warp1 = tgm.warp_perspective(img1, self.H1, dsize=(400, 400)) img_warp2 = tgm.warp_perspective(img2, self.H2, dsize=(400, 400)) return img_warp1, img_warp2, self.H1, self.H2
def forward(self, x, M): x = tgm.warp_perspective(x, M, dsize=(self.output_height, self.output_width)) return x
import utils from my_perspective_transform import spatial_transformer_network print(torch.cuda.is_available()) original_image = cv2.imread('test_im_hidden.png') original_image = torch.cuda.FloatTensor(original_image) original_image = original_image.unsqueeze(0) Ms = utils.opencv_get_rand_transform_matrix(400, 400, 50, 1) print(Ms) Ms = torch.cuda.FloatTensor(Ms) original_image = original_image.permute(0, 3, 1, 2) transform_image = torchgeometry.warp_perspective(original_image, Ms[:, 1, :, :], dsize=(400, 400), flags='bilinear') transform_image = transform_image.permute(0, 2, 3, 1) print(transform_image.size()) transform_image = transform_image.cpu().detach().numpy() transform_image = transform_image.astype(np.uint8) transform_image = transform_image[0, :, :, :] cv2.imwrite('transform1.png', transform_image) original_image = cv2.imread('transform1.png') original_image = torch.cuda.FloatTensor(original_image) original_image = original_image.unsqueeze(0) original_image = original_image.permute(0, 3, 1, 2) transform_image = spatial_transformer_network(original_image, Ms[:, 1, :, :]) transform_image = transform_image.permute(0, 2, 3, 1) transform_image = transform_image.cpu().detach().numpy()
def build_model(encoder, decoder, discriminator, secret_input, image_input, l2_edge_gain, borders, secret_size, M, loss_scales, yuv_scales, args, global_step, writer): test_transform = transform_net(image_input, args, global_step) input_warped = torchgeometry.warp_perspective(image_input, M[:, 1, :, :], dsize=(400, 400), flags='bilinear') mask_warped = torchgeometry.warp_perspective(torch.ones_like(input_warped), M[:, 1, :, :], dsize=(400, 400), flags='bilinear') input_warped += (1 - mask_warped) * image_input residual_warped = encoder((secret_input, input_warped)) encoded_warped = residual_warped + input_warped residual = torchgeometry.warp_perspective(residual_warped, M[:, 0, :, :], dsize=(400, 400), flags='bilinear') if borders == 'no_edge': encoded_image = image_input + residual elif borders == 'black': encoded_image = residual_warped + input_warped encoded_image = torchgeometry.warp_perspective(encoded_image, M[:, 0, :, :], dsize=(400, 400), flags='bilinear') input_unwarped = torchgeometry.warp_perspective(image_input, M[:, 0, :, :], dsize=(400, 400), flags='bilinear') elif borders.startswith('random'): mask = torchgeometry.warp_perspective(torch.ones_like(residual), M[:, 0, :, :], dsize=(400, 400), flags='bilinear') encoded_image = residual_warped + input_unwarped encoded_image = torchgeometry.warp_perspective(encoded_image, M[:, 0, :, :], dsize=(400, 400), flags='bilinear') input_unwarped = torchgeometry.warp_perspective(input_warped, M[:, 0, :, :], dsize=(400, 400), flags='bilinear') ch = 3 if borders.endswith('rgb') else 1 encoded_image += (1-mask) * torch.ones_like(residual) * torch.rand([ch]) elif borders == 'white': mask = torchgeometry.warp_perspective(torch.ones_like(residual), M[:, 0, :, :], dsize=(400, 400), flags='bilinear') encoded_image = residual_warped + input_warped encoded_image = torchgeometry.warp_perspective(encoded_image, M[:, 0, :, :], dsize=(400, 400), flags='bilinear') input_unwarped = torchgeometry.warp_perspective(input_warped, M[:, 0, :, :], dsize=(400, 400), flags='bilinear') encoded_image += (1 - mask) * torch.ones_like(residual) elif borders == 'image': mask = torchgeometry.warp_perspective(torch.ones_like(residual), M[:, 0, :, :], dsize=(400, 400), flags='bilinear') encoded_image = residual_warped + input_warped encoded_image = torchgeometry.warp_perspective(encoded_image, M[:, 0, :, :], dsize=(400, 400), flags='bilinear') encoded_image += (1-mask) * torch.roll(image_input, 1, 0) if borders == 'no_edge': D_output_real, _ = discriminator(image_input) D_output_fake, D_heatmap = discriminator(encoded_image) else: D_output_real, _ = discriminator(input_warped) D_output_fake, D_heatmap = discriminator(encoded_warped) transformed_image = transform_net(encoded_image, args, global_step) decoded_secret = decoder(transformed_image) bit_acc, str_acc = get_secret_acc(secret_input, decoded_secret) lpips_loss = torch.mean(lpips(image_input, encoded_image)) cross_entropy = nn.BCELoss() if args.cuda: cross_entropy = cross_entropy.cuda() secret_loss = cross_entropy(decoded_secret, secret_input) size = (int(image_input.shape[2]), int(image_input.shape[3])) gain = 10 falloff_speed = 4 falloff_im = np.ones(size) for i in range(int(falloff_im.shape[0]/falloff_speed)): #for i in range 100 falloff_im[-i,:] *= (np.cos(4*np.pi*i/size[0]+np.pi)+1)/2# [cos[(4*pi*i/400)+pi] + 1]/2 falloff_im[i,:] *= (np.cos(4*np.pi*i/size[0]+np.pi)+1)/2 # [cos[(4*pi*i/400)+pi] + 1]/2 for j in range(int(falloff_im.shape[1]/falloff_speed)): falloff_im[:,-j] *= (np.cos(4*np.pi*j/size[0]+np.pi)+1)/2 falloff_im[:,j] *= (np.cos(4*np.pi*j/size[0]+np.pi)+1)/2 falloff_im = 1-falloff_im falloff_im = torch.from_numpy(falloff_im).float() if args.cuda: falloff_im = falloff_im.cuda() falloff_im *= l2_edge_gain encoded_image_yuv = color.rgb_to_yuv(encoded_image) image_input_yuv = color.rgb_to_yuv(image_input) im_diff = encoded_image_yuv - image_input_yuv im_diff += im_diff * falloff_im.unsqueeze_(0) yuv_loss = torch.mean((im_diff)**2, axis=[0, 2, 3]) yuv_scales = torch.Tensor(yuv_scales) if args.cuda: yuv_scales = yuv_scales.cuda() image_loss = torch.dot(yuv_loss, yuv_scales) D_loss = D_output_real - D_output_fake G_loss = D_output_fake loss = loss_scales[0] * image_loss + loss_scales[1] * lpips_loss + loss_scales[2] * secret_loss if not args.no_gan: loss += loss_scales[3] * G_loss writer.add_scalar('loss/image_loss', image_loss, global_step) writer.add_scalar('loss/lpips_loss', lpips_loss, global_step) writer.add_scalar('loss/secret_loss', secret_loss, global_step) writer.add_scalar('loss/G_loss', G_loss, global_step) writer.add_scalar('loss/loss', loss, global_step) writer.add_scalar('metric/bit_acc', bit_acc, global_step) writer.add_scalar('metric/str_acc', str_acc, global_step) if global_step % 20 == 0: writer.add_image('input/image_input', image_input[0], global_step) writer.add_image('input/image_warped', input_warped[0], global_step) writer.add_image('encoded/encoded_warped', encoded_warped[0], global_step) writer.add_image('encoded/residual_warped', residual_warped[0] + 0.5, global_step) writer.add_image('encoded/encoded_image', encoded_image[0], global_step) writer.add_image('transformed/transformed_image', transformed_image[0], global_step) writer.add_image('transformed/test', test_transform[0], global_step) return loss, secret_loss, D_loss, bit_acc, str_acc
def loss_function(model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False): output = model({ 'image1': batch['image1'].to(device), 'image2': batch['image2'].to(device) }) loss = torch.tensor(np.array([0], dtype=np.float32), device=device) has_grad = False n_valid_samples = 0 for idx_in_batch in range(batch['image1'].size(0)): # Annotations depth1 = batch['depth1'][idx_in_batch].to(device) # [h1, w1] intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device) # [3, 3] pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device) # [4, 4] bbox1 = batch['bbox1'][idx_in_batch].to(device) # [2] depth2 = batch['depth2'][idx_in_batch].to(device) intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device) pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device) bbox2 = batch['bbox2'][idx_in_batch].to(device) # Network output dense_features1 = output['dense_features1'][idx_in_batch] c, h1, w1 = dense_features1.size() scores1 = output['scores1'][idx_in_batch].view(-1) dense_features2 = output['dense_features2'][idx_in_batch] _, h2, w2 = dense_features2.size() scores2 = output['scores2'][idx_in_batch] all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0) descriptors1 = all_descriptors1 all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0) # Warp the positions from image 1 to image 2 fmap_pos1 = grid_positions(h1, w1, device) hOrig, wOrig = int(batch['image1'].shape[2] / 8), int( batch['image1'].shape[3] / 8) fmap_pos1Orig = grid_positions(hOrig, wOrig, device) pos1 = upscale_positions(fmap_pos1Orig, scaling_steps=scaling_steps) # SIFT Feature Detection imgNp1 = imshow_image(batch['image1'][idx_in_batch].cpu().numpy(), preprocessing=batch['preprocessing']) imgNp1 = cv2.cvtColor(imgNp1, cv2.COLOR_BGR2RGB) # surf = cv2.xfeatures2d.SIFT_create() surf = cv2.xfeatures2d.SURF_create(100) # surf = cv2.ORB_create() kp = surf.detect(imgNp1, None) keyP = [(kp[i].pt) for i in range(len(kp))] keyP = np.asarray(keyP).T keyP[[0, 1]] = keyP[[1, 0]] keyP = np.floor(keyP) + 0.5 pos1 = torch.from_numpy(keyP).to(pos1.device).float() try: pos1, pos2, ids = warp(pos1, depth1, intrinsics1, pose1, bbox1, depth2, intrinsics2, pose2, bbox2) except EmptyTensorError: continue ids = idsAlign(pos1, device, h1, w1) # cv2.drawKeypoints(imgNp1, kp, imgNp1) # cv2.imshow('Keypoints', imgNp1) # cv2.waitKey(0) # drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=False) # exit(1) # Top view homography adjustment H1 = output['H1'][idx_in_batch] H2 = output['H2'][idx_in_batch] try: pos1, pos2 = homoAlign(pos1, pos2, H1, H2, device) except IndexError: continue ids = idsAlign(pos1, device, h1, w1) img_warp1 = tgm.warp_perspective(batch['image1'].to(device), H1, dsize=(400, 400)) img_warp2 = tgm.warp_perspective(batch['image2'].to(device), H2, dsize=(400, 400)) # drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output) # exit(1) fmap_pos1 = fmap_pos1[:, ids] descriptors1 = descriptors1[:, ids] scores1 = scores1[ids] # Skip the pair if not enough GT correspondences are available if ids.size(0) < 128: print(ids.size(0)) continue # Descriptors at the corresponding positions fmap_pos2 = torch.round( downscale_positions(pos2, scaling_steps=scaling_steps)).long() descriptors2 = F.normalize(dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]], dim=0) positive_distance = 2 - 2 * (descriptors1.t().unsqueeze( 1) @ descriptors2.t().unsqueeze(2)).squeeze() # positive_distance = getPositiveDistance(descriptors1, descriptors2) all_fmap_pos2 = grid_positions(h2, w2, device) position_distance = torch.max(torch.abs( fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)), dim=0)[0] is_out_of_safe_radius = position_distance > safe_radius distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2) # distance_matrix = getDistanceMatrix(descriptors1, all_descriptors2) negative_distance2 = torch.min( distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., dim=1)[0] # negative_distance2 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin) all_fmap_pos1 = grid_positions(h1, w1, device) position_distance = torch.max(torch.abs( fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)), dim=0)[0] is_out_of_safe_radius = position_distance > safe_radius distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1) # distance_matrix = getDistanceMatrix(descriptors2, all_descriptors1) negative_distance1 = torch.min( distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., dim=1)[0] # negative_distance1 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin) diff = positive_distance - torch.min(negative_distance1, negative_distance2) scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]] loss = loss + (torch.sum(scores1 * scores2 * F.relu(margin + diff)) / (torch.sum(scores1 * scores2))) has_grad = True n_valid_samples += 1 if plot and batch['batch_idx'] % batch['log_interval'] == 0: # drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=True) drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output, save=True) if not has_grad: raise NoGradientError loss = loss / (n_valid_samples) return loss