def test_rotation_matrix2d(batch_size, device_type): # generate input data device = torch.device(device_type) center_base = torch.zeros(batch_size, 2).to(device) angle_base = torch.ones(batch_size).to(device) scale_base = torch.ones(batch_size).to(device) # 90 deg rotation center = center_base angle = 90. * angle_base scale = scale_base M = kornia.get_rotation_matrix2d(center, angle, scale) for i in range(batch_size): pytest.approx(M[i, 0, 0].item(), 0.0) pytest.approx(M[i, 0, 1].item(), 1.0) pytest.approx(M[i, 1, 0].item(), -1.0) pytest.approx(M[i, 1, 1].item(), 0.0) # 90 deg rotation + 2x scale center = center_base angle = 90. * angle_base scale = 2. * scale_base M = kornia.get_rotation_matrix2d(center, angle, scale) for i in range(batch_size): pytest.approx(M[i, 0, 0].item(), 0.0) pytest.approx(M[i, 0, 1].item(), 2.0) pytest.approx(M[i, 1, 0].item(), -2.0) pytest.approx(M[i, 1, 1].item(), 0.0) # 45 deg rotation center = center_base angle = 45. * angle_base scale = scale_base M = kornia.get_rotation_matrix2d(center, angle, scale) for i in range(batch_size): pytest.approx(M[i, 0, 0].item(), 0.7071) pytest.approx(M[i, 0, 1].item(), 0.7071) pytest.approx(M[i, 1, 0].item(), -0.7071) pytest.approx(M[i, 1, 1].item(), 0.7071) # evaluate function gradient center = utils.tensor_to_gradcheck_var(center) # to var angle = utils.tensor_to_gradcheck_var(angle) # to var scale = utils.tensor_to_gradcheck_var(scale) # to var assert gradcheck(kornia.get_rotation_matrix2d, (center, angle, scale), raise_exception=True)
def test_rotation_matrix2d(batch_size, device, dtype): # generate input data center_base = torch.zeros(batch_size, 2, device=device, dtype=dtype) angle_base = torch.ones(batch_size, device=device, dtype=dtype) scale_base = torch.ones(batch_size, 2, device=device, dtype=dtype) # 90 deg rotation center = center_base angle = 90.0 * angle_base scale = scale_base M = kornia.get_rotation_matrix2d(center, angle, scale) for i in range(batch_size): assert_close(M[i, 0, 0].item(), 0.0, rtol=1e-4, atol=1e-4) assert_close(M[i, 0, 1].item(), 1.0, rtol=1e-4, atol=1e-4) assert_close(M[i, 1, 0].item(), -1.0, rtol=1e-4, atol=1e-4) assert_close(M[i, 1, 1].item(), 0.0, rtol=1e-4, atol=1e-4) # 90 deg rotation + 2x scale center = center_base angle = 90.0 * angle_base scale = 2.0 * scale_base M = kornia.get_rotation_matrix2d(center, angle, scale) for i in range(batch_size): assert_close(M[i, 0, 0].item(), 0.0, rtol=1e-4, atol=1e-4) assert_close(M[i, 0, 1].item(), 2.0, rtol=1e-4, atol=1e-4) assert_close(M[i, 1, 0].item(), -2.0, rtol=1e-4, atol=1e-4) assert_close(M[i, 1, 1].item(), 0.0, rtol=1e-4, atol=1e-4) # 45 deg rotation center = center_base angle = 45.0 * angle_base scale = scale_base M = kornia.get_rotation_matrix2d(center, angle, scale) for i in range(batch_size): assert_close(M[i, 0, 0].item(), 0.7071) assert_close(M[i, 0, 1].item(), 0.7071) assert_close(M[i, 1, 0].item(), -0.7071) assert_close(M[i, 1, 1].item(), 0.7071) # evaluate function gradient center = utils.tensor_to_gradcheck_var(center) # to var angle = utils.tensor_to_gradcheck_var(angle) # to var scale = utils.tensor_to_gradcheck_var(scale) # to var assert gradcheck(kornia.get_rotation_matrix2d, (center, angle, scale), raise_exception=True)
def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be rotated. Size is (C, T, H, W) Returns: torch.tensor: central cropping of video clip. Size is (C, T, crop_size, crop_size) """ clip = clip.permute(1, 0, 2, 3) # define the rotation center center = torch.ones(clip.shape[0], 2) center[..., 0] = clip.shape[3] / 2 # x center[..., 1] = clip.shape[2] / 2 # y # define the scale factor scale = torch.ones(clip.shape[0]) degree = np.random.randint(-self.degree, self.degree + 2) degree = torch.ones(clip.shape[0]) * degree # compute the transformation matrix M = kornia.get_rotation_matrix2d(center, degree, scale) # apply the transformation to original image _, _, h, w = clip.shape clip = kornia.warp_affine(clip, M, dsize=(h, w)) clip = clip.permute(1, 0, 2, 3) return clip
def Raw2Bayer(x, crop_size = cSize, is_rotate = False): r''' Convert FlatCam raw data to Bayer''' # Step 1. Convert the Image & rotate c, b, h, w = x.size() y = torch.zeros((c, 4, int(h/2), int(w/2)), device = torch.device('cuda')) if is_rotate: # ---> THIS MODES DOESNOT WORK YET!!! (2019.07.14) scale = torch.ones(1) angle = torch.ones(1) * 0.05 * 360 # 0.05 is angle collected from data measurements center = torch.ones(1, 2) center[..., 0] = int(h / 4) # x center[..., 1] = int(w / 4) # y M = kr.get_rotation_matrix2d(center, angle, scale).cuda() _, _, h, w = y.size() y[:, 0, :, : ] = kr.warp_affine(x[:, :, 1::2, 1::2], M, dsize = (h, w)) y[:, 1, :, : ] = kr.warp_affine(x[:, :, 0::2, 1::2], M, dsize = (h, w)) y[:, 2, :, : ] = kr.warp_affine(x[:, :, 1::2, 0::2], M, dsize = (h, w)) y[:, 3, :, : ] = kr.warp_affine(x[:, :, 0::2, 0::2], M, dsize = (h, w)) else: y[:, 0, :, : ] = x[:, 0, 1::2, 1::2] y[:, 1, :, : ] = x[:, 0, 0::2, 1::2] y[:, 2, :, : ] = x[:, 0, 1::2, 0::2] y[:, 3, :, : ] = x[:, 0, 0::2, 0::2] # Step 3. Crop the image start_row = int((y.size()[2] - crop_size[0]) / 2) end_row = start_row + crop_size[0] start_col = int((y.size()[3] - crop_size[1])/2) end_col = start_col + crop_size[1] return y[:,:, start_row:end_row, start_col:end_col]
def kornia_affine(im, parameter, aug_type, data_type='data'): ''' Get rotation by given angle or scale by given factor along axis-0 using kornia. (See https://kornia.readthedocs.io/en/latest/geometry.transform.html) ''' center = torch.ones(1, 2).cuda() center[..., 0] = im.shape[1] // 2 center[..., 1] = im.shape[2] // 2 if aug_type == 'rotate': scale = torch.ones(1).cuda() angle = parameter * scale elif aug_type == 'scale': scale = torch.Tensor([parameter]).cuda() angle = 0 * scale # vol_warped = kornia.scale(vol[:, 0, :, :, :], scale, center) if data_type == 'data': interpolation = 'bilinear' elif data_type == 'label': interpolation = 'nearest' M = kornia.get_rotation_matrix2d(center, angle, scale) _, h, w = im.shape im_warped = kornia.warp_affine(im[None, :, :, :].float(), M.cuda(), dsize=(h, w), flags=interpolation) # vol_warped = vol_warped[:, None, :, :, :] return im_warped[0]
def test_rot90(self, device, dtype): angle = torch.tensor([90.0], device=device, dtype=dtype) scale = torch.tensor([[1.0, 1.0]], device=device, dtype=dtype) center = torch.tensor([[0.0, 0.0]], device=device, dtype=dtype) expected = torch.tensor([[[0.0, -1.0, 0.0], [1.0, 0.0, 0.0]]], device=device, dtype=dtype) matrix = kornia.get_rotation_matrix2d(center, angle, scale) matrix_inv = kornia.invert_affine_transform(matrix) assert_close(matrix_inv, expected, rtol=1e-4, atol=1e-4)
def test_rot90(self, device): angle = torch.tensor([90.]).to(device) scale = torch.tensor([[1., 1.]]).to(device) center = torch.tensor([[0., 0.]]).to(device) expected = torch.tensor([[ [0., -1., 0.], [1., 0., 0.], ]]).to(device) matrix = kornia.get_rotation_matrix2d(center, angle, scale) matrix_inv = kornia.invert_affine_transform(matrix) assert_allclose(matrix_inv, expected)
def forward(self, img: torch.Tensor): b, c, h, w = img.shape center = torch.tensor([[w, h]], dtype=torch.float) / 2 transformation_matrix = kornia.get_rotation_matrix2d( center, self.angle, torch.ones(1)) transformation_matrix = transformation_matrix.expand(b, -1, -1) transformation_matrix = transformation_matrix.to(img.device) return kornia.warp_affine(img.float(), transformation_matrix, dsize=(h, w), flags=self.interpolation, padding_mode=self.padding_mode)
def test_rot90_batch(self): angle = torch.tensor([90.]) scale = torch.tensor([1.]) center = torch.tensor([[0., 0.]]) expected = torch.tensor([[ [0., -1., 0.], [1., 0., 0.], ]]) matrix = kornia.get_rotation_matrix2d(center, angle, scale).repeat(2, 1, 1) matrix_inv = kornia.invert_affine_transform(matrix) assert_allclose(matrix_inv, expected)
def test_rot90_batch(self, device, dtype): angle = torch.tensor([90.], device=device, dtype=dtype) scale = torch.tensor([[1., 1.]], device=device, dtype=dtype) center = torch.tensor([[0., 0.]], device=device, dtype=dtype) expected = torch.tensor([[ [0., -1., 0.], [1., 0., 0.], ]], device=device, dtype=dtype) matrix = kornia.get_rotation_matrix2d( center, angle, scale).repeat(2, 1, 1) matrix_inv = kornia.invert_affine_transform(matrix) assert_allclose(matrix_inv, expected, rtol=1e-4, atol=1e-4)
def inner(image_t): b, _, h, w = image_t.shape # kornia takes degrees alpha = _rads2angle(np.random.choice(angles), units) angle = torch.ones(b) * alpha scale = torch.ones(b) center = torch.ones(b, 2) center[..., 0] = (image_t.shape[3] - 1) / 2 center[..., 1] = (image_t.shape[2] - 1) / 2 M = kornia.get_rotation_matrix2d(center, angle, scale).to(device) rotated_image = kornia.warp_affine(image_t.float(), M, dsize=(h, w)) return rotated_image
def Rotate(x, alphas, step): b, c, h, w = x.shape start, end = alphas x = CopyN(x, step) angle = torch.linspace(start, end, step, device=x.device).unsqueeze(0).repeat( b, 1).view(b * step) center = torch.zeros((b * step, 2), device=x.device) center[:, 0], center[:, 1] = w / 2, h / 2 scale = torch.ones(b * step, device=x.device) M = kornia.get_rotation_matrix2d(center, angle, scale) x_hat = kornia.warp_affine(x, M, dsize=(h, w)) return x_hat
def get_all_rotation_matrices(angles, center_image, dtype): device = angles.device center: torch.tensor = torch.ones(1, 2, device=device) center[..., 0] = center_image[0] # x center[..., 1] = center_image[1] # y scale: torch.tensor = torch.ones(1, 2, device=device) t, = angles.shape list_rot_mat = torch.zeros((t, 2, 3), device=device).type(dtype) for i in range(t): theta = torch.ones(1, device=device) * (angles[i]) #alpha = torch.cos(theta) #beta = torch.sin(theta) list_rot_mat[i, :, :] = kornia.get_rotation_matrix2d( center, theta, scale) return list_rot_mat
def forward(self, x, params): angle, scale = params self.angle.fill_(angle) self.scale.fill_(scale) # define the rotation center self.center[..., 0] = x.shape[3] / 2 self.center[..., 1] = x.shape[2] / 2 M = kornia.get_rotation_matrix2d(self.center, self.angle, self.scale) return kornia.warp_affine(x, M, dsize=(x.shape[2], x.shape[3]), flags='bilinear', padding_mode='reflection')
def augment_batch(batch_tensor): orig_size = batch_tensor.ndim batch_size = batch_tensor.shape[0] if orig_size == 3: batch_tensor=batch_tensor.unsqueeze(0) angle = (torch.rand(batch_size)*20)-10 center = torch.ones(batch_size,2) center[..., 0] = batch_tensor.shape[3] / 2 center[..., 1] = batch_tensor.shape[2] / 2 scale = torch.ones(batch_size) M = kornia.get_rotation_matrix2d(center, angle, scale) _, _, h, w = batch_tensor.shape batch_tensor_warped = kornia.warp_affine(batch_tensor, M, dsize=(h, w)) if orig_size == 3: batch_tensor_warped=batch_tensor_warped.squeeze(0) return batch_tensor_warped
def forward(self, img: torch.tensor): b, _, h, w = img.shape # create transformation (rotation) if not self.same_throughout_batch: angle = torch.randn(b, device=img.device) * self.angle else: angle = torch.randn(1, device=img.device) * self.angle angle = angle.repeat(b) center = torch.ones(b, 2, device=img.device) center[..., 0] = img.shape[3] / 2 # x center[..., 1] = img.shape[2] / 2 # y # define the scale factor scale = torch.ones(b, device=img.device) M = kornia.get_rotation_matrix2d(center, angle, scale) img_warped = kornia.warp_affine(img, M, dsize=(h, w)) return img_warped
def test_rotation_inverse(self, device, dtype): h, w = 4, 4 img_b = torch.rand(1, 1, h, w, device=device, dtype=dtype) # create rotation matrix of 90deg (anti-clockwise) center = torch.tensor([[w - 1, h - 1]], device=device, dtype=dtype) / 2 scale = torch.ones((1, 2), device=device, dtype=dtype) angle = 90. * torch.ones(1, device=device, dtype=dtype) aff_ab = kornia.get_rotation_matrix2d(center, angle, scale) # Same as opencv: cv2.getRotationMatrix2D(((w-1)/2,(h-1)/2), 90., 1.) # warp the tensor # Same as opencv: cv2.warpAffine(kornia.tensor_to_image(img_b), aff_ab[0].numpy(), (w, h)) img_a = kornia.warp_affine(img_b, aff_ab, (h, w)) # invert the transform aff_ba = kornia.convert_affinematrix_to_homography(aff_ab).inverse()[..., :2, :] img_b_hat = kornia.warp_affine(img_a, aff_ba, (h, w)) assert_allclose(img_b_hat, img_b, atol=1e-3, rtol=1e-3)
def get_heatmap_transformation_matrix( jitter_x: Tensor, jitter_y: Tensor, scale: Tensor, angle: Tensor, heatmap_dim: Tensor, ) -> Tensor: """ Generates transfromation matric to revert the transformation on heatmap. Args: jitter_x (Tensor): x Pixels by which heatmap should be jittered (batch) jitter_y (Tensor): y Pixels by which heatmap should be jittered (batch) scale (Tensor): Scale factor from crop margin (batch). angle (Tensor): Rotation angle (batch) heatmap_dim (Tensor): Height and width of heatmap (1x2) Returns: [Tensor]: Transformation matrix (batch x 2 x3). """ # Making a translation matrix translations = torch.cat( [jitter_x.view(-1, 1), jitter_y.view(-1, 1)], axis=1).float() origin = torch.zeros_like(translations) zero_angle = torch.zeros_like(jitter_x[:, 0]) unit_scale = torch.ones_like(translations) # NOTE: The function below returns a batch x 3 x 3 matrix. translation_matrix = kornia.get_affine_matrix2d(translations=translations, center=origin, angle=zero_angle, scale=unit_scale) # Making a rotation matrix. center_of_rotation = torch.ones_like(translations) * ( (heatmap_dim / 2).view(1, 2)) # NOTE: The function below returns a batch x 2 x 3 matrix. rotation_matrix = kornia.get_rotation_matrix2d( center=center_of_rotation.float(), angle=angle.float(), scale=scale.repeat(1, 2).float(), ) # Applying transformations in the order. return torch.bmm(rotation_matrix, translation_matrix)
def make_training_batch(input_tensor, patch, patch_mask): # determine patch size H, W = PA_cfg.image_shape[-2:] PATCH_SIZE = int(np.floor(np.sqrt((H*W*PA_cfg.percentage)))) translate_space = [H-PATCH_SIZE+1, W-PATCH_SIZE+1] bs = input_tensor.size(0) training_batch = [] for b in range(bs): # random translation u_t = np.random.randint(low=0, high=translate_space[0]) v_t = np.random.randint(low=0, high=translate_space[1]) # random scaling and rotation scale = np.random.rand() * (PA_cfg.scale_max - PA_cfg.scale_min) + PA_cfg.scale_min scale = torch.Tensor([scale]) angle = np.random.rand() * (PA_cfg.rotate_max - PA_cfg.rotate_min) + PA_cfg.rotate_min angle = torch.Tensor([angle]) center = torch.Tensor([u_t+PATCH_SIZE/2, v_t+PATCH_SIZE/2]).unsqueeze(0) rotation_m = kornia.get_rotation_matrix2d(center, angle, scale) # warp three tensors temp_mask = patch_mask.unsqueeze(0) temp_input = input_tensor[b].unsqueeze(0) temp_patch = patch.unsqueeze(0) temp_mask = kornia.translate(temp_mask.float(), translation=torch.Tensor([u_t, v_t]).unsqueeze(0)) temp_patch = kornia.translate(temp_patch.float(), translation=torch.Tensor([u_t, v_t]).unsqueeze(0)) mask_warpped = kornia.warp_affine(temp_mask.float(), rotation_m, temp_mask.size()[-2:]) patch_warpped = kornia.warp_affine(temp_patch.float(), rotation_m, temp_patch.size()[-2:]) # overlay overlay = temp_input * (1 - mask_warpped) + patch_warpped * mask_warpped training_batch.append(overlay) training_batch = torch.cat(training_batch, dim=0) return training_batch
def rotate_tensor_along_y_axis(tensor, gamma): B = tensor.shape[0] tensor = tensor.to("cpu") assert tensor.ndim == 6, "Tensors should have 6 dimensions." tensor = tensor.float() # B,S,C,D,H,W __p = lambda x: utils_basic.pack_seqdim(x, B) __u = lambda x: utils_basic.unpack_seqdim(x, B) tensor_ = __p(tensor) tensor_ = tensor_.permute( 0, 1, 3, 2, 4) # Make it BS, C, H, D, W (i.e. BS, C, y, z, x) BS, C, H, D, W = tensor_.shape # merge y dimension with channel dimension and rotate with gamma_ tensor_y_reduced = tensor_.reshape(BS, C * H, D, W) # # gammas will be rotation angles along y axis. # gammas = torch.arange(10, 360, 10) # define the rotation center center = torch.ones(1, 2) center[..., 0] = tensor_y_reduced.shape[3] / 2 # x center[..., 1] = tensor_y_reduced.shape[2] / 2 # z # define the scale factor scale = torch.ones(1) gamma_ = torch.ones(1) * gamma # compute the transformation matrix M = kornia.get_rotation_matrix2d(center, gamma_, scale) M = M.repeat(BS, 1, 1) # apply the transformation to original image # st() tensor_y_reduced_warped = kornia.warp_affine(tensor_y_reduced, M, dsize=(D, W)) tensor_y_reduced_warped = tensor_y_reduced_warped.reshape(BS, C, H, D, W) tensor_y_reduced_warped = tensor_y_reduced_warped.permute(0, 1, 3, 2, 4) tensor_y_reduced_warped = __u(tensor_y_reduced_warped) return tensor_y_reduced_warped.cuda()
def _get_rotation(self): """ Get transformation matrices Output: rot_mat (float torch.Tensor): tensor of shape (batch_size, 2, 3) """ # define the rotation center center = torch.ones(self.batch_size, 2) center[..., 0] = self.input_hw[1] / 2 # x center[..., 1] = self.input_hw[0] / 2 # y # create transformation (rotation) angle = torch.tensor([ random.randint(-self.max_angle, self.max_angle) for _ in range(self.batch_size) ]) # define the scale factor scale = torch.ones(self.batch_size) # compute the transformation matrix tf_matrices = kornia.get_rotation_matrix2d(center, angle, scale) return tf_matrices
def align_fake(self, margin=40, alignUnaligned=True): # get params desiredLeftEye = [ float(self.alignment_params["desiredLeftEye"][0]), float(self.alignment_params["desiredLeftEye"][1]) ] rotation_point = self.alignment_params["eyesCenter"] angle = -self.alignment_params["angle"] h, w = self.fake_B.shape[2:] # get original positions m1 = round(w * 0.5) m2 = round(desiredLeftEye[1] * w) # define the scale factor scale = 1 / self.alignment_params["scale"] width = int(self.alignment_params["shape"][0]) long_edge_size = width / abs(np.cos(np.deg2rad(angle))) w_original = int(scale * long_edge_size) h_original = int(scale * long_edge_size) # get offset tX = w_original * 0.5 tY = h_original * desiredLeftEye[1] # get rotation center center = torch.ones(1, 2) center[..., 0] = m1 center[..., 1] = m2 # compute the transformation matrix M: torch.tensor = kornia.get_rotation_matrix2d(center, angle, scale).to(self.device) M[0, 0, 2] += (tX - m1) M[0, 1, 2] += (tY - m2) # get insertion point x_start = int(rotation_point[0] - (0.5 * w_original)) y_start = int(rotation_point[1] - (desiredLeftEye[1] * h_original)) # _, _, h_tensor, w_tensor = self.real_B_unaligned_full.shape # # # # # # # # # # # # # # # # # # ## # # # # # # # ## # # # # ## # # # # # # # # # ## # # # get safe margin h_size_tensor, w_size_tensor = self.real_B_unaligned_full.shape[2:] margin = max( min( y_start - max(0, y_start - margin), x_start - max(0, x_start - margin), min(y_start + h_original + margin, h_size_tensor) - y_start - h_original, min(x_start + w_original + margin, w_size_tensor) - x_start - w_original, ), 0) # get face + margin unaligned space self.real_B_aligned_margin = self.real_B_unaligned_full[:, :, y_start - margin: y_start + h_original + margin, x_start - margin: x_start + w_original + margin] # invert matrix M_inverse = kornia.invert_affine_transform(M) # update output size to fit the 256 + scaled margin old_size = self.real_B_aligned_margin.shape[2] new_size = old_size + 2 * round(float(margin * scale)) _, _, h_tensor, w_tensor = self.real_B_aligned_margin.shape self.real_B_aligned_margin = kornia.warp_affine( self.real_B_aligned_margin, M_inverse, dsize=(new_size, new_size)) # padding_mode="reflection") self.fake_B_aligned_margin = self.real_B_aligned_margin.clone( ).requires_grad_(True) # update margin as we now scale the image! # update start point start = round(float(margin * scale * new_size / old_size)) print(start) # point = torch.tensor([0, 0, 1], dtype=torch.float) # M_ = M_inverse[0].clone().detach() # M_ = torch.cat((M_, torch.tensor([[0, 0, 1]], dtype=torch.float))) # # M_n = M[0].clone().detach() # M_n = torch.cat((M_n, torch.tensor([[0, 0, 1]], dtype=torch.float))) # # start_tensor = torch.matmul(torch.matmul(point, M_) + margin, M_n) # print(start_tensor) # start_y, start_x = round(float(start_tensor[0])), round(float(start_tensor[1])) # reinsert into tensor self.fake_B_aligned_margin[0, :, start:start + 256, start:start + 256] = self.real_B Image.fromarray(tensor2im(self.real_B_aligned_margin)).save( "/home/mo/datasets/ff_aligned_unaligned/real.png") Image.fromarray(tensor2im(self.fake_B_aligned_margin)).save( "/home/mo/datasets/ff_aligned_unaligned/fake.png") exit() # # # # # # # # # # # # # # # # # # ## # # # # # # # ## # # # # ## # # # # # # # # # ## # # if not alignUnaligned: # Now apply the transformation to original image # clone fake fake_B_clone = self.fake_B.clone().requires_grad_(True) # apply warp fake_B_warped: torch.tensor = kornia.warp_affine( fake_B_clone, M, dsize=(h_original, w_original)) # make sure warping does not exceed real_B_unaligned_full dimensions if y_start < 0: fake_B_warped = fake_B_warped[:, :, abs(y_start):h_original, :] h_original += y_start y_start = 0 if x_start < 0: fake_B_warped = fake_B_warped[:, :, :, abs(x_start):w_original] w_original += x_start x_start = 0 if y_start + h_original > h_tensor: h_original -= (y_start + h_original - h_tensor) fake_B_warped = fake_B_warped[:, :, 0:h_original, :] if x_start + w_original > w_tensor: w_original -= (x_start + w_original - w_tensor) fake_B_warped = fake_B_warped[:, :, :, 0:w_original] # create mask that is true where fake_B_warped is 0 # This is the background that is not filled with image after the transformation mask = ((fake_B_warped[0][0] == 0) & (fake_B_warped[0][1] == 0) & (fake_B_warped[0][2] == 0)) # fill fake_B_filled where mask = False with self.real_B_unaligned_full fake_B_filled = torch.where( mask, self.real_B_unaligned_full[:, :, y_start:y_start + h_original, x_start:x_start + w_original], fake_B_warped) # reinsert into tensor self.fake_B_unaligned = self.real_B_unaligned_full.clone( ).requires_grad_(True) mask = torch.zeros_like(self.fake_B_unaligned, dtype=torch.bool) mask[0, :, y_start:y_start + h_original, x_start:x_start + w_original] = True self.fake_B_unaligned = self.fake_B_unaligned.masked_scatter( mask, fake_B_filled) # cutout tensor h_size_tensor, w_size_tensor = self.real_B_unaligned_full.shape[2:] margin = max( min( y_start - max(0, y_start - margin), x_start - max(0, x_start - margin), min(y_start + h_original + margin, h_size_tensor) - y_start - h_original, min(x_start + w_original + margin, w_size_tensor) - x_start - w_original, ), 0) self.fake_B_unaligned = self.fake_B_unaligned[:, :, y_start - margin:y_start + h_original + margin, x_start - margin:x_start + w_original + margin] self.real_B_unaligned = self.real_B_unaligned_full[:, :, y_start - margin:y_start + h_original + margin, x_start - margin:x_start + w_original + margin]
def mod_compute_cube_frame_conv_grad_pytorch(xd, xp, xl, matrix, angles, compute_loss, kernel, mask, center_image, U_L0, stochastic): rank, _ = xl.shape t, _ = matrix.shape if stochastic: #number_of_frames = int(t*stochastic) #list_indexes = np.random.choice(t, number_of_frames, replace=False) number_data = int(np.floor(1. / stochastic)) sub_data_index = np.random.randint(0, number_data) list_indexes = np.arange(t)[sub_data_index::number_data] else: list_indexes = range(t) xs = xd + xp n, _ = xs.shape conv_op = lambda x: A(x, kernel) adj_conv_op = lambda x: A_(x, kernel) if t != angles.shape[0]: print('ANGLES do not have t elements!') class FFTconv_numpy_torch(torch.autograd.Function): @staticmethod def forward(ctx, input): numpy_input = input.detach().numpy() result = conv_op(numpy_input) return input.new(result) @staticmethod def backward(ctx, grad_output): numpy_go = grad_output.numpy() result = adj_conv_op(numpy_go) return grad_output.new(result) def fft_conv_np_torch(input): return FFTconv_numpy_torch.apply(input) # needed for rotation: center: torch.tensor = torch.ones(1, 2) center[..., 0] = center_image[0] # x center[..., 1] = center_image[1] # y scale: torch.tensor = torch.ones(1, 2) torch_xs = torch.tensor([[xs]], requires_grad=True) torch_data = torch.tensor(matrix.reshape(t, n, n)) torch_L = torch.tensor(xl, requires_grad=True) torch_U_L0 = torch.tensor(U_L0, requires_grad=False) loss: torch.tensor = torch.zeros(1, requires_grad=True) for k in list_indexes: angle: torch.tensor = torch.ones(1) * (angles[k]) M: torch.tensor = kornia.get_rotation_matrix2d(center, angle, scale) rotated_xs = kornia.warp_affine(torch_xs.float(), M, dsize=(n, n)) loss = loss + compute_loss( fft_conv_np_torch(rotated_xs[0, 0, :, :]) + torch.mm(torch_U_L0[None, k, :], torch_L).reshape(n, n) - torch_data[k, :, :]) loss.backward() torch_grad_xs = torch_xs.grad torch_grad_L = torch_L.grad np_grad_xs = torch_grad_xs[0, 0, :, :].detach().numpy() np_grad_L = torch_grad_L.detach().numpy() grad_d_p = np_grad_xs * mask return grad_d_p, grad_d_p, np_grad_L.reshape(rank, n * n), loss.detach().numpy()
def validate_translation(template_trans, source_trans, groundTruth_number, scale_gt, gt_trans, model_template_trans, model_source_trans, model_corr2softmax_trans, acc_x, acc_y, device ): print(" ") print(" VALIDATING TRANSLATION") print(" ") # # for toy dataset # b, c, h, w = source_trans.shape # center = torch.ones(b,2).to(device) # center[:, 0] = h // 2 # center[:, 1] = w // 2 # angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device)) # scale_rot = torch.ones(b).to(device) # rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot) # source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w)) # for AGDatase since = time.time() b, c, h, w = source_trans.shape center = torch.ones(b,2).to(device) center[:, 0] = h // 2 center[:, 1] = w // 2 angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device)) scale_rot = torch.ones(b).to(device) / scale_gt.to(device) rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot) source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w)) # imshow(template_trans[0,:,:,:]) # plt.show() # imshow(source_trans[0,:,:,:]) # plt.show() # imshow(template,"temp") # imshow(source, "src") template_unet_trans = model_template_trans(template_trans) source_unet_trans = model_source_trans(source_trans) # imshow(template_unet_trans[0,:,:,:]) # plt.show() # imshow(source_unet_trans[0,:,:,:]) # plt.show() # for tensorboard visualize template_visual_trans = template_unet_trans source_visual_trans = source_unet_trans template_unet_trans = template_unet_trans.permute(0,2,3,1) source_unet_trans = source_unet_trans.permute(0,2,3,1) template_unet_trans = template_unet_trans.squeeze(-1) source_unet_trans = source_unet_trans.squeeze(-1) (b, h, w) = template_unet_trans.shape logbase_trans = torch.tensor(1.) phase_corr_layer_xy = PhaseCorr(device, logbase_trans, model_corr2softmax_trans, trans=True) t0, t1, softmax_result_trans, corr_result_trans = phase_corr_layer_xy(template_unet_trans.to(device), source_unet_trans.to(device)) # use phasecorr result corr_final_trans = corr_result_trans.clone() # corr_visual = corr_final_trans.unsqueeze(-1) # corr_visual = corr_visual.permute(0,3,1,2) corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False) # corr_2d = corr_final_trans.clone().reshape(b, h*w) # corr_2d = model_corr2softmax(corr_2d) corr_y = model_corr2softmax_trans(corr_y) input_c = nn.functional.softmax(corr_y.clone(), dim=-1) indices_c = np.linspace(0, 1, 256) indices_c = torch.tensor(np.reshape(indices_c, (-1, 256))).to(device) tranformation_y = torch.sum((256 - 1) * input_c * indices_c, dim=-1) tranformation_y_show = torch.argmax(corr_y, dim=-1) corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False) # corr_final_trans = corr_final_trans.reshape(b, h*w) corr_x = model_corr2softmax_trans(corr_x) input_r = nn.functional.softmax(corr_x.clone(), dim=-1) indices_r = np.linspace(0, 1, 256) indices_r = torch.tensor(np.reshape(indices_r, (-1, 256))).to(device) tranformation_x_show = torch.argmax(corr_x, dim=-1) tranformation_x = torch.sum((256 - 1) * input_r * indices_r, dim=-1) time_elapsed = time.time() - since print("time elapsed", time_elapsed) print('in val time {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) # only consider angle as the los # softmax_result = torch.sum(corr_result.clone(), 2, keepdim=False) # softmax_final = softmax_result.clone() # # softmax_visual = softmax_final.unsqueeze(-1) # # softmax_visual = softmax_visual.permute(0,3,1,2) # softmax_final = softmax_final.reshape(b, h*w) # softmax_final = model_corr2softmax(softmax_final.clone()) gt_trans_orig = gt_trans.clone().to(device) # err_true = (1-((t0-gt_trans_orig).abs()/(gt_trans_orig+1e-15))).mean() # if err_true <= 0: # err_true = torch.Tensor([0.5]) # print("err_true = ",err_true.item()*100,"%") gt_trans_convert = GT_trans_convert(gt_trans_orig, [256, 256]) gt_trans_convert_y = gt_trans_convert[:,0] gt_trans_convert_x = gt_trans_convert[:,1] print("trans x", tranformation_x_show) print("trans y", tranformation_y_show) print("gt_convert x", gt_trans_convert_x) print("gt_convert y", gt_trans_convert_y) # err_y = (1-(abs(tranformation_y_show-gt_trans_convert_y))/(gt_trans_convert_y+1e-10)).mean() # if err_y <= 0: # err_y = torch.Tensor([0.5]) # print("err y", err_y.item()*100,"%") # err_x = (1-(abs(tranformation_x_show-gt_trans_convert_x))/(gt_trans_convert_x+1e-10)).mean() # if err_x <= 0: # err_x = torch.Tensor([0.5]) # print("err x", err_x.item()*100,"%") arg_x = [] corr_top5 = corr_x.clone().detach().cpu() for i in range(2): max_ind = torch.argmax(corr_top5, dim=-1) arg_x.append(max_ind) for batch_num in range(b): corr_top5[batch_num, max_ind[batch_num]] = 0. for batch_num in range(b): min_x = 100. for i in range(2): if abs(gt_trans_convert_x[batch_num].float() - torch.round(arg_x[i][batch_num].float())) < min_x: min_x = abs(gt_trans_convert_x[batch_num].float() - torch.round(arg_x[i][batch_num].float())) for class_id in range(20): for batch_num in range(b): if min_x <= class_id: acc_x[class_id] += 1 arg_y = [] corr_top5 = corr_y.clone().detach().cpu() for i in range(2): max_ind = torch.argmax(corr_top5, dim=-1) arg_y.append(max_ind) for batch_num in range(b): corr_top5[batch_num, max_ind[batch_num]] = 0. for batch_num in range(b): min_y = 100. for i in range(2): if abs(gt_trans_convert_y[batch_num] - torch.round(arg_y[i][batch_num].float())) < min_y: min_y = abs(gt_trans_convert_y[batch_num] - torch.round(arg_y[i][batch_num].float())) for class_id in range(20): for batch_num in range(b): if min_y <= class_id: acc_y[class_id] += 1 # set the loss function: # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device) # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device) compute_loss_y = torch.nn.CrossEntropyLoss(reduction="sum").to(device) compute_loss_x = torch.nn.CrossEntropyLoss(reduction="sum").to(device) # mse_loss = torch.nn.MSELoss(reduce=True) # compute_loss = torch.nn.NLLLoss() # compute_l1loss_a = torch.nn.L1Loss() # compute_l1loss_x = torch.nn.L1Loss() # compute_l1loss_y = torch.nn.L1Loss() compute_mse = torch.nn.MSELoss() compute_l1 = torch.nn.L1Loss() # mse_loss = mse_loss(rotation_cal.float(), groundTruth_number.float()) # l1loss = compute_l1loss(rotation_cal, groundTruth_number) loss_l1_x = compute_l1(tranformation_x, gt_trans_convert_x) loss_l1_y = compute_l1(tranformation_y, gt_trans_convert_y) loss_mse_x = compute_mse(tranformation_x, gt_trans_convert_x) loss_mse_y = compute_mse(tranformation_y, gt_trans_convert_y) loss_x = compute_loss_x(corr_x, gt_trans_convert_x) loss_y = compute_loss_y(corr_y, gt_trans_convert_y) total_loss = loss_x + loss_y + loss_l1_x +loss_l1_y return loss_y, loss_x, total_loss, loss_l1_x,loss_l1_y,loss_mse_x, loss_mse_y
data: torch.tensor = kornia.image_to_tensor(img) # BxCxHxW # create transformation (rotation) alpha: float = 45.0 # in degrees angle: torch.tensor = torch.ones(1) * alpha # define the rotation center center: torch.tensor = torch.ones(1, 2) center[..., 0] = data.shape[3] / 2 # x center[..., 1] = data.shape[2] / 2 # y # define the scale factor scale: torch.tensor = torch.ones(1) # compute the transformation matrix M: torch.tensor = kornia.get_rotation_matrix2d(center, angle, scale) # apply the transformation to original image _, _, h, w = data.shape data_warped: torch.tensor = kornia.warp_affine(data.float(), M, dsize=(h, w)) # convert back to numpy img_warped: np.ndarray = kornia.tensor_to_image(data_warped.byte()[0]) # create the plot fig, axs = plt.subplots(1, 2, figsize=(16, 10)) axs = axs.ravel() axs[0].axis('off') axs[0].set_title('image source') axs[0].imshow(img)
def train_translation(template_trans, source_trans, groundTruth_number, scale_gt, gt_trans, model_template_trans, model_source_trans, model_corr2softmax_trans, phase, dsnt, device): print(" ") print(" TRAINING TRANSLATION") print(" ") with torch.set_grad_enabled(phase == 'train'): # # for toy dataset # b, c, h, w = source_trans.shape # center = torch.ones(b,2).to(device) # center[:, 0] = h // 2 # center[:, 1] = w // 2 # angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device)) # scale_rot = torch.ones(b).to(device) # rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot) # source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w)) # for AGDatase b, c, h, w = source_trans.shape center = torch.ones(b, 2).to(device) center[:, 0] = h // 2 center[:, 1] = w // 2 angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device)) scale_rot = torch.ones(b).to(device) / scale_gt.to(device) rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot) source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w)) # imshow(template,"temp") # imshow(source, "src") template_unet_trans = model_template_trans(template_trans) source_unet_trans = model_source_trans(source_trans) # for tensorboard visualize template_visual_trans = template_unet_trans source_visual_trans = source_unet_trans template_unet_trans = template_unet_trans.permute(0, 2, 3, 1) source_unet_trans = source_unet_trans.permute(0, 2, 3, 1) template_unet_trans = template_unet_trans.squeeze(-1) source_unet_trans = source_unet_trans.squeeze(-1) (b, h, w) = template_unet_trans.shape logbase_trans = torch.tensor(1.) phase_corr_layer_xy = PhaseCorr(device, logbase_trans, model_corr2softmax_trans, trans=True) t0, t1, softmax_result_trans, corr_result_trans = phase_corr_layer_xy( template_unet_trans.to(device), source_unet_trans.to(device)) # use phasecorr result if not dsnt: corr_final_trans = corr_result_trans.clone() # corr_visual = corr_final_trans.unsqueeze(-1) # corr_visual = corr_visual.permute(0,3,1,2) corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False) # corr_2d = corr_final_trans.clone().reshape(b, h*w) # corr_2d = model_corr2softmax(corr_2d) corr_y = model_corr2softmax_trans(corr_y) input_c = nn.functional.softmax(corr_y.clone(), dim=-1) indices_c = np.linspace(0, 1, 256) indices_c = torch.tensor(np.reshape(indices_c, (-1, 256))).to(device) tranformation_y = torch.sum((256 - 1) * input_c * indices_c, dim=-1) # tranformation_y = torch.argmax(corr_y, dim=-1) corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False) # corr_final_trans = corr_final_trans.reshape(b, h*w) corr_x = model_corr2softmax_trans(corr_x) input_r = nn.functional.softmax(corr_x.clone(), dim=-1) indices_r = np.linspace(0, 1, 256) indices_r = torch.tensor(np.reshape(indices_r, (-1, 256))).to(device) # tranformation_x = torch.argmax(corr_x, dim=-1) tranformation_x = torch.sum((256 - 1) * input_r * indices_r, dim=-1) # only consider angle as the los # softmax_result = torch.sum(corr_result.clone(), 2, keepdim=False) # softmax_final = softmax_result.clone() # # softmax_visual = softmax_final.unsqueeze(-1) # # softmax_visual = softmax_visual.permute(0,3,1,2) # softmax_final = softmax_final.reshape(b, h*w) # softmax_final = model_corr2softmax(softmax_final.clone()) gt_trans_orig = gt_trans.clone().to(device) # print("err_true = ",err_true.item()*100,"%") gt_trans_convert = GT_trans_convert(gt_trans_orig, [256, 256]) gt_trans_convert_y = gt_trans_convert[:, 0] gt_trans_convert_x = gt_trans_convert[:, 1] print("trans x", tranformation_x) print("gt_convert x", gt_trans_convert_x, "\n") print("trans y", tranformation_y) print("gt_convert y", gt_trans_convert_y, "\n") # set the loss function: # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device) # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device) compute_loss_y = torch.nn.CrossEntropyLoss( reduction="sum").to(device) compute_loss_x = torch.nn.CrossEntropyLoss( reduction="sum").to(device) # mse_loss = torch.nn.MSELoss(reduce=True) # compute_loss = torch.nn.NLLLoss() # compute_l1loss_a = torch.nn.L1Loss() # compute_l1loss_x = torch.nn.L1Loss() compute_mse = torch.nn.MSELoss() compute_l1 = torch.nn.L1Loss() # mse_loss = mse_loss(rotation_cal.float(), groundTruth_number.float()) loss_l1_x = compute_l1(tranformation_x, gt_trans_convert_x) loss_l1_y = compute_l1(tranformation_y, gt_trans_convert_y) loss_mse_x = compute_mse(tranformation_x, gt_trans_convert_x) loss_mse_y = compute_mse(tranformation_y, gt_trans_convert_y) loss_x = compute_loss_x(corr_x, gt_trans_convert_x) loss_y = compute_loss_y(corr_y, gt_trans_convert_y) total_loss = loss_x + loss_y + loss_l1_x + loss_l1_y return loss_y, loss_x, total_loss, loss_l1_x, loss_l1_y, loss_mse_x, loss_mse_y, template_visual_trans, source_visual_trans else: corr_final_trans = corr_result_trans.clone() # corr_visual = corr_final_trans.unsqueeze(-1) # corr_visual = corr_visual.permute(0,3,1,2) corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False) # corr_2d = corr_final_trans.clone().reshape(b, h*w) # corr_2d = model_corr2softmax(corr_2d) corr_y = model_corr2softmax_trans(corr_y) corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False) # corr_final_trans = corr_final_trans.reshape(b, h*w) corr_x = model_corr2softmax_trans(corr_x) corr_mat_dsnt_trans = corr_result_trans.clone().unsqueeze(-1) corr_mat_dsnt_trans_final = model_corr2softmax_trans( corr_mat_dsnt_trans) corr_mat_dsnt_trans_final = kornia.spatial_softmax2d( corr_mat_dsnt_trans_final) coors_trans = kornia.spatial_expectation2d( corr_mat_dsnt_trans_final, False) tranformation_x = coors_trans[:, 0, 0] tranformation_y = coors_trans[:, 0, 1] # only consider angle as the los # softmax_result = torch.sum(corr_result.clone(), 2, keepdim=False) # softmax_final = softmax_result.clone() # # softmax_visual = softmax_final.unsqueeze(-1) # # softmax_visual = softmax_visual.permute(0,3,1,2) # softmax_final = softmax_final.reshape(b, h*w) # softmax_final = model_corr2softmax(softmax_final.clone()) gt_trans_orig = gt_trans.clone().to(device) # print("err_true = ",err_true.item()*100,"%") gt_trans_convert = GT_trans_convert(gt_trans_orig, [256, 256]) gt_trans_convert_y = gt_trans_convert[:, 0] gt_trans_convert_x = gt_trans_convert[:, 1] print("trans x", tranformation_x) print("gt_convert x", gt_trans_convert_x, "\n") print("trans y", tranformation_y) print("gt_convert y", gt_trans_convert_y, "\n") # set the loss function: # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device) # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device) compute_loss_y = torch.nn.CrossEntropyLoss( reduction="sum").to(device) compute_loss_x = torch.nn.CrossEntropyLoss( reduction="sum").to(device) # mse_loss = torch.nn.MSELoss(reduce=True) # compute_loss = torch.nn.NLLLoss() # compute_l1loss_a = torch.nn.L1Loss() # compute_l1loss_x = torch.nn.L1Loss() compute_mse = torch.nn.MSELoss() compute_l1 = torch.nn.L1Loss() # mse_loss = mse_loss(rotation_cal.float(), groundTruth_number.float()) loss_l1_x = compute_l1(tranformation_x, gt_trans_convert_x) loss_l1_y = compute_l1(tranformation_y, gt_trans_convert_y) loss_mse_x = compute_mse( tranformation_x, gt_trans_convert_x.type(torch.FloatTensor).to(device)) loss_mse_y = compute_mse( tranformation_y, gt_trans_convert_y.type(torch.FloatTensor).to(device)) loss_x = compute_loss_x(corr_x, gt_trans_convert_x) loss_y = compute_loss_y(corr_y, gt_trans_convert_y) total_loss = 0.001 * (loss_mse_x + loss_mse_y) return loss_y, loss_x, total_loss, loss_l1_x, loss_l1_y, loss_mse_x, loss_mse_y, template_visual_trans, source_visual_trans
def detect_translation(template_trans, source_trans, rotation, scale, model_template_trans, model_source_trans, model_corr2softmax_trans, device ): print(" ") print(" DETECTING TRANSLATION") print(" ") # for AGDatase b, c, h, w = source_trans.shape center = torch.ones(b,2).to(device) center[:, 0] = h // 2 center[:, 1] = w // 2 angle_rot = torch.ones(b).to(device) * (-rotation.to(device)) scale_rot = torch.ones(b).to(device) * (1/scale.to(device)) rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot) source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w)) # imshow(template_trans[0,:,:]) # time.sleep(2) # imshow(source_trans[0,:,:]) # time.sleep(2) # imshow(template,"temp") # imshow(source, "src") template_unet_trans = model_template_trans(template_trans) source_unet_trans = model_source_trans(source_trans) # for tensorboard visualize template_visual_trans = template_unet_trans source_visual_trans = source_unet_trans template_unet_trans = template_unet_trans.permute(0,2,3,1) source_unet_trans = source_unet_trans.permute(0,2,3,1) template_unet_trans = template_unet_trans.squeeze(-1) source_unet_trans = source_unet_trans.squeeze(-1) (b, h, w) = template_unet_trans.shape logbase_trans = torch.tensor(1.) phase_corr_layer_xy = PhaseCorr(device, logbase_trans, model_corr2softmax_trans) t0, t1, softmax_result_trans, corr_result_trans = phase_corr_layer_xy(template_unet_trans.to(device), source_unet_trans.to(device)) # use phasecorr result corr_final_trans = corr_result_trans.clone() # corr_visual = corr_final_trans.unsqueeze(-1) # corr_visual = corr_visual.permute(0,3,1,2) corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False) # corr_2d = corr_final_trans.clone().reshape(b, h*w) # corr_2d = model_corr2softmax(corr_2d) corr_y = model_corr2softmax_trans(corr_y) input_c = nn.functional.softmax(corr_y.clone(), dim=-1) indices_c = np.linspace(0, 1, 256) indices_c = torch.tensor(np.reshape(indices_c, (-1, 256))).to(device) transformation_y = torch.sum((256 - 1) * input_c * indices_c, dim=-1) # transformation_y = torch.argmax(corr_y, dim=-1) corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False) # corr_final_trans = corr_final_trans.reshape(b, h*w) corr_x = model_corr2softmax_trans(corr_x) input_r = nn.functional.softmax(corr_x.clone(), dim=-1) indices_r = np.linspace(0, 1, 256) indices_r = torch.tensor(np.reshape(indices_r, (-1, 256))).to(device) # transformation_x = torch.argmax(corr_x, dim=-1) transformation_x = torch.sum((256 - 1) * input_r * indices_r, dim=-1) print("trans x", transformation_x) print("trans y", transformation_y) trans_mat_affine = torch.Tensor([[[1.0,0.0,transformation_x-128.0],[0.0,1.0,transformation_y-128.0]]]).to(device) template_trans = kornia.warp_affine(template_trans.to(device), trans_mat_affine, dsize=(h, w)) image_aligned = align_image(template_trans[0,:,:], source_trans[0,:,:]) # imshow(template_trans[0,:,:]) # time.sleep(2) # imshow(source_trans[0,:,:]) # time.sleep(2) return transformation_y, transformation_x, image_aligned, source_trans