示例#1
0
    def test_warp_perspective(self):
        # generate input data
        batch_size = 1
        height, width = 16, 32
        alpha = tgm.pi / 2  # 90 deg rotation

        # create data patch
        patch = torch.rand(batch_size, 1, height, width)

        # create transformation (rotation)
        M = torch.tensor([[
            [torch.cos(alpha), -torch.sin(alpha), 0.],
            [torch.sin(alpha), torch.cos(alpha), 0.],
            [0., 0., 1.],
        ]])  # Bx3x3

        # apply transformation and inverse
        _, _, h, w = patch.shape
        patch_warped = tgm.warp_perspective(patch, M, dsize=(height, width))
        patch_warped_inv = tgm.warp_perspective(patch_warped, tgm.inverse(M),
                                                dsize=(height, width))

        # generate mask to compute error
        mask = torch.ones_like(patch)
        mask_warped_inv = tgm.warp_perspective(
            tgm.warp_perspective(patch, M, dsize=(height, width)),
            tgm.inverse(M), dsize=(height, width))

        res = utils.check_equal_torch(mask_warped_inv * patch,
                                      mask_warped_inv * patch_warped_inv)
        self.assertTrue(res)
示例#2
0
def get_warped_images_torch(base_img):
    h, w = base_img.shape[1:]
    imgs = []
    rotXhere = rotXdeg

    distances, degrees, images = [], [], []
    for dist in range(start_dist, min_dist, -step_size):

        rotXhere -= x_decay
        rotZdeg = (2 * z_deviation) * np.random.random_sample() + z_min

        rotX = np.deg2rad(rotXhere - 90)
        rotZ = np.deg2rad(rotZdeg - 90)

        # Final and overall transformation matrix
        H = get_H_matrix(dist, w, h, rotX, rotZ)
        H = ch.from_numpy(H)

        img_warp = tgm.warp_perspective(base_img.unsqueeze(0), H, dsize=(h, w))

        imgs.append(img_warp)
        distances.append(dist)
        degrees.append(rotZdeg)

    return np.array(distances), np.array(degrees), ch.cat(imgs, 0).float()
示例#3
0
    def test_warp_perspective_crop(self):
        # generate input data
        batch_size = 1
        src_h, src_w = 3, 4
        dst_h, dst_w = 3, 2

        # [x, y] origin
        # top-left, top-right, bottom-right, bottom-left
        points_src = torch.FloatTensor([[
            [1, 0], [2, 0], [2, 2], [1, 2],
        ]])

        # [x, y] destination
        # top-left, top-right, bottom-right, bottom-left
        points_dst = torch.FloatTensor([[
            [0, 0], [dst_w - 1, 0], [dst_w - 1, dst_h - 1], [0, dst_h - 1],
        ]])

        # compute transformation between points
        dst_pix_trans_src_pix = tgm.get_perspective_transform(
            points_src, points_dst)

        # create points grid in normalized coordinates
        grid_src_norm = tgm.create_meshgrid(src_h, src_w,
                                            normalized_coordinates=True)
        grid_src_norm = torch.unsqueeze(grid_src_norm, dim=0)

        # create points grid in pixel coordinates
        grid_src_pix = tgm.create_meshgrid(src_h, src_w,
                                           normalized_coordinates=False)
        grid_src_pix = torch.unsqueeze(grid_src_pix, dim=0)

        src_norm_trans_src_pix = tgm.normal_transform_pixel(src_h, src_w)
        src_pix_trans_src_norm = tgm.inverse(src_norm_trans_src_pix)

        dst_norm_trans_dst_pix = tgm.normal_transform_pixel(dst_h, dst_w)

        # transform pixel grid
        grid_dst_pix = tgm.transform_points(
            dst_pix_trans_src_pix, grid_src_pix)
        grid_dst_norm = tgm.transform_points(
            dst_norm_trans_dst_pix, grid_dst_pix)

        # transform norm grid
        dst_norm_trans_src_norm = torch.matmul(
            dst_norm_trans_dst_pix, torch.matmul(
                dst_pix_trans_src_pix, src_pix_trans_src_norm))
        grid_dst_norm2 = tgm.transform_points(
            dst_norm_trans_src_norm, grid_src_norm)

        # grids should be equal
        self.assertTrue(utils.check_equal_torch(
            grid_dst_norm, grid_dst_norm2))

        # warp tensor
        patch = torch.rand(batch_size, 1, src_h, src_w)
        patch_warped = tgm.warp_perspective(
            patch, dst_pix_trans_src_pix, (dst_h, dst_w))
        self.assertTrue(utils.check_equal_torch(
            patch[:, :, :3, 1:3], patch_warped))
