Beispiel #1
0
def test_rotation_matrix2d(batch_size, device_type):
    # generate input data
    device = torch.device(device_type)
    center_base = torch.zeros(batch_size, 2).to(device)
    angle_base = torch.ones(batch_size).to(device)
    scale_base = torch.ones(batch_size).to(device)

    # 90 deg rotation
    center = center_base
    angle = 90. * angle_base
    scale = scale_base
    M = kornia.get_rotation_matrix2d(center, angle, scale)

    for i in range(batch_size):
        pytest.approx(M[i, 0, 0].item(), 0.0)
        pytest.approx(M[i, 0, 1].item(), 1.0)
        pytest.approx(M[i, 1, 0].item(), -1.0)
        pytest.approx(M[i, 1, 1].item(), 0.0)

    # 90 deg rotation + 2x scale
    center = center_base
    angle = 90. * angle_base
    scale = 2. * scale_base
    M = kornia.get_rotation_matrix2d(center, angle, scale)

    for i in range(batch_size):
        pytest.approx(M[i, 0, 0].item(), 0.0)
        pytest.approx(M[i, 0, 1].item(), 2.0)
        pytest.approx(M[i, 1, 0].item(), -2.0)
        pytest.approx(M[i, 1, 1].item(), 0.0)

    # 45 deg rotation
    center = center_base
    angle = 45. * angle_base
    scale = scale_base
    M = kornia.get_rotation_matrix2d(center, angle, scale)

    for i in range(batch_size):
        pytest.approx(M[i, 0, 0].item(), 0.7071)
        pytest.approx(M[i, 0, 1].item(), 0.7071)
        pytest.approx(M[i, 1, 0].item(), -0.7071)
        pytest.approx(M[i, 1, 1].item(), 0.7071)

    # evaluate function gradient
    center = utils.tensor_to_gradcheck_var(center)  # to var
    angle = utils.tensor_to_gradcheck_var(angle)  # to var
    scale = utils.tensor_to_gradcheck_var(scale)  # to var
    assert gradcheck(kornia.get_rotation_matrix2d, (center, angle, scale),
                     raise_exception=True)
Beispiel #2
0
def test_rotation_matrix2d(batch_size, device, dtype):
    # generate input data
    center_base = torch.zeros(batch_size, 2, device=device, dtype=dtype)
    angle_base = torch.ones(batch_size, device=device, dtype=dtype)
    scale_base = torch.ones(batch_size, 2, device=device, dtype=dtype)

    # 90 deg rotation
    center = center_base
    angle = 90.0 * angle_base
    scale = scale_base
    M = kornia.get_rotation_matrix2d(center, angle, scale)

    for i in range(batch_size):
        assert_close(M[i, 0, 0].item(), 0.0, rtol=1e-4, atol=1e-4)
        assert_close(M[i, 0, 1].item(), 1.0, rtol=1e-4, atol=1e-4)
        assert_close(M[i, 1, 0].item(), -1.0, rtol=1e-4, atol=1e-4)
        assert_close(M[i, 1, 1].item(), 0.0, rtol=1e-4, atol=1e-4)

    # 90 deg rotation + 2x scale
    center = center_base
    angle = 90.0 * angle_base
    scale = 2.0 * scale_base
    M = kornia.get_rotation_matrix2d(center, angle, scale)

    for i in range(batch_size):
        assert_close(M[i, 0, 0].item(), 0.0, rtol=1e-4, atol=1e-4)
        assert_close(M[i, 0, 1].item(), 2.0, rtol=1e-4, atol=1e-4)
        assert_close(M[i, 1, 0].item(), -2.0, rtol=1e-4, atol=1e-4)
        assert_close(M[i, 1, 1].item(), 0.0, rtol=1e-4, atol=1e-4)

    # 45 deg rotation
    center = center_base
    angle = 45.0 * angle_base
    scale = scale_base
    M = kornia.get_rotation_matrix2d(center, angle, scale)

    for i in range(batch_size):
        assert_close(M[i, 0, 0].item(), 0.7071)
        assert_close(M[i, 0, 1].item(), 0.7071)
        assert_close(M[i, 1, 0].item(), -0.7071)
        assert_close(M[i, 1, 1].item(), 0.7071)

    # evaluate function gradient
    center = utils.tensor_to_gradcheck_var(center)  # to var
    angle = utils.tensor_to_gradcheck_var(angle)  # to var
    scale = utils.tensor_to_gradcheck_var(scale)  # to var
    assert gradcheck(kornia.get_rotation_matrix2d, (center, angle, scale),
                     raise_exception=True)
Beispiel #3
0
    def __call__(self, clip):
        """
    Args:
        clip (torch.tensor): Video clip to be rotated. Size is (C, T, H, W)
    Returns:
        torch.tensor: central cropping of video clip. Size is
        (C, T, crop_size, crop_size)
    """
        clip = clip.permute(1, 0, 2, 3)
        # define the rotation center
        center = torch.ones(clip.shape[0], 2)
        center[..., 0] = clip.shape[3] / 2  # x
        center[..., 1] = clip.shape[2] / 2  # y

        # define the scale factor
        scale = torch.ones(clip.shape[0])
        degree = np.random.randint(-self.degree, self.degree + 2)
        degree = torch.ones(clip.shape[0]) * degree

        # compute the transformation matrix
        M = kornia.get_rotation_matrix2d(center, degree, scale)

        # apply the transformation to original image
        _, _, h, w = clip.shape
        clip = kornia.warp_affine(clip, M, dsize=(h, w))
        clip = clip.permute(1, 0, 2, 3)
        return clip
Beispiel #4
0
def Raw2Bayer(x, crop_size = cSize, is_rotate = False):
    r''' Convert FlatCam raw data to Bayer'''
    
    # Step 1. Convert the Image & rotate 
    c, b, h, w = x.size()
    
    y = torch.zeros((c, 4, int(h/2), int(w/2)), device = torch.device('cuda'))

    if is_rotate:                       # ---> THIS MODES DOESNOT WORK YET!!! (2019.07.14)
        scale = torch.ones(1)
        angle = torch.ones(1) * 0.05 * 360              # 0.05 is angle collected from data measurements 
        center = torch.ones(1, 2)
        center[..., 0] = int(h / 4)  # x
        center[..., 1] = int(w / 4)  # y
        M = kr.get_rotation_matrix2d(center, angle, scale).cuda()
        _, _, h, w = y.size()
        
        y[:, 0, :, : ] = kr.warp_affine(x[:, :, 1::2, 1::2], M, dsize = (h, w))
        y[:, 1, :, : ] = kr.warp_affine(x[:, :, 0::2, 1::2], M, dsize = (h, w))
        y[:, 2, :, : ] = kr.warp_affine(x[:, :, 1::2, 0::2], M, dsize = (h, w))
        y[:, 3, :, : ] = kr.warp_affine(x[:, :, 0::2, 0::2], M, dsize = (h, w))

    else:
        y[:, 0, :, : ] = x[:, 0, 1::2, 1::2]
        y[:, 1, :, : ] = x[:, 0, 0::2, 1::2]
        y[:, 2, :, : ] = x[:, 0, 1::2, 0::2]
        y[:, 3, :, : ] = x[:, 0, 0::2, 0::2]

    # Step 3. Crop the image 
    start_row = int((y.size()[2] - crop_size[0]) / 2) 
    end_row = start_row + crop_size[0]
    start_col = int((y.size()[3] - crop_size[1])/2) 
    end_col = start_col + crop_size[1] 
    return y[:,:, start_row:end_row, start_col:end_col]
def kornia_affine(im, parameter, aug_type, data_type='data'):
    '''
    Get rotation by given angle or scale by given factor
    along axis-0 using kornia.
    (See https://kornia.readthedocs.io/en/latest/geometry.transform.html)
    '''
    center = torch.ones(1, 2).cuda()
    center[..., 0] = im.shape[1] // 2
    center[..., 1] = im.shape[2] // 2
    if aug_type == 'rotate':
        scale = torch.ones(1).cuda()
        angle = parameter * scale
    elif aug_type == 'scale':
        scale = torch.Tensor([parameter]).cuda()
        angle = 0 * scale
        # vol_warped = kornia.scale(vol[:, 0, :, :, :], scale, center)
    if data_type == 'data':
        interpolation = 'bilinear'
    elif data_type == 'label':
        interpolation = 'nearest'
    M = kornia.get_rotation_matrix2d(center, angle, scale)
    _, h, w = im.shape
    im_warped = kornia.warp_affine(im[None, :, :, :].float(),
                                   M.cuda(),
                                   dsize=(h, w),
                                   flags=interpolation)
    # vol_warped = vol_warped[:, None, :, :, :]
    return im_warped[0]
Beispiel #6
0
 def test_rot90(self, device, dtype):
     angle = torch.tensor([90.0], device=device, dtype=dtype)
     scale = torch.tensor([[1.0, 1.0]], device=device, dtype=dtype)
     center = torch.tensor([[0.0, 0.0]], device=device, dtype=dtype)
     expected = torch.tensor([[[0.0, -1.0, 0.0], [1.0, 0.0, 0.0]]],
                             device=device,
                             dtype=dtype)
     matrix = kornia.get_rotation_matrix2d(center, angle, scale)
     matrix_inv = kornia.invert_affine_transform(matrix)
     assert_close(matrix_inv, expected, rtol=1e-4, atol=1e-4)
Beispiel #7
0
 def test_rot90(self, device):
     angle = torch.tensor([90.]).to(device)
     scale = torch.tensor([[1., 1.]]).to(device)
     center = torch.tensor([[0., 0.]]).to(device)
     expected = torch.tensor([[
         [0., -1., 0.],
         [1., 0., 0.],
     ]]).to(device)
     matrix = kornia.get_rotation_matrix2d(center, angle, scale)
     matrix_inv = kornia.invert_affine_transform(matrix)
     assert_allclose(matrix_inv, expected)
Beispiel #8
0
 def forward(self, img: torch.Tensor):
     b, c, h, w = img.shape
     center = torch.tensor([[w, h]], dtype=torch.float) / 2
     transformation_matrix = kornia.get_rotation_matrix2d(
         center, self.angle, torch.ones(1))
     transformation_matrix = transformation_matrix.expand(b, -1, -1)
     transformation_matrix = transformation_matrix.to(img.device)
     return kornia.warp_affine(img.float(),
                               transformation_matrix,
                               dsize=(h, w),
                               flags=self.interpolation,
                               padding_mode=self.padding_mode)
Beispiel #9
0
 def test_rot90_batch(self):
     angle = torch.tensor([90.])
     scale = torch.tensor([1.])
     center = torch.tensor([[0., 0.]])
     expected = torch.tensor([[
         [0., -1., 0.],
         [1., 0., 0.],
     ]])
     matrix = kornia.get_rotation_matrix2d(center, angle,
                                           scale).repeat(2, 1, 1)
     matrix_inv = kornia.invert_affine_transform(matrix)
     assert_allclose(matrix_inv, expected)
Beispiel #10
0
 def test_rot90_batch(self, device, dtype):
     angle = torch.tensor([90.], device=device, dtype=dtype)
     scale = torch.tensor([[1., 1.]], device=device, dtype=dtype)
     center = torch.tensor([[0., 0.]], device=device, dtype=dtype)
     expected = torch.tensor([[
         [0., -1., 0.],
         [1., 0., 0.],
     ]], device=device, dtype=dtype)
     matrix = kornia.get_rotation_matrix2d(
         center, angle, scale).repeat(2, 1, 1)
     matrix_inv = kornia.invert_affine_transform(matrix)
     assert_allclose(matrix_inv, expected, rtol=1e-4, atol=1e-4)
Beispiel #11
0
 def inner(image_t):
     b, _, h, w = image_t.shape
     # kornia takes degrees
     alpha = _rads2angle(np.random.choice(angles), units)
     angle = torch.ones(b) * alpha
     scale = torch.ones(b)
     center = torch.ones(b, 2)
     center[..., 0] = (image_t.shape[3] - 1) / 2
     center[..., 1] = (image_t.shape[2] - 1) / 2
     M = kornia.get_rotation_matrix2d(center, angle, scale).to(device)
     rotated_image = kornia.warp_affine(image_t.float(), M, dsize=(h, w))
     return rotated_image
def Rotate(x, alphas, step):
    b, c, h, w = x.shape
    start, end = alphas
    x = CopyN(x, step)
    angle = torch.linspace(start,
                           end, step, device=x.device).unsqueeze(0).repeat(
                               b, 1).view(b * step)
    center = torch.zeros((b * step, 2), device=x.device)
    center[:, 0], center[:, 1] = w / 2, h / 2
    scale = torch.ones(b * step, device=x.device)
    M = kornia.get_rotation_matrix2d(center, angle, scale)
    x_hat = kornia.warp_affine(x, M, dsize=(h, w))
    return x_hat
Beispiel #13
0
def get_all_rotation_matrices(angles, center_image, dtype):
    device = angles.device
    center: torch.tensor = torch.ones(1, 2, device=device)
    center[..., 0] = center_image[0]  # x
    center[..., 1] = center_image[1]  # y
    scale: torch.tensor = torch.ones(1, 2, device=device)
    t, = angles.shape
    list_rot_mat = torch.zeros((t, 2, 3), device=device).type(dtype)
    for i in range(t):
        theta = torch.ones(1, device=device) * (angles[i])
        #alpha = torch.cos(theta)
        #beta = torch.sin(theta)
        list_rot_mat[i, :, :] = kornia.get_rotation_matrix2d(
            center, theta, scale)
    return list_rot_mat