示例#4
0
def test_warp_perspective_rotation(batch_shape, device_type):
    # generate input data
    batch_size, channels, height, width = batch_shape
    alpha = 0.5 * tgm.pi * torch.ones(batch_size)  # 90 deg rotation

    # create data patch
    device = torch.device(device_type)
    patch = torch.rand(batch_shape).to(device)

    # create transformation (rotation)
    M = torch.eye(3, device=device).repeat(batch_size, 1, 1)  # Bx3x3
    M[:, 0, 0] = torch.cos(alpha)
    M[:, 0, 1] = -torch.sin(alpha)
    M[:, 1, 0] = torch.sin(alpha)
    M[:, 1, 1] = torch.cos(alpha)

    # apply transformation and inverse
    _, _, h, w = patch.shape
    patch_warped = tgm.warp_perspective(patch, M, dsize=(height, width))
    patch_warped_inv = tgm.warp_perspective(patch_warped,
                                            torch.inverse(M),
                                            dsize=(height, width))

    # generate mask to compute error
    mask = torch.ones_like(patch)
    mask_warped_inv = tgm.warp_perspective(tgm.warp_perspective(patch,
                                                                M,
                                                                dsize=(height,
                                                                       width)),
                                           torch.inverse(M),
                                           dsize=(height, width))

    assert utils.check_equal_torch(mask_warped_inv * patch,
                                   mask_warped_inv * patch_warped_inv)

    # evaluate function gradient
    patch = utils.tensor_to_gradcheck_var(patch)  # to var
    M = utils.tensor_to_gradcheck_var(M, requires_grad=False)  # to var
    assert gradcheck(tgm.warp_perspective, (patch, M, (
        height,
        width,
    )),
                     raise_exception=True)
示例#5
0
    def test_crop(self, device_type, batch_size, channels):
        # generate input data
        src_h, src_w = 3, 3
        dst_h, dst_w = 3, 3
        device = torch.device(device_type)

        # [x, y] origin
        # top-left, top-right, bottom-right, bottom-left
        points_src = torch.FloatTensor([[
            [0, 0],
            [0, src_w - 1],
            [src_h - 1, src_w - 1],
            [src_h - 1, 0],
        ]])

        # [x, y] destination
        # top-left, top-right, bottom-right, bottom-left
        points_dst = torch.FloatTensor([[
            [0, 0],
            [0, dst_w - 1],
            [dst_h - 1, dst_w - 1],
            [dst_h - 1, 0],
        ]])

        # compute transformation between points
        dst_trans_src = tgm.get_perspective_transform(points_src,
                                                      points_dst).expand(
                                                          batch_size, -1, -1)

        # warp tensor
        patch = torch.FloatTensor([[[
            [1, 2, 3, 4],
            [5, 6, 7, 8],
            [9, 10, 11, 12],
            [13, 14, 15, 16],
        ]]]).expand(batch_size, channels, -1, -1)

        expected = torch.FloatTensor([[[
            [1, 2, 3],
            [5, 6, 7],
            [9, 10, 11],
        ]]])

        # warp and assert
        patch_warped = tgm.warp_perspective(patch, dst_trans_src,
                                            (dst_h, dst_w))
        assert_allclose(patch_warped, expected)
示例#6
0
def bilinear(I, theta, mode='translation', device='cuda:0'):
    ## shift images I (B,C,H,W) by traslation theta (b,2)
    theta = theta
    H = torch.eye(3).to(device)
    H = H.repeat(I.size()[0], 1).view(-1, 3, 3)
    H[:, 0, 2] = theta[:, 0]  # tx
    H[:, 1, 2] = theta[:, 1]  # ty

    if mode == 'rototranslation':
        H[:, 0, 0] = torch.cos(theta[:, 2])  # cos
        H[:, 0, 1] = -torch.sin(theta[:, 2])  # -sin
        H[:, 1, 0] = torch.sin(theta[:, 2])  # sin
        H[:, 1, 1] = torch.cos(theta[:, 2])  # cos
    new_I = tgm.warp_perspective(I, H,
                                 dsize=(I.size()[2],
                                        I.size()[3]))  # bilinear interpolation
    return new_I
示例#7
0
    def test_crop_center_resize(self, device_type):
        # generate input data
        dst_h, dst_w = 4, 4
        device = torch.device(device_type)

        # [x, y] origin
        # top-left, top-right, bottom-right, bottom-left
        points_src = torch.FloatTensor([[
            [1, 1],
            [1, 2],
            [2, 2],
            [2, 1],
        ]])

        # [x, y] destination
        # top-left, top-right, bottom-right, bottom-left
        points_dst = torch.FloatTensor([[
            [0, 0],
            [0, dst_w - 1],
            [dst_h - 1, dst_w - 1],
            [dst_h - 1, 0],
        ]])

        # compute transformation between points
        dst_trans_src = tgm.get_perspective_transform(points_src, points_dst)

        # warp tensor
        patch = torch.FloatTensor([[[
            [1, 2, 3, 4],
            [5, 6, 7, 8],
            [9, 10, 11, 12],
            [13, 14, 15, 16],
        ]]])

        expected = torch.FloatTensor([[[
            [6.000, 6.333, 6.666, 7.000],
            [7.333, 7.666, 8.000, 8.333],
            [8.666, 9.000, 9.333, 9.666],
            [10.000, 10.333, 10.666, 11.000],
        ]]])

        # warp and assert
        patch_warped = tgm.warp_perspective(patch, dst_trans_src,
                                            (dst_h, dst_w))
        assert_allclose(patch_warped, expected)
示例#8
0
def test_warp_perspective_crop(batch_size, device_type, channels):
    # generate input data
    src_h, src_w = 3, 4
    dst_h, dst_w = 3, 2
    device = torch.device(device_type)

    # [x, y] origin
    # top-left, top-right, bottom-right, bottom-left
    points_src = torch.rand(batch_size, 4, 2).to(device)
    points_src[:, :, 0] *= dst_h
    points_src[:, :, 1] *= dst_w

    # [x, y] destination
    # top-left, top-right, bottom-right, bottom-left
    points_dst = torch.zeros_like(points_src)
    points_dst[:, 1, 0] = dst_w - 1
    points_dst[:, 2, 0] = dst_w - 1
    points_dst[:, 2, 1] = dst_h - 1
    points_dst[:, 3, 1] = dst_h - 1

    # compute transformation between points
    dst_pix_trans_src_pix = tgm.get_perspective_transform(
        points_src, points_dst)

    # create points grid in normalized coordinates
    grid_src_norm = tgm.create_meshgrid(src_h,
                                        src_w,
                                        normalized_coordinates=True)
    grid_src_norm = grid_src_norm.repeat(batch_size, 1, 1, 1).to(device)

    # create points grid in pixel coordinates
    grid_src_pix = tgm.create_meshgrid(src_h,
                                       src_w,
                                       normalized_coordinates=False)
    grid_src_pix = grid_src_pix.repeat(batch_size, 1, 1, 1).to(device)

    src_norm_trans_src_pix = tgm.normal_transform_pixel(src_h, src_w).repeat(
        batch_size, 1, 1).to(device)
    src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix)

    dst_norm_trans_dst_pix = tgm.normal_transform_pixel(dst_h, dst_w).repeat(
        batch_size, 1, 1).to(device)

    # transform pixel grid
    grid_dst_pix = tgm.transform_points(dst_pix_trans_src_pix.unsqueeze(1),
                                        grid_src_pix)
    grid_dst_norm = tgm.transform_points(dst_norm_trans_dst_pix.unsqueeze(1),
                                         grid_dst_pix)

    # transform norm grid
    dst_norm_trans_src_norm = torch.matmul(
        dst_norm_trans_dst_pix,
        torch.matmul(dst_pix_trans_src_pix, src_pix_trans_src_norm))
    grid_dst_norm2 = tgm.transform_points(dst_norm_trans_src_norm.unsqueeze(1),
                                          grid_src_norm)

    # grids should be equal
    # TODO: investage why precision is that low
    assert utils.check_equal_torch(grid_dst_norm, grid_dst_norm2, 1e-2)

    # warp tensor
    patch = torch.rand(batch_size, channels, src_h, src_w)
    patch_warped = tgm.warp_perspective(patch, dst_pix_trans_src_pix,
                                        (dst_h, dst_w))
    assert patch_warped.shape == (batch_size, channels, dst_h, dst_w)