Beispiel #14
0
    def forward(self, x, params):
        angle, scale = params
        self.angle.fill_(angle)
        self.scale.fill_(scale)

        # define the rotation center
        self.center[..., 0] = x.shape[3] / 2
        self.center[..., 1] = x.shape[2] / 2

        M = kornia.get_rotation_matrix2d(self.center, self.angle, self.scale)
        return kornia.warp_affine(x,
                                  M,
                                  dsize=(x.shape[2], x.shape[3]),
                                  flags='bilinear',
                                  padding_mode='reflection')
Beispiel #15
0
def augment_batch(batch_tensor):
    orig_size = batch_tensor.ndim
    batch_size = batch_tensor.shape[0]
    if orig_size == 3:
        batch_tensor=batch_tensor.unsqueeze(0)
    angle = (torch.rand(batch_size)*20)-10
    center = torch.ones(batch_size,2)
    center[..., 0] = batch_tensor.shape[3] / 2
    center[..., 1] = batch_tensor.shape[2] / 2
    scale = torch.ones(batch_size)
    M = kornia.get_rotation_matrix2d(center, angle, scale)
    _, _, h, w = batch_tensor.shape
    batch_tensor_warped = kornia.warp_affine(batch_tensor, M, dsize=(h, w))
    if orig_size == 3:
        batch_tensor_warped=batch_tensor_warped.squeeze(0)
    return batch_tensor_warped
Beispiel #16
0
 def forward(self, img: torch.tensor):
     b, _, h, w = img.shape
     # create transformation (rotation)
     if not self.same_throughout_batch:
         angle = torch.randn(b, device=img.device) * self.angle
     else:
         angle = torch.randn(1, device=img.device) * self.angle
         angle = angle.repeat(b)
     center = torch.ones(b, 2, device=img.device)
     center[..., 0] = img.shape[3] / 2  # x
     center[..., 1] = img.shape[2] / 2  # y
     # define the scale factor
     scale = torch.ones(b, device=img.device)
     M = kornia.get_rotation_matrix2d(center, angle, scale)
     img_warped = kornia.warp_affine(img, M, dsize=(h, w))
     return img_warped
Beispiel #17
0
    def test_rotation_inverse(self, device, dtype):
        h, w = 4, 4
        img_b = torch.rand(1, 1, h, w, device=device, dtype=dtype)

        # create rotation matrix of 90deg (anti-clockwise)
        center = torch.tensor([[w - 1, h - 1]], device=device, dtype=dtype) / 2
        scale = torch.ones((1, 2), device=device, dtype=dtype)
        angle = 90. * torch.ones(1, device=device, dtype=dtype)
        aff_ab = kornia.get_rotation_matrix2d(center, angle, scale)
        # Same as opencv: cv2.getRotationMatrix2D(((w-1)/2,(h-1)/2), 90., 1.)

        # warp the tensor
        # Same as opencv: cv2.warpAffine(kornia.tensor_to_image(img_b), aff_ab[0].numpy(), (w, h))
        img_a = kornia.warp_affine(img_b, aff_ab, (h, w))

        # invert the transform
        aff_ba = kornia.convert_affinematrix_to_homography(aff_ab).inverse()[..., :2, :]
        img_b_hat = kornia.warp_affine(img_a, aff_ba, (h, w))
        assert_allclose(img_b_hat, img_b, atol=1e-3, rtol=1e-3)
def get_heatmap_transformation_matrix(
    jitter_x: Tensor,
    jitter_y: Tensor,
    scale: Tensor,
    angle: Tensor,
    heatmap_dim: Tensor,
) -> Tensor:
    """
    Generates transfromation matric to revert the transformation on heatmap.

    Args:
        jitter_x (Tensor): x Pixels by which heatmap should be jittered (batch)
        jitter_y (Tensor): y Pixels by which heatmap should be jittered (batch)
        scale (Tensor): Scale factor from crop margin (batch).
        angle (Tensor): Rotation angle (batch)
        heatmap_dim (Tensor): Height and width of heatmap (1x2)

    Returns:
        [Tensor]: Transformation matrix (batch x 2 x3).
    """
    # Making a translation matrix
    translations = torch.cat(
        [jitter_x.view(-1, 1), jitter_y.view(-1, 1)], axis=1).float()
    origin = torch.zeros_like(translations)
    zero_angle = torch.zeros_like(jitter_x[:, 0])
    unit_scale = torch.ones_like(translations)
    # NOTE: The function below returns a batch x 3 x 3 matrix.
    translation_matrix = kornia.get_affine_matrix2d(translations=translations,
                                                    center=origin,
                                                    angle=zero_angle,
                                                    scale=unit_scale)
    # Making a rotation matrix.
    center_of_rotation = torch.ones_like(translations) * (
        (heatmap_dim / 2).view(1, 2))
    # NOTE: The function below returns a batch x 2 x 3 matrix.
    rotation_matrix = kornia.get_rotation_matrix2d(
        center=center_of_rotation.float(),
        angle=angle.float(),
        scale=scale.repeat(1, 2).float(),
    )
    # Applying transformations in the order.
    return torch.bmm(rotation_matrix, translation_matrix)
def make_training_batch(input_tensor, patch, patch_mask):

    # determine patch size
    H, W = PA_cfg.image_shape[-2:]
    PATCH_SIZE = int(np.floor(np.sqrt((H*W*PA_cfg.percentage))))

    translate_space = [H-PATCH_SIZE+1, W-PATCH_SIZE+1]
    bs = input_tensor.size(0)

    training_batch = []
    for b in range(bs):
        
        # random translation
        u_t = np.random.randint(low=0, high=translate_space[0])
        v_t = np.random.randint(low=0, high=translate_space[1])
        # random scaling and rotation
        scale = np.random.rand() * (PA_cfg.scale_max - PA_cfg.scale_min) + PA_cfg.scale_min
        scale = torch.Tensor([scale])
        angle = np.random.rand() * (PA_cfg.rotate_max - PA_cfg.rotate_min) + PA_cfg.rotate_min
        angle = torch.Tensor([angle])
        center = torch.Tensor([u_t+PATCH_SIZE/2, v_t+PATCH_SIZE/2]).unsqueeze(0)
        rotation_m = kornia.get_rotation_matrix2d(center, angle, scale)

        # warp three tensors
        temp_mask = patch_mask.unsqueeze(0)
        temp_input = input_tensor[b].unsqueeze(0)
        temp_patch = patch.unsqueeze(0)

        temp_mask = kornia.translate(temp_mask.float(), translation=torch.Tensor([u_t, v_t]).unsqueeze(0))
        temp_patch = kornia.translate(temp_patch.float(), translation=torch.Tensor([u_t, v_t]).unsqueeze(0))

        mask_warpped = kornia.warp_affine(temp_mask.float(), rotation_m, temp_mask.size()[-2:])
        patch_warpped = kornia.warp_affine(temp_patch.float(), rotation_m, temp_patch.size()[-2:])

        # overlay
        overlay = temp_input * (1 - mask_warpped) + patch_warpped * mask_warpped
        
        training_batch.append(overlay)
    
    training_batch = torch.cat(training_batch, dim=0)
    return training_batch