示例#9
0
def forward_pass(secret_img,
                 secret_target,
                 cover_img,
                 cover_target,
                 Hnet,
                 Rnet,
                 criterion,
                 val_cover=0,
                 i_c=None,
                 position=None,
                 Se_two=None,
                 with_std_noise=True,
                 with_warp_noise=True):

    batch_size_secret, channel_secret, _, _ = secret_img.size()
    batch_size_cover, channel_cover, _, _ = cover_img.size()

    if opt.cuda:
        cover_img = cover_img.cuda()
        secret_img = secret_img.cuda()
        #concat_img = concat_img.cuda()

    secret_imgv = secret_img.view(batch_size_secret // opt.num_secret,
                                  channel_secret * opt.num_secret,
                                  opt.imageSize, opt.imageSize)
    secret_imgv_nh = secret_imgv.repeat(opt.num_training, 1, 1, 1)

    cover_img = cover_img.view(batch_size_cover // opt.num_cover,
                               channel_cover * opt.num_cover, opt.imageSize,
                               opt.imageSize)

    if opt.no_cover and (
            val_cover == 0
    ):  # if val_cover = 1, always use cover in val; otherwise, no_cover True >>> not using cover in training
        cover_img.fill_(0.0)
    if (opt.plain_cover or opt.noise_cover) and (val_cover == 0):
        cover_img.fill_(0.0)
    b, c, w, h = cover_img.size()
    if opt.plain_cover and (val_cover == 0):
        img_w1 = torch.cat(
            (torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()),
            dim=2)
        img_w2 = torch.cat(
            (torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()),
            dim=2)
        img_w3 = torch.cat(
            (torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()),
            dim=2)
        img_w4 = torch.cat(
            (torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
             torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()),
            dim=2)
        img_wh = torch.cat((img_w1, img_w2, img_w3, img_w4), dim=3)
        cover_img = cover_img + img_wh
    if opt.noise_cover and (val_cover == 0):
        cover_img = cover_img + (
            (torch.rand(b, c, w, h) - 0.5) * 2 * 0 / 255).cuda()

    cover_imgv = cover_img

    if opt.cover_dependent:
        H_input = torch.cat((cover_imgv, secret_imgv), dim=1)
    else:
        H_input = secret_imgv

    itm_secret_img = Hnet(H_input)
    if i_c != None:
        if type(i_c) == type(1.0):
            #######To keep one channel
            itm_secret_img_clone = itm_secret_img.clone()
            itm_secret_img.fill_(0)
            itm_secret_img[:, int(i_c):int(i_c) +
                           1, :, :] = itm_secret_img_clone[:,
                                                           int(i_c):int(i_c) +
                                                           1, :, :]
        if type(i_c) == type(1):
            print('aaaaa', i_c)
            #######To set one channel to zero
            itm_secret_img[:, i_c:i_c + 1, :, :].fill_(0.0)

    if position != None:
        itm_secret_img[:, :, position:position + 1,
                       position:position + 1].fill_(0.0)
    if Se_two == 2:
        itm_secret_img_half = itm_secret_img[0:batch_size_secret // 2, :, :, :]
        itm_secret_img = itm_secret_img + torch.cat(
            (itm_secret_img_half.clone().fill_(0.0), itm_secret_img_half), 0)
    elif type(Se_two) == type(0.1):
        itm_secret_img = itm_secret_img + Se_two * torch.rand(
            itm_secret_img.size()).cuda()
    if opt.cover_dependent:
        container_img = itm_secret_img
    else:
        itm_secret_img = itm_secret_img.repeat(opt.num_training, 1, 1, 1)
        container_img = itm_secret_img + cover_imgv
        print('L1 metric of residual={:.4f}'.format(
            itm_secret_img.abs().mean()))
    errH = criterion(container_img, cover_imgv)  # Hiding net

    if with_std_noise:
        std_noise = (torch.rand(1) * 0.05).item()
        noise = torch.randn_like(container_img) * std_noise
        container_img = container_img + noise

    if with_warp_noise:
        # Get random homography matrix
        homography = get_rand_homography_mat(opt.imageSize,
                                             opt.imageSize * 0.1,
                                             opt.bs_secret)
        homography = torch.from_numpy(homography).float().cuda()

        # Apply the homography and then undo it
        # (DCMMC) here is the key to train model for Universal photographic steganography
        container_img = tgm.warp_perspective(container_img, homography[:, 1],
                                             (opt.imageSize, opt.imageSize))
        container_img = tgm.warp_perspective(container_img, homography[:, 0],
                                             (opt.imageSize, opt.imageSize))

    rev_secret_img = Rnet(container_img)
    #secret_imgv = Variable(secret_img)
    errR = criterion(rev_secret_img, secret_imgv_nh)  # Reveal net

    # L1 metric
    diffH = (container_img - cover_imgv).abs().mean() * 255
    diffR = (rev_secret_img - secret_imgv_nh).abs().mean() * 255
    return cover_imgv, container_img, secret_imgv_nh, rev_secret_img, errH, errR, diffH, diffR
def loss_function(
		model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False
):
	output = model({
		'image1': batch['image1'].to(device),
		'image2': batch['image2'].to(device)
	})

	loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
	has_grad = False

	n_valid_samples = 0
	for idx_in_batch in range(batch['image1'].size(0)):
		# Annotations
		depth1 = batch['depth1'][idx_in_batch].to(device)  # [h1, w1]
		intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device)  # [3, 3]
		pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device)  # [4, 4]
		bbox1 = batch['bbox1'][idx_in_batch].to(device)  # [2]

		depth2 = batch['depth2'][idx_in_batch].to(device)
		intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
		pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
		bbox2 = batch['bbox2'][idx_in_batch].to(device)

		# Network output
		dense_features1 = output['dense_features1'][idx_in_batch]
		c, h1, w1 = dense_features1.size()
		scores1 = output['scores1'][idx_in_batch].view(-1)

		dense_features2 = output['dense_features2'][idx_in_batch]
		_, h2, w2 = dense_features2.size()
		scores2 = output['scores2'][idx_in_batch]


		all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
		descriptors1 = all_descriptors1

		all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)

		# Warp the positions from image 1 to image 2
		fmap_pos1 = grid_positions(h1, w1, device)

		hOrig, wOrig = int(batch['image1'].shape[2]/8), int(batch['image1'].shape[3]/8)
		fmap_pos1Orig = grid_positions(hOrig, wOrig, device)
		pos1 = upscale_positions(fmap_pos1Orig, scaling_steps=scaling_steps)

		try:
			pos1, pos2, ids = warp(
				pos1,
				depth1, intrinsics1, pose1, bbox1,
				depth2, intrinsics2, pose2, bbox2
			)
		except EmptyTensorError:
			continue

		H1 = output['H1'][idx_in_batch]
		H2 = output['H2'][idx_in_batch]

		try:
			pos1, pos2 = homoAlign(pos1, pos2, H1, H2, device)
		except IndexError:
			continue

		ids = idsAlign(pos1, device)

		img_warp1 = tgm.warp_perspective(batch['image1'].to(device), H1, dsize=(400, 400))
		img_warp2 = tgm.warp_perspective(batch['image2'].to(device), H2, dsize=(400, 400))

		# drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output)
		# exit(1)

		fmap_pos1 = fmap_pos1[:, ids]
		descriptors1 = descriptors1[:, ids]
		scores1 = scores1[ids]

		# Skip the pair if not enough GT correspondences are available
		if ids.size(0) < 128:
			continue

		# Descriptors at the corresponding positions
		fmap_pos2 = torch.round(
			downscale_positions(pos2, scaling_steps=scaling_steps)
		).long()

		descriptors2 = F.normalize(
			dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
			dim=0
		)

		positive_distance = 2 - 2 * (
			descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
		).squeeze()

		all_fmap_pos2 = grid_positions(h2, w2, device)
		position_distance = torch.max(
			torch.abs(
				fmap_pos2.unsqueeze(2).float() -
				all_fmap_pos2.unsqueeze(1)
			),
			dim=0
		)[0]
		is_out_of_safe_radius = position_distance > safe_radius
		distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
		# negative_distance2 = torch.min(
		# 	distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
		# 	dim=1
		# )[0]

		negative_distance2 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

		all_fmap_pos1 = grid_positions(h1, w1, device)
		position_distance = torch.max(
			torch.abs(
				fmap_pos1.unsqueeze(2).float() -
				all_fmap_pos1.unsqueeze(1)
			),
			dim=0
		)[0]
		is_out_of_safe_radius = position_distance > safe_radius
		distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
		# negative_distance1 = torch.min(
		# 	distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
		# 	dim=1
		# )[0]

		negative_distance1 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

		diff = positive_distance - torch.min(
			negative_distance1, negative_distance2
		)

		# if(batch['batch_idx']%20 == 0):
		# 	print("positive_distance: {} | negative_distance: {}".format(positive_distance, torch.min(negative_distance1, negative_distance2)))

		scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]

		loss = loss + (
			torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
			(torch.sum(scores1 * scores2) )
		)

		has_grad = True
		n_valid_samples += 1

		if plot and batch['batch_idx'] % batch['log_interval'] == 0:
			# drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=True)
			drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output, save=True)

	if not has_grad:
		raise NoGradientError

	loss = loss / (n_valid_samples )

	return loss