def rotate_tensor_along_y_axis(tensor, gamma):
    B = tensor.shape[0]
    tensor = tensor.to("cpu")
    assert tensor.ndim == 6, "Tensors should have 6 dimensions."
    tensor = tensor.float()
    # B,S,C,D,H,W
    __p = lambda x: utils_basic.pack_seqdim(x, B)
    __u = lambda x: utils_basic.unpack_seqdim(x, B)
    tensor_ = __p(tensor)
    tensor_ = tensor_.permute(
        0, 1, 3, 2, 4)  # Make it BS, C, H, D, W  (i.e. BS, C, y, z, x)
    BS, C, H, D, W = tensor_.shape

    # merge y dimension with channel dimension and rotate with gamma_
    tensor_y_reduced = tensor_.reshape(BS, C * H, D, W)
    # # gammas will be rotation angles along y axis.
    # gammas = torch.arange(10, 360, 10)

    # define the rotation center
    center = torch.ones(1, 2)
    center[..., 0] = tensor_y_reduced.shape[3] / 2  # x
    center[..., 1] = tensor_y_reduced.shape[2] / 2  # z

    # define the scale factor
    scale = torch.ones(1)

    gamma_ = torch.ones(1) * gamma

    # compute the transformation matrix
    M = kornia.get_rotation_matrix2d(center, gamma_, scale)
    M = M.repeat(BS, 1, 1)
    # apply the transformation to original image
    # st()
    tensor_y_reduced_warped = kornia.warp_affine(tensor_y_reduced,
                                                 M,
                                                 dsize=(D, W))
    tensor_y_reduced_warped = tensor_y_reduced_warped.reshape(BS, C, H, D, W)
    tensor_y_reduced_warped = tensor_y_reduced_warped.permute(0, 1, 3, 2, 4)
    tensor_y_reduced_warped = __u(tensor_y_reduced_warped)
    return tensor_y_reduced_warped.cuda()
Beispiel #21
0
    def _get_rotation(self):
        """ Get transformation matrices

        Output:
            rot_mat (float torch.Tensor): tensor of shape (batch_size, 2, 3)
        """
        # define the rotation center
        center = torch.ones(self.batch_size, 2)
        center[..., 0] = self.input_hw[1] / 2  # x
        center[..., 1] = self.input_hw[0] / 2  # y

        # create transformation (rotation)
        angle = torch.tensor([
            random.randint(-self.max_angle, self.max_angle)
            for _ in range(self.batch_size)
        ])
        # define the scale factor
        scale = torch.ones(self.batch_size)

        # compute the transformation matrix
        tf_matrices = kornia.get_rotation_matrix2d(center, angle, scale)
        return tf_matrices
    def align_fake(self, margin=40, alignUnaligned=True):

        # get params
        desiredLeftEye = [
            float(self.alignment_params["desiredLeftEye"][0]),
            float(self.alignment_params["desiredLeftEye"][1])
        ]
        rotation_point = self.alignment_params["eyesCenter"]
        angle = -self.alignment_params["angle"]
        h, w = self.fake_B.shape[2:]
        # get original positions
        m1 = round(w * 0.5)
        m2 = round(desiredLeftEye[1] * w)
        # define the scale factor
        scale = 1 / self.alignment_params["scale"]
        width = int(self.alignment_params["shape"][0])
        long_edge_size = width / abs(np.cos(np.deg2rad(angle)))
        w_original = int(scale * long_edge_size)
        h_original = int(scale * long_edge_size)
        # get offset
        tX = w_original * 0.5
        tY = h_original * desiredLeftEye[1]
        # get rotation center
        center = torch.ones(1, 2)
        center[..., 0] = m1
        center[..., 1] = m2
        # compute the transformation matrix
        M: torch.tensor = kornia.get_rotation_matrix2d(center, angle,
                                                       scale).to(self.device)
        M[0, 0, 2] += (tX - m1)
        M[0, 1, 2] += (tY - m2)

        # get insertion point
        x_start = int(rotation_point[0] - (0.5 * w_original))
        y_start = int(rotation_point[1] - (desiredLeftEye[1] * h_original))
        # _, _, h_tensor, w_tensor = self.real_B_unaligned_full.shape

        # # # # # # # # # # # # # # # # # # ## # # # # # # # ## # # # # ## # # # # # # # # # ## # #
        # get safe margin
        h_size_tensor, w_size_tensor = self.real_B_unaligned_full.shape[2:]
        margin = max(
            min(
                y_start - max(0, y_start - margin),
                x_start - max(0, x_start - margin),
                min(y_start + h_original + margin, h_size_tensor) - y_start -
                h_original,
                min(x_start + w_original + margin, w_size_tensor) - x_start -
                w_original,
            ), 0)
        # get face + margin unaligned space
        self.real_B_aligned_margin = self.real_B_unaligned_full[:, :, y_start -
                                                                margin:
                                                                y_start +
                                                                h_original +
                                                                margin,
                                                                x_start -
                                                                margin:
                                                                x_start +
                                                                w_original +
                                                                margin]
        # invert matrix
        M_inverse = kornia.invert_affine_transform(M)
        # update output size to fit the 256 + scaled margin
        old_size = self.real_B_aligned_margin.shape[2]
        new_size = old_size + 2 * round(float(margin * scale))

        _, _, h_tensor, w_tensor = self.real_B_aligned_margin.shape
        self.real_B_aligned_margin = kornia.warp_affine(
            self.real_B_aligned_margin, M_inverse, dsize=(new_size, new_size))
        # padding_mode="reflection")
        self.fake_B_aligned_margin = self.real_B_aligned_margin.clone(
        ).requires_grad_(True)

        # update margin as we now scale the image!
        # update start point
        start = round(float(margin * scale * new_size / old_size))
        print(start)

        # point = torch.tensor([0, 0, 1], dtype=torch.float)
        # M_ = M_inverse[0].clone().detach()
        # M_ = torch.cat((M_, torch.tensor([[0, 0, 1]], dtype=torch.float)))
        #
        # M_n = M[0].clone().detach()
        # M_n = torch.cat((M_n, torch.tensor([[0, 0, 1]], dtype=torch.float)))
        #
        # start_tensor = torch.matmul(torch.matmul(point, M_) + margin, M_n)
        # print(start_tensor)
        # start_y, start_x = round(float(start_tensor[0])), round(float(start_tensor[1]))

        # reinsert into tensor
        self.fake_B_aligned_margin[0, :, start:start + 256,
                                   start:start + 256] = self.real_B

        Image.fromarray(tensor2im(self.real_B_aligned_margin)).save(
            "/home/mo/datasets/ff_aligned_unaligned/real.png")
        Image.fromarray(tensor2im(self.fake_B_aligned_margin)).save(
            "/home/mo/datasets/ff_aligned_unaligned/fake.png")

        exit()
        # # # # # # # # # # # # # # # # # # ## # # # # # # # ## # # # # ## # # # # # # # # # ## # #
        if not alignUnaligned:
            # Now apply the transformation to original image
            # clone fake
            fake_B_clone = self.fake_B.clone().requires_grad_(True)
            # apply warp
            fake_B_warped: torch.tensor = kornia.warp_affine(
                fake_B_clone, M, dsize=(h_original, w_original))

            # make sure warping does not exceed real_B_unaligned_full dimensions
            if y_start < 0:
                fake_B_warped = fake_B_warped[:, :, abs(y_start):h_original, :]
                h_original += y_start
                y_start = 0
            if x_start < 0:
                fake_B_warped = fake_B_warped[:, :, :, abs(x_start):w_original]
                w_original += x_start
                x_start = 0
            if y_start + h_original > h_tensor:
                h_original -= (y_start + h_original - h_tensor)
                fake_B_warped = fake_B_warped[:, :, 0:h_original, :]
            if x_start + w_original > w_tensor:
                w_original -= (x_start + w_original - w_tensor)
                fake_B_warped = fake_B_warped[:, :, :, 0:w_original]

            # create mask that is true where fake_B_warped is 0
            # This is the background that is not filled with image after the transformation
            mask = ((fake_B_warped[0][0] == 0) & (fake_B_warped[0][1] == 0) &
                    (fake_B_warped[0][2] == 0))
            # fill fake_B_filled where mask = False with self.real_B_unaligned_full
            fake_B_filled = torch.where(
                mask,
                self.real_B_unaligned_full[:, :, y_start:y_start + h_original,
                                           x_start:x_start + w_original],
                fake_B_warped)

            # reinsert into tensor
            self.fake_B_unaligned = self.real_B_unaligned_full.clone(
            ).requires_grad_(True)
            mask = torch.zeros_like(self.fake_B_unaligned, dtype=torch.bool)
            mask[0, :, y_start:y_start + h_original,
                 x_start:x_start + w_original] = True
            self.fake_B_unaligned = self.fake_B_unaligned.masked_scatter(
                mask, fake_B_filled)

            # cutout tensor
            h_size_tensor, w_size_tensor = self.real_B_unaligned_full.shape[2:]
            margin = max(
                min(
                    y_start - max(0, y_start - margin),
                    x_start - max(0, x_start - margin),
                    min(y_start + h_original + margin, h_size_tensor) -
                    y_start - h_original,
                    min(x_start + w_original + margin, w_size_tensor) -
                    x_start - w_original,
                ), 0)
            self.fake_B_unaligned = self.fake_B_unaligned[:, :, y_start -
                                                          margin:y_start +
                                                          h_original + margin,
                                                          x_start -
                                                          margin:x_start +
                                                          w_original + margin]
            self.real_B_unaligned = self.real_B_unaligned_full[:, :, y_start -
                                                               margin:y_start +
                                                               h_original +
                                                               margin,
                                                               x_start -
                                                               margin:x_start +
                                                               w_original +
                                                               margin]
Beispiel #23
0
def mod_compute_cube_frame_conv_grad_pytorch(xd, xp, xl, matrix, angles,
                                             compute_loss, kernel, mask,
                                             center_image, U_L0, stochastic):
    rank, _ = xl.shape
    t, _ = matrix.shape
    if stochastic:
        #number_of_frames = int(t*stochastic)
        #list_indexes = np.random.choice(t, number_of_frames, replace=False)
        number_data = int(np.floor(1. / stochastic))
        sub_data_index = np.random.randint(0, number_data)
        list_indexes = np.arange(t)[sub_data_index::number_data]
    else:
        list_indexes = range(t)
    xs = xd + xp
    n, _ = xs.shape
    conv_op = lambda x: A(x, kernel)
    adj_conv_op = lambda x: A_(x, kernel)
    if t != angles.shape[0]:
        print('ANGLES do not have t elements!')

    class FFTconv_numpy_torch(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            numpy_input = input.detach().numpy()
            result = conv_op(numpy_input)
            return input.new(result)

        @staticmethod
        def backward(ctx, grad_output):
            numpy_go = grad_output.numpy()
            result = adj_conv_op(numpy_go)
            return grad_output.new(result)

    def fft_conv_np_torch(input):
        return FFTconv_numpy_torch.apply(input)

    # needed for rotation:
    center: torch.tensor = torch.ones(1, 2)
    center[..., 0] = center_image[0]  # x
    center[..., 1] = center_image[1]  # y
    scale: torch.tensor = torch.ones(1, 2)

    torch_xs = torch.tensor([[xs]], requires_grad=True)

    torch_data = torch.tensor(matrix.reshape(t, n, n))
    torch_L = torch.tensor(xl, requires_grad=True)
    torch_U_L0 = torch.tensor(U_L0, requires_grad=False)

    loss: torch.tensor = torch.zeros(1, requires_grad=True)

    for k in list_indexes:
        angle: torch.tensor = torch.ones(1) * (angles[k])
        M: torch.tensor = kornia.get_rotation_matrix2d(center, angle, scale)
        rotated_xs = kornia.warp_affine(torch_xs.float(), M, dsize=(n, n))
        loss = loss + compute_loss(
            fft_conv_np_torch(rotated_xs[0, 0, :, :]) +
            torch.mm(torch_U_L0[None, k, :], torch_L).reshape(n, n) -
            torch_data[k, :, :])
    loss.backward()
    torch_grad_xs = torch_xs.grad
    torch_grad_L = torch_L.grad

    np_grad_xs = torch_grad_xs[0, 0, :, :].detach().numpy()
    np_grad_L = torch_grad_L.detach().numpy()
    grad_d_p = np_grad_xs * mask
    return grad_d_p, grad_d_p, np_grad_L.reshape(rank,
                                                 n * n), loss.detach().numpy()
Beispiel #24
0
def validate_translation(template_trans, source_trans, groundTruth_number, scale_gt, gt_trans, model_template_trans, model_source_trans, model_corr2softmax_trans, acc_x, acc_y, device ):
    print("                             ")
    print("                             VALIDATING TRANSLATION")
    print("                             ")
# # for toy dataset
    # b, c, h, w = source_trans.shape
    # center = torch.ones(b,2).to(device)
    # center[:, 0] = h // 2
    # center[:, 1] = w // 2
    # angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device))
    # scale_rot = torch.ones(b).to(device) 
    # rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot)
    # source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w))

# for AGDatase
    since = time.time()
    b, c, h, w = source_trans.shape
    center = torch.ones(b,2).to(device)
    center[:, 0] = h // 2
    center[:, 1] = w // 2
    angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device))
    scale_rot = torch.ones(b).to(device) / scale_gt.to(device)
    rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot)
    source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w))
    # imshow(template_trans[0,:,:,:])
    # plt.show()
    # imshow(source_trans[0,:,:,:])
    # plt.show()

    # imshow(template,"temp")
    # imshow(source, "src")
    
    template_unet_trans = model_template_trans(template_trans)
    source_unet_trans = model_source_trans(source_trans)
    # imshow(template_unet_trans[0,:,:,:])
    # plt.show()
    # imshow(source_unet_trans[0,:,:,:])
    # plt.show()

    # for tensorboard visualize
    template_visual_trans = template_unet_trans
    source_visual_trans = source_unet_trans

    template_unet_trans = template_unet_trans.permute(0,2,3,1)
    source_unet_trans = source_unet_trans.permute(0,2,3,1)

    template_unet_trans = template_unet_trans.squeeze(-1)
    source_unet_trans = source_unet_trans.squeeze(-1)

    (b, h, w) = template_unet_trans.shape
    logbase_trans = torch.tensor(1.)
    phase_corr_layer_xy = PhaseCorr(device, logbase_trans, model_corr2softmax_trans, trans=True)
    t0, t1, softmax_result_trans, corr_result_trans = phase_corr_layer_xy(template_unet_trans.to(device), source_unet_trans.to(device))