示例#11
0
    def forward(self, img1, img2):
        img_warp1 = tgm.warp_perspective(img1, self.H1, dsize=(400, 400))
        img_warp2 = tgm.warp_perspective(img2, self.H2, dsize=(400, 400))

        return img_warp1, img_warp2, self.H1, self.H2
 def forward(self, x, M):
     x = tgm.warp_perspective(x,
                              M,
                              dsize=(self.output_height, self.output_width))
     return x
示例#13
0
import utils
from my_perspective_transform import spatial_transformer_network

print(torch.cuda.is_available())

original_image = cv2.imread('test_im_hidden.png')
original_image = torch.cuda.FloatTensor(original_image)
original_image = original_image.unsqueeze(0)

Ms = utils.opencv_get_rand_transform_matrix(400, 400, 50, 1)
print(Ms)
Ms = torch.cuda.FloatTensor(Ms)
original_image = original_image.permute(0, 3, 1, 2)
transform_image = torchgeometry.warp_perspective(original_image,
                                                 Ms[:, 1, :, :],
                                                 dsize=(400, 400),
                                                 flags='bilinear')
transform_image = transform_image.permute(0, 2, 3, 1)
print(transform_image.size())
transform_image = transform_image.cpu().detach().numpy()
transform_image = transform_image.astype(np.uint8)
transform_image = transform_image[0, :, :, :]
cv2.imwrite('transform1.png', transform_image)