# use phasecorr result

    corr_final_trans = corr_result_trans.clone()
    # corr_visual = corr_final_trans.unsqueeze(-1)
    # corr_visual = corr_visual.permute(0,3,1,2)
    corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False)
    # corr_2d = corr_final_trans.clone().reshape(b, h*w)
    # corr_2d = model_corr2softmax(corr_2d)
    corr_y = model_corr2softmax_trans(corr_y)
    input_c = nn.functional.softmax(corr_y.clone(), dim=-1)
    indices_c = np.linspace(0, 1, 256)
    indices_c = torch.tensor(np.reshape(indices_c, (-1, 256))).to(device)
    tranformation_y = torch.sum((256 - 1) * input_c * indices_c, dim=-1)
    tranformation_y_show = torch.argmax(corr_y, dim=-1)

    corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False)
    # corr_final_trans = corr_final_trans.reshape(b, h*w)
    corr_x = model_corr2softmax_trans(corr_x)
    input_r = nn.functional.softmax(corr_x.clone(), dim=-1)
    indices_r = np.linspace(0, 1, 256)
    indices_r = torch.tensor(np.reshape(indices_r, (-1, 256))).to(device)
    tranformation_x_show = torch.argmax(corr_x, dim=-1)
    tranformation_x = torch.sum((256 - 1) * input_r * indices_r, dim=-1)
    time_elapsed = time.time() - since
    print("time elapsed", time_elapsed)
    print('in val time {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

# only consider angle as the los
    # softmax_result = torch.sum(corr_result.clone(), 2, keepdim=False)
    # softmax_final = softmax_result.clone()
    # # softmax_visual = softmax_final.unsqueeze(-1)
    # # softmax_visual = softmax_visual.permute(0,3,1,2)
            
    # softmax_final = softmax_final.reshape(b, h*w)
    # softmax_final = model_corr2softmax(softmax_final.clone())
    gt_trans_orig = gt_trans.clone().to(device)

    # err_true = (1-((t0-gt_trans_orig).abs()/(gt_trans_orig+1e-15))).mean()
    # if err_true <= 0:
    #     err_true = torch.Tensor([0.5])
    # print("err_true = ",err_true.item()*100,"%")

    gt_trans_convert = GT_trans_convert(gt_trans_orig, [256, 256])
    gt_trans_convert_y = gt_trans_convert[:,0]
    gt_trans_convert_x = gt_trans_convert[:,1]

    print("trans x", tranformation_x_show)
    print("trans y", tranformation_y_show)
    print("gt_convert x", gt_trans_convert_x)
    print("gt_convert y", gt_trans_convert_y)
    # err_y = (1-(abs(tranformation_y_show-gt_trans_convert_y))/(gt_trans_convert_y+1e-10)).mean()
    # if err_y <= 0:
    #     err_y = torch.Tensor([0.5])
    # print("err y", err_y.item()*100,"%")


    # err_x = (1-(abs(tranformation_x_show-gt_trans_convert_x))/(gt_trans_convert_x+1e-10)).mean()
    # if err_x <= 0:
    #     err_x = torch.Tensor([0.5])
    # print("err x", err_x.item()*100,"%")
    
    arg_x = []
    corr_top5 = corr_x.clone().detach().cpu()
    for i in range(2):
        max_ind = torch.argmax(corr_top5, dim=-1)
        arg_x.append(max_ind)
        for batch_num in range(b):
            corr_top5[batch_num, max_ind[batch_num]] = 0.


    for batch_num in range(b):
        min_x = 100.
        for i in range(2):
            if abs(gt_trans_convert_x[batch_num].float() - torch.round(arg_x[i][batch_num].float())) < min_x:
                min_x = abs(gt_trans_convert_x[batch_num].float() - torch.round(arg_x[i][batch_num].float()))
    
    for class_id in range(20):
        for batch_num in range(b):
            if min_x <= class_id:
                acc_x[class_id] += 1


    arg_y = []
    corr_top5 = corr_y.clone().detach().cpu()
    for i in range(2):
        max_ind = torch.argmax(corr_top5, dim=-1)
        arg_y.append(max_ind)
        for batch_num in range(b):
            corr_top5[batch_num, max_ind[batch_num]] = 0.

    
    for batch_num in range(b):
        min_y = 100.
        for i in range(2):
            if abs(gt_trans_convert_y[batch_num] - torch.round(arg_y[i][batch_num].float())) < min_y:
                min_y = abs(gt_trans_convert_y[batch_num] - torch.round(arg_y[i][batch_num].float()))
    
    for class_id in range(20):
        for batch_num in range(b):
            if min_y <= class_id:
                acc_y[class_id] += 1

                
# set the loss function:
    # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device)
    # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device)
    compute_loss_y = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
    compute_loss_x = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
    # mse_loss = torch.nn.MSELoss(reduce=True)
    # compute_loss = torch.nn.NLLLoss()
    # compute_l1loss_a = torch.nn.L1Loss()
    # compute_l1loss_x = torch.nn.L1Loss()
    # compute_l1loss_y = torch.nn.L1Loss()
    compute_mse = torch.nn.MSELoss()
    compute_l1 = torch.nn.L1Loss()

    # mse_loss = mse_loss(rotation_cal.float(), groundTruth_number.float())
    # l1loss = compute_l1loss(rotation_cal, groundTruth_number)
    loss_l1_x = compute_l1(tranformation_x, gt_trans_convert_x)
    loss_l1_y = compute_l1(tranformation_y, gt_trans_convert_y)
    loss_mse_x = compute_mse(tranformation_x, gt_trans_convert_x)
    loss_mse_y = compute_mse(tranformation_y, gt_trans_convert_y)
    loss_x = compute_loss_x(corr_x, gt_trans_convert_x)
    loss_y = compute_loss_y(corr_y, gt_trans_convert_y)
    total_loss = loss_x + loss_y + loss_l1_x +loss_l1_y
    return loss_y, loss_x, total_loss, loss_l1_x,loss_l1_y,loss_mse_x, loss_mse_y
Beispiel #25
0
data: torch.tensor = kornia.image_to_tensor(img)  # BxCxHxW

# create transformation (rotation)
alpha: float = 45.0  # in degrees
angle: torch.tensor = torch.ones(1) * alpha

# define the rotation center
center: torch.tensor = torch.ones(1, 2)
center[..., 0] = data.shape[3] / 2  # x
center[..., 1] = data.shape[2] / 2  # y

# define the scale factor
scale: torch.tensor = torch.ones(1)

# compute the transformation matrix
M: torch.tensor = kornia.get_rotation_matrix2d(center, angle, scale)

# apply the transformation to original image
_, _, h, w = data.shape
data_warped: torch.tensor = kornia.warp_affine(data.float(), M, dsize=(h, w))

# convert back to numpy
img_warped: np.ndarray = kornia.tensor_to_image(data_warped.byte()[0])

# create the plot
fig, axs = plt.subplots(1, 2, figsize=(16, 10))
axs = axs.ravel()

axs[0].axis('off')
axs[0].set_title('image source')
axs[0].imshow(img)
Beispiel #26
0
def train_translation(template_trans, source_trans, groundTruth_number,
                      scale_gt, gt_trans, model_template_trans,
                      model_source_trans, model_corr2softmax_trans, phase,
                      dsnt, device):
    print("                             ")
    print("                             TRAINING TRANSLATION")
    print("                             ")
    with torch.set_grad_enabled(phase == 'train'):
        # # for toy dataset
        #         b, c, h, w = source_trans.shape
        #         center = torch.ones(b,2).to(device)
        #         center[:, 0] = h // 2
        #         center[:, 1] = w // 2
        #         angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device))
        #         scale_rot = torch.ones(b).to(device)
        #         rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot)
        #         source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w))

        # for AGDatase
        b, c, h, w = source_trans.shape
        center = torch.ones(b, 2).to(device)
        center[:, 0] = h // 2
        center[:, 1] = w // 2
        angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device))
        scale_rot = torch.ones(b).to(device) / scale_gt.to(device)
        rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot)
        source_trans = kornia.warp_affine(source_trans.to(device),
                                          rot_mat,
                                          dsize=(h, w))

        # imshow(template,"temp")
        # imshow(source, "src")

        template_unet_trans = model_template_trans(template_trans)
        source_unet_trans = model_source_trans(source_trans)

        # for tensorboard visualize
        template_visual_trans = template_unet_trans
        source_visual_trans = source_unet_trans

        template_unet_trans = template_unet_trans.permute(0, 2, 3, 1)
        source_unet_trans = source_unet_trans.permute(0, 2, 3, 1)

        template_unet_trans = template_unet_trans.squeeze(-1)
        source_unet_trans = source_unet_trans.squeeze(-1)

        (b, h, w) = template_unet_trans.shape
        logbase_trans = torch.tensor(1.)
        phase_corr_layer_xy = PhaseCorr(device,
                                        logbase_trans,
                                        model_corr2softmax_trans,
                                        trans=True)
        t0, t1, softmax_result_trans, corr_result_trans = phase_corr_layer_xy(
            template_unet_trans.to(device), source_unet_trans.to(device))

        # use phasecorr result
        if not dsnt:

            corr_final_trans = corr_result_trans.clone()
            # corr_visual = corr_final_trans.unsqueeze(-1)
            # corr_visual = corr_visual.permute(0,3,1,2)
            corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False)
            # corr_2d = corr_final_trans.clone().reshape(b, h*w)
            # corr_2d = model_corr2softmax(corr_2d)
            corr_y = model_corr2softmax_trans(corr_y)
            input_c = nn.functional.softmax(corr_y.clone(), dim=-1)
            indices_c = np.linspace(0, 1, 256)
            indices_c = torch.tensor(np.reshape(indices_c,
                                                (-1, 256))).to(device)
            tranformation_y = torch.sum((256 - 1) * input_c * indices_c,
                                        dim=-1)
            # tranformation_y = torch.argmax(corr_y, dim=-1)

            corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False)
            # corr_final_trans = corr_final_trans.reshape(b, h*w)
            corr_x = model_corr2softmax_trans(corr_x)
            input_r = nn.functional.softmax(corr_x.clone(), dim=-1)
            indices_r = np.linspace(0, 1, 256)
            indices_r = torch.tensor(np.reshape(indices_r,
                                                (-1, 256))).to(device)
            # tranformation_x = torch.argmax(corr_x, dim=-1)
            tranformation_x = torch.sum((256 - 1) * input_r * indices_r,
                                        dim=-1)

            # only consider angle as the los
            # softmax_result = torch.sum(corr_result.clone(), 2, keepdim=False)
            # softmax_final = softmax_result.clone()
            # # softmax_visual = softmax_final.unsqueeze(-1)
            # # softmax_visual = softmax_visual.permute(0,3,1,2)

            # softmax_final = softmax_final.reshape(b, h*w)
            # softmax_final = model_corr2softmax(softmax_final.clone())
            gt_trans_orig = gt_trans.clone().to(device)

            # print("err_true = ",err_true.item()*100,"%")

            gt_trans_convert = GT_trans_convert(gt_trans_orig, [256, 256])
            gt_trans_convert_y = gt_trans_convert[:, 0]
            gt_trans_convert_x = gt_trans_convert[:, 1]

            print("trans x", tranformation_x)
            print("gt_convert x", gt_trans_convert_x, "\n")

            print("trans y", tranformation_y)
            print("gt_convert y", gt_trans_convert_y, "\n")

            # set the loss function:
            # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device)
            # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device)
            compute_loss_y = torch.nn.CrossEntropyLoss(
                reduction="sum").to(device)
            compute_loss_x = torch.nn.CrossEntropyLoss(
                reduction="sum").to(device)
            # mse_loss = torch.nn.MSELoss(reduce=True)
            # compute_loss = torch.nn.NLLLoss()
            # compute_l1loss_a = torch.nn.L1Loss()
            # compute_l1loss_x = torch.nn.L1Loss()
            compute_mse = torch.nn.MSELoss()
            compute_l1 = torch.nn.L1Loss()

            # mse_loss = mse_loss(rotation_cal.float(), groundTruth_number.float())
            loss_l1_x = compute_l1(tranformation_x, gt_trans_convert_x)
            loss_l1_y = compute_l1(tranformation_y, gt_trans_convert_y)
            loss_mse_x = compute_mse(tranformation_x, gt_trans_convert_x)
            loss_mse_y = compute_mse(tranformation_y, gt_trans_convert_y)
            loss_x = compute_loss_x(corr_x, gt_trans_convert_x)
            loss_y = compute_loss_y(corr_y, gt_trans_convert_y)
            total_loss = loss_x + loss_y + loss_l1_x + loss_l1_y
            return loss_y, loss_x, total_loss, loss_l1_x, loss_l1_y, loss_mse_x, loss_mse_y, template_visual_trans, source_visual_trans
        else:
            corr_final_trans = corr_result_trans.clone()
            # corr_visual = corr_final_trans.unsqueeze(-1)
            # corr_visual = corr_visual.permute(0,3,1,2)
            corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False)
            # corr_2d = corr_final_trans.clone().reshape(b, h*w)
            # corr_2d = model_corr2softmax(corr_2d)
            corr_y = model_corr2softmax_trans(corr_y)

            corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False)
            # corr_final_trans = corr_final_trans.reshape(b, h*w)
            corr_x = model_corr2softmax_trans(corr_x)

            corr_mat_dsnt_trans = corr_result_trans.clone().unsqueeze(-1)
            corr_mat_dsnt_trans_final = model_corr2softmax_trans(
                corr_mat_dsnt_trans)
            corr_mat_dsnt_trans_final = kornia.spatial_softmax2d(
                corr_mat_dsnt_trans_final)
            coors_trans = kornia.spatial_expectation2d(
                corr_mat_dsnt_trans_final, False)
            tranformation_x = coors_trans[:, 0, 0]
            tranformation_y = coors_trans[:, 0, 1]

            # only consider angle as the los
            # softmax_result = torch.sum(corr_result.clone(), 2, keepdim=False)
            # softmax_final = softmax_result.clone()
            # # softmax_visual = softmax_final.unsqueeze(-1)
            # # softmax_visual = softmax_visual.permute(0,3,1,2)

            # softmax_final = softmax_final.reshape(b, h*w)
            # softmax_final = model_corr2softmax(softmax_final.clone())
            gt_trans_orig = gt_trans.clone().to(device)

            # print("err_true = ",err_true.item()*100,"%")

            gt_trans_convert = GT_trans_convert(gt_trans_orig, [256, 256])
            gt_trans_convert_y = gt_trans_convert[:, 0]
            gt_trans_convert_x = gt_trans_convert[:, 1]

            print("trans x", tranformation_x)
            print("gt_convert x", gt_trans_convert_x, "\n")

            print("trans y", tranformation_y)
            print("gt_convert y", gt_trans_convert_y, "\n")

            # set the loss function:
            # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device)
            # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device)
            compute_loss_y = torch.nn.CrossEntropyLoss(
                reduction="sum").to(device)
            compute_loss_x = torch.nn.CrossEntropyLoss(
                reduction="sum").to(device)
            # mse_loss = torch.nn.MSELoss(reduce=True)
            # compute_loss = torch.nn.NLLLoss()
            # compute_l1loss_a = torch.nn.L1Loss()
            # compute_l1loss_x = torch.nn.L1Loss()
            compute_mse = torch.nn.MSELoss()
            compute_l1 = torch.nn.L1Loss()

            # mse_loss = mse_loss(rotation_cal.float(), groundTruth_number.float())
            loss_l1_x = compute_l1(tranformation_x, gt_trans_convert_x)
            loss_l1_y = compute_l1(tranformation_y, gt_trans_convert_y)
            loss_mse_x = compute_mse(
                tranformation_x,
                gt_trans_convert_x.type(torch.FloatTensor).to(device))
            loss_mse_y = compute_mse(
                tranformation_y,
                gt_trans_convert_y.type(torch.FloatTensor).to(device))
            loss_x = compute_loss_x(corr_x, gt_trans_convert_x)
            loss_y = compute_loss_y(corr_y, gt_trans_convert_y)
            total_loss = 0.001 * (loss_mse_x + loss_mse_y)
            return loss_y, loss_x, total_loss, loss_l1_x, loss_l1_y, loss_mse_x, loss_mse_y, template_visual_trans, source_visual_trans
Beispiel #27
0
def detect_translation(template_trans, source_trans, rotation, scale, model_template_trans, model_source_trans, model_corr2softmax_trans, device ):
    print("                             ")
    print("                             DETECTING TRANSLATION")
    print("                             ")


# for AGDatase
    b, c, h, w = source_trans.shape
    center = torch.ones(b,2).to(device)
    center[:, 0] = h // 2
    center[:, 1] = w // 2
    angle_rot = torch.ones(b).to(device) * (-rotation.to(device))
    scale_rot = torch.ones(b).to(device) * (1/scale.to(device))
    rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot)
    source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w))
    # imshow(template_trans[0,:,:])
    # time.sleep(2)
    # imshow(source_trans[0,:,:])
    # time.sleep(2)

    # imshow(template,"temp")
    # imshow(source, "src")
    
    template_unet_trans = model_template_trans(template_trans)
    source_unet_trans = model_source_trans(source_trans)

    # for tensorboard visualize
    template_visual_trans = template_unet_trans
    source_visual_trans = source_unet_trans

    template_unet_trans = template_unet_trans.permute(0,2,3,1)
    source_unet_trans = source_unet_trans.permute(0,2,3,1)

    template_unet_trans = template_unet_trans.squeeze(-1)
    source_unet_trans = source_unet_trans.squeeze(-1)

    (b, h, w) = template_unet_trans.shape
    logbase_trans = torch.tensor(1.)
    phase_corr_layer_xy = PhaseCorr(device, logbase_trans, model_corr2softmax_trans)
    t0, t1, softmax_result_trans, corr_result_trans = phase_corr_layer_xy(template_unet_trans.to(device), source_unet_trans.to(device))