original_image = cv2.imread('transform1.png')
original_image = torch.cuda.FloatTensor(original_image)
original_image = original_image.unsqueeze(0)
original_image = original_image.permute(0, 3, 1, 2)
transform_image = spatial_transformer_network(original_image, Ms[:, 1, :, :])
transform_image = transform_image.permute(0, 2, 3, 1)
transform_image = transform_image.cpu().detach().numpy()
示例#14
0
def build_model(encoder, decoder, discriminator, secret_input, image_input, l2_edge_gain, 
                borders, secret_size, M, loss_scales, yuv_scales, args, global_step, writer):
    test_transform = transform_net(image_input, args, global_step)

    input_warped = torchgeometry.warp_perspective(image_input, M[:, 1, :, :], dsize=(400, 400), flags='bilinear')
    mask_warped = torchgeometry.warp_perspective(torch.ones_like(input_warped), M[:, 1, :, :], dsize=(400, 400), flags='bilinear')
    input_warped += (1 - mask_warped) * image_input

    residual_warped = encoder((secret_input, input_warped))
    encoded_warped = residual_warped + input_warped

    residual = torchgeometry.warp_perspective(residual_warped, M[:, 0, :, :], dsize=(400, 400), flags='bilinear')

    if borders == 'no_edge':
        encoded_image = image_input + residual
    elif borders == 'black':
        encoded_image = residual_warped + input_warped
        encoded_image = torchgeometry.warp_perspective(encoded_image, M[:, 0, :, :], dsize=(400, 400), flags='bilinear')
        input_unwarped = torchgeometry.warp_perspective(image_input, M[:, 0, :, :], dsize=(400, 400), flags='bilinear')
    elif borders.startswith('random'):
        mask  = torchgeometry.warp_perspective(torch.ones_like(residual), M[:, 0, :, :], dsize=(400, 400), flags='bilinear')
        encoded_image = residual_warped + input_unwarped
        encoded_image = torchgeometry.warp_perspective(encoded_image, M[:, 0, :, :], dsize=(400, 400), flags='bilinear')
        input_unwarped = torchgeometry.warp_perspective(input_warped, M[:, 0, :, :], dsize=(400, 400), flags='bilinear')
        ch = 3 if borders.endswith('rgb') else 1
        encoded_image += (1-mask) * torch.ones_like(residual) * torch.rand([ch])
    elif borders == 'white':
        mask = torchgeometry.warp_perspective(torch.ones_like(residual), M[:, 0, :, :], dsize=(400, 400), flags='bilinear')
        encoded_image = residual_warped + input_warped
        encoded_image = torchgeometry.warp_perspective(encoded_image, M[:, 0, :, :], dsize=(400, 400), flags='bilinear')
        input_unwarped = torchgeometry.warp_perspective(input_warped, M[:, 0, :, :], dsize=(400, 400), flags='bilinear')
        encoded_image += (1 - mask) * torch.ones_like(residual)
    elif borders == 'image':
        mask = torchgeometry.warp_perspective(torch.ones_like(residual), M[:, 0, :, :], dsize=(400, 400), flags='bilinear')
        encoded_image = residual_warped + input_warped
        encoded_image = torchgeometry.warp_perspective(encoded_image, M[:, 0, :, :], dsize=(400, 400), flags='bilinear')
        encoded_image += (1-mask) * torch.roll(image_input, 1, 0)        

    if borders == 'no_edge':
        D_output_real, _ = discriminator(image_input)
        D_output_fake, D_heatmap = discriminator(encoded_image)
    else:
        D_output_real, _ = discriminator(input_warped)
        D_output_fake, D_heatmap = discriminator(encoded_warped)
        
    transformed_image = transform_net(encoded_image, args, global_step)
    decoded_secret = decoder(transformed_image)
    bit_acc, str_acc = get_secret_acc(secret_input, decoded_secret)
    
    lpips_loss = torch.mean(lpips(image_input, encoded_image))

    cross_entropy = nn.BCELoss()
    if args.cuda:
        cross_entropy = cross_entropy.cuda()
    secret_loss = cross_entropy(decoded_secret, secret_input)

    size = (int(image_input.shape[2]), int(image_input.shape[3]))
    gain = 10
    falloff_speed = 4
    falloff_im = np.ones(size)
    for i in range(int(falloff_im.shape[0]/falloff_speed)): #for i in range 100
        falloff_im[-i,:] *= (np.cos(4*np.pi*i/size[0]+np.pi)+1)/2# [cos[(4*pi*i/400)+pi] + 1]/2
        falloff_im[i,:] *= (np.cos(4*np.pi*i/size[0]+np.pi)+1)/2 # [cos[(4*pi*i/400)+pi] + 1]/2
    for j in range(int(falloff_im.shape[1]/falloff_speed)):
        falloff_im[:,-j] *= (np.cos(4*np.pi*j/size[0]+np.pi)+1)/2
        falloff_im[:,j] *= (np.cos(4*np.pi*j/size[0]+np.pi)+1)/2
    falloff_im = 1-falloff_im
    falloff_im = torch.from_numpy(falloff_im).float()
    if args.cuda:
        falloff_im = falloff_im.cuda()
    falloff_im *= l2_edge_gain

    encoded_image_yuv = color.rgb_to_yuv(encoded_image)
    image_input_yuv = color.rgb_to_yuv(image_input)          
    im_diff = encoded_image_yuv - image_input_yuv
    im_diff += im_diff * falloff_im.unsqueeze_(0)
    yuv_loss = torch.mean((im_diff)**2, axis=[0, 2, 3])
    yuv_scales = torch.Tensor(yuv_scales)
    if args.cuda:
        yuv_scales = yuv_scales.cuda()
    image_loss = torch.dot(yuv_loss, yuv_scales)

    D_loss = D_output_real - D_output_fake
    G_loss = D_output_fake
    loss = loss_scales[0] * image_loss + loss_scales[1] * lpips_loss + loss_scales[2] * secret_loss
    if not args.no_gan:
        loss += loss_scales[3] * G_loss

    writer.add_scalar('loss/image_loss', image_loss, global_step)
    writer.add_scalar('loss/lpips_loss', lpips_loss, global_step)
    writer.add_scalar('loss/secret_loss', secret_loss, global_step)
    writer.add_scalar('loss/G_loss', G_loss, global_step)
    writer.add_scalar('loss/loss', loss, global_step)

    writer.add_scalar('metric/bit_acc', bit_acc, global_step)
    writer.add_scalar('metric/str_acc', str_acc, global_step)
    if global_step % 20 == 0:
        writer.add_image('input/image_input', image_input[0], global_step)
        writer.add_image('input/image_warped', input_warped[0], global_step)
        writer.add_image('encoded/encoded_warped', encoded_warped[0], global_step)
        writer.add_image('encoded/residual_warped', residual_warped[0] + 0.5, global_step)
        writer.add_image('encoded/encoded_image', encoded_image[0], global_step)
        writer.add_image('transformed/transformed_image', transformed_image[0], global_step)
        writer.add_image('transformed/test', test_transform[0], global_step)

    return loss, secret_loss, D_loss, bit_acc, str_acc