# use phasecorr result

    corr_final_trans = corr_result_trans.clone()
    # corr_visual = corr_final_trans.unsqueeze(-1)
    # corr_visual = corr_visual.permute(0,3,1,2)
    corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False)
    # corr_2d = corr_final_trans.clone().reshape(b, h*w)
    # corr_2d = model_corr2softmax(corr_2d)
    corr_y = model_corr2softmax_trans(corr_y)
    input_c = nn.functional.softmax(corr_y.clone(), dim=-1)
    indices_c = np.linspace(0, 1, 256)
    indices_c = torch.tensor(np.reshape(indices_c, (-1, 256))).to(device)
    transformation_y = torch.sum((256 - 1) * input_c * indices_c, dim=-1)
    # transformation_y = torch.argmax(corr_y, dim=-1)

    corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False)
    # corr_final_trans = corr_final_trans.reshape(b, h*w)
    corr_x = model_corr2softmax_trans(corr_x)
    input_r = nn.functional.softmax(corr_x.clone(), dim=-1)
    indices_r = np.linspace(0, 1, 256)
    indices_r = torch.tensor(np.reshape(indices_r, (-1, 256))).to(device)
    # transformation_x = torch.argmax(corr_x, dim=-1)
    transformation_x = torch.sum((256 - 1) * input_r * indices_r, dim=-1)

    print("trans x", transformation_x)
    print("trans y", transformation_y)

    trans_mat_affine = torch.Tensor([[[1.0,0.0,transformation_x-128.0],[0.0,1.0,transformation_y-128.0]]]).to(device)
    template_trans = kornia.warp_affine(template_trans.to(device), trans_mat_affine, dsize=(h, w))
    image_aligned = align_image(template_trans[0,:,:], source_trans[0,:,:])
    # imshow(template_trans[0,:,:])
    # time.sleep(2)
    # imshow(source_trans[0,:,:])
    # time.sleep(2)

    return transformation_y, transformation_x, image_aligned, source_trans