示例#15
0
def loss_function(model,
                  batch,
                  device,
                  margin=1,
                  safe_radius=4,
                  scaling_steps=3,
                  plot=False):
    output = model({
        'image1': batch['image1'].to(device),
        'image2': batch['image2'].to(device)
    })

    loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
    has_grad = False

    n_valid_samples = 0
    for idx_in_batch in range(batch['image1'].size(0)):
        # Annotations
        depth1 = batch['depth1'][idx_in_batch].to(device)  # [h1, w1]
        intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device)  # [3, 3]
        pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device)  # [4, 4]
        bbox1 = batch['bbox1'][idx_in_batch].to(device)  # [2]

        depth2 = batch['depth2'][idx_in_batch].to(device)
        intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
        pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
        bbox2 = batch['bbox2'][idx_in_batch].to(device)

        # Network output
        dense_features1 = output['dense_features1'][idx_in_batch]
        c, h1, w1 = dense_features1.size()
        scores1 = output['scores1'][idx_in_batch].view(-1)

        dense_features2 = output['dense_features2'][idx_in_batch]
        _, h2, w2 = dense_features2.size()
        scores2 = output['scores2'][idx_in_batch]

        all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
        descriptors1 = all_descriptors1

        all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)

        # Warp the positions from image 1 to image 2
        fmap_pos1 = grid_positions(h1, w1, device)

        hOrig, wOrig = int(batch['image1'].shape[2] / 8), int(
            batch['image1'].shape[3] / 8)
        fmap_pos1Orig = grid_positions(hOrig, wOrig, device)
        pos1 = upscale_positions(fmap_pos1Orig, scaling_steps=scaling_steps)

        # SIFT Feature Detection

        imgNp1 = imshow_image(batch['image1'][idx_in_batch].cpu().numpy(),
                              preprocessing=batch['preprocessing'])
        imgNp1 = cv2.cvtColor(imgNp1, cv2.COLOR_BGR2RGB)
        # surf = cv2.xfeatures2d.SIFT_create()
        surf = cv2.xfeatures2d.SURF_create(100)
        # surf = cv2.ORB_create()

        kp = surf.detect(imgNp1, None)
        keyP = [(kp[i].pt) for i in range(len(kp))]
        keyP = np.asarray(keyP).T
        keyP[[0, 1]] = keyP[[1, 0]]
        keyP = np.floor(keyP) + 0.5

        pos1 = torch.from_numpy(keyP).to(pos1.device).float()

        try:
            pos1, pos2, ids = warp(pos1, depth1, intrinsics1, pose1, bbox1,
                                   depth2, intrinsics2, pose2, bbox2)
        except EmptyTensorError:
            continue

        ids = idsAlign(pos1, device, h1, w1)

        # cv2.drawKeypoints(imgNp1, kp, imgNp1)
        # cv2.imshow('Keypoints', imgNp1)
        # cv2.waitKey(0)

        # drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=False)

        # exit(1)

        # Top view homography adjustment

        H1 = output['H1'][idx_in_batch]
        H2 = output['H2'][idx_in_batch]

        try:
            pos1, pos2 = homoAlign(pos1, pos2, H1, H2, device)
        except IndexError:
            continue

        ids = idsAlign(pos1, device, h1, w1)

        img_warp1 = tgm.warp_perspective(batch['image1'].to(device),
                                         H1,
                                         dsize=(400, 400))
        img_warp2 = tgm.warp_perspective(batch['image2'].to(device),
                                         H2,
                                         dsize=(400, 400))

        # drawTraining(img_warp1, img_warp2, pos1, pos2, batch, idx_in_batch, output)
        # exit(1)

        fmap_pos1 = fmap_pos1[:, ids]
        descriptors1 = descriptors1[:, ids]
        scores1 = scores1[ids]

        # Skip the pair if not enough GT correspondences are available
        if ids.size(0) < 128:
            print(ids.size(0))
            continue

        # Descriptors at the corresponding positions
        fmap_pos2 = torch.round(
            downscale_positions(pos2, scaling_steps=scaling_steps)).long()

        descriptors2 = F.normalize(dense_features2[:, fmap_pos2[0, :],
                                                   fmap_pos2[1, :]],
                                   dim=0)

        positive_distance = 2 - 2 * (descriptors1.t().unsqueeze(
            1) @ descriptors2.t().unsqueeze(2)).squeeze()

        # positive_distance = getPositiveDistance(descriptors1, descriptors2)

        all_fmap_pos2 = grid_positions(h2, w2, device)
        position_distance = torch.max(torch.abs(
            fmap_pos2.unsqueeze(2).float() - all_fmap_pos2.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius

        distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
        # distance_matrix = getDistanceMatrix(descriptors1, all_descriptors2)

        negative_distance2 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        # negative_distance2 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

        all_fmap_pos1 = grid_positions(h1, w1, device)
        position_distance = torch.max(torch.abs(
            fmap_pos1.unsqueeze(2).float() - all_fmap_pos1.unsqueeze(1)),
                                      dim=0)[0]
        is_out_of_safe_radius = position_distance > safe_radius

        distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
        # distance_matrix = getDistanceMatrix(descriptors2, all_descriptors1)

        negative_distance1 = torch.min(
            distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
            dim=1)[0]

        # negative_distance1 = semiHardMine(distance_matrix, is_out_of_safe_radius, positive_distance, margin)

        diff = positive_distance - torch.min(negative_distance1,
                                             negative_distance2)

        scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]

        loss = loss + (torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
                       (torch.sum(scores1 * scores2)))

        has_grad = True
        n_valid_samples += 1

        if plot and batch['batch_idx'] % batch['log_interval'] == 0:
            # drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=True)
            drawTraining(img_warp1,
                         img_warp2,
                         pos1,
                         pos2,
                         batch,
                         idx_in_batch,
                         output,
                         save=True)

    if not has_grad:
        raise NoGradientError

    loss = loss / (n_valid_samples)

    return loss