Пример #1
0
    def forward(self, x):

        B, _, H, W = x.shape

        x = self.encoderK(x)
        xK = self.decoderK(x)

        score = xK[('score')]
        center_shift = xK[('location')]
        feat = xK[('feature')]

        _, _, Hc, Wc = score.shape

        ############ Remove border for score ##############
        border_mask = torch.ones(B, Hc, Wc)
        border_mask[:, 0] = 0
        border_mask[:, Hc - 1] = 0
        border_mask[:, :, 0] = 0
        border_mask[:, :, Wc - 1] = 0
        border_mask = border_mask.unsqueeze(1)
        score = score * border_mask.to(score.device)

        ############ Remap coordinate ##############
        step = (self.cell - 1) / 2.
        center_base = image_grid(B,
                                 Hc,
                                 Wc,
                                 dtype=center_shift.dtype,
                                 device=center_shift.device,
                                 ones=False,
                                 normalized=False).mul(self.cell) + step

        coord_un = center_base.add(center_shift.mul(self.cross_ratio * step))
        coord = coord_un.clone()
        coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W - 1)
        coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H - 1)

        ############ Sampling feature ##############
        if self.training is False:
            coord_norm = coord[:, :2].clone()
            coord_norm[:, 0] = (coord_norm[:, 0] / (float(W - 1) / 2.)) - 1.
            coord_norm[:, 1] = (coord_norm[:, 1] / (float(H - 1) / 2.)) - 1.
            coord_norm = coord_norm.permute(0, 2, 3, 1)

            feat = torch.nn.functional.grid_sample(feat,
                                                   coord_norm,
                                                   align_corners=False)

            dn = torch.norm(feat, p=2, dim=1)  # Compute the norm.
            feat = feat.div(torch.unsqueeze(dn,
                                            1))  # Divide by norm to normalize.

        return score, coord, feat
Пример #2
0
def ha_augment_sample(data,
                      jitter_paramters=[0.5, 0.5, 0.2, 0.05],
                      patch_ratio=0.7,
                      scaling_amplitude=0.2,
                      max_angle=pi / 4):
    """Apply Homography Adaptation image augmentation."""
    input_img = data['image'].unsqueeze(0)
    _, _, H, W = input_img.shape
    device = input_img.device

    homography = torch.from_numpy(
        sample_homography([H, W],
                          patch_ratio=patch_ratio,
                          scaling_amplitude=scaling_amplitude,
                          max_angle=max_angle)).float().to(device)
    homography_inv = torch.inverse(homography)

    source = image_grid(1,
                        H,
                        W,
                        dtype=input_img.dtype,
                        device=device,
                        ones=False,
                        normalized=True).clone().permute(0, 2, 3, 1)

    target_warped = warp_homography(source, homography)
    img_warp = torch.nn.functional.grid_sample(input_img, target_warped)

    color_order = [0, 1, 2]
    if np.random.rand() > 0.5:
        random.shuffle(color_order)

    to_gray = False
    if np.random.rand() > 0.5:
        to_gray = True

    input_img = non_spatial_augmentation(input_img,
                                         jitter_paramters=jitter_paramters,
                                         color_order=color_order,
                                         to_gray=to_gray)
    img_warp = non_spatial_augmentation(img_warp,
                                        jitter_paramters=jitter_paramters,
                                        color_order=color_order,
                                        to_gray=to_gray)

    data['image'] = input_img.squeeze()
    data['image_aug'] = img_warp.squeeze()
    data['homography'] = homography
    data['homography_inv'] = homography_inv
    return data
Пример #3
0
def ha_augment_sample(data, jitter_paramters=[0.5, 0.5, 0.2, 0.05], patch_ratio=0.7, scaling_amplitude=0.2, max_angle=pi/4):
    """Apply Homography Adaptation image augmentation."""
    target_img = data['image'].unsqueeze(0)
    _, _, H, W = target_img.shape
    device = target_img.device

    # Generate homography (warps source to target)
    homography = sample_homography([H, W],
        patch_ratio=patch_ratio, 
        scaling_amplitude=scaling_amplitude, 
        max_angle=max_angle)
    homography = torch.from_numpy(homography).float().to(device)

    source_grid = image_grid(1, H, W,
                    dtype=target_img.dtype,
                    device=device,
                    ones=False, normalized=True).clone().permute(0, 2, 3, 1)

    source_warped = warp_homography(source_grid, homography)
    source_img = torch.nn.functional.grid_sample(target_img, source_warped, align_corners=True)

    color_order = [0,1,2]
    if np.random.rand() > 0.5:
        random.shuffle(color_order)

    to_gray = False
    if np.random.rand() > 0.5:
        to_gray = True

    target_img = non_spatial_augmentation(target_img, jitter_paramters=jitter_paramters, color_order=color_order, to_gray=to_gray)
    source_img = non_spatial_augmentation(source_img, jitter_paramters=jitter_paramters, color_order=color_order, to_gray=to_gray)

    data['image'] = target_img.squeeze()
    data['image_aug'] = source_img.squeeze()
    data['homography'] = homography
    return data
Пример #4
0
def ha_augment_sample(data,
                      training_mode,
                      non_spatial_aug,
                      jitter_paramters=[0.5, 0.5, 0.2, 0.05],
                      patch_ratio=0.7,
                      scaling_amplitude=0.2,
                      max_angle=pi / 4):
    """Apply Homography Adaptation or real world viewpoint adaptation image augmentation."""
    target_img = data['image_target'].unsqueeze(0)
    source_img = data['image_source'].unsqueeze(0)
    # only apply H.A for corresponding modes
    if training_mode=='HA' or training_mode=='scene+HA' or training_mode=='cam+HA'\
         or training_mode=='con+HA' or training_mode=='HA_wo_sp':
        _, _, H, W = target_img.shape
        device = target_img.device

        # Generate homography (warps source to target)
        homography = sample_homography([H, W],
                                       patch_ratio=patch_ratio,
                                       scaling_amplitude=scaling_amplitude,
                                       max_angle=max_angle)
        homography = torch.from_numpy(homography).float().to(device)

        target_grid = image_grid(1,
                                 H,
                                 W,
                                 dtype=target_img.dtype,
                                 device=device,
                                 ones=False,
                                 normalized=True).clone().permute(0, 2, 3, 1)

        target_warped = warp_homography(target_grid, homography)
        target_img = torch.nn.functional.grid_sample(target_img,
                                                     target_warped,
                                                     align_corners=True)
        data['homography'] = torch.inverse(homography)
    if non_spatial_aug:
        color_order = [0, 1, 2]
        if np.random.rand() > 0.5:
            random.shuffle(color_order)

        to_gray = False
        if np.random.rand() > 0.5:
            to_gray = True
        target_img = non_spatial_augmentation(
            target_img,
            jitter_paramters=jitter_paramters,
            color_order=color_order,
            to_gray=to_gray)
        source_img = non_spatial_augmentation(
            source_img,
            jitter_paramters=jitter_paramters,
            color_order=color_order,
            to_gray=to_gray)

    data['image'] = target_img.squeeze()
    data['image_aug'] = source_img.squeeze()
    return data


# def ha_augment_sample(data, jitter_paramters=[0.5, 0.5, 0.2, 0.05], patch_ratio=0.7, scaling_amplitude=0.2, max_angle=pi/4):
#     """Apply Homography Adaptation image augmentation."""
#     target_img = data['image_target'].unsqueeze(0)
#     _, _, H, W = target_img.shape
#     device = target_img.device

#     # Generate homography (warps source to target)
#     homography = sample_homography([H, W],
#         patch_ratio=patch_ratio,
#         scaling_amplitude=scaling_amplitude,
#         max_angle=max_angle)
#     homography = torch.from_numpy(homography).float().to(device)

#     source_grid = image_grid(1, H, W,
#                     dtype=target_img.dtype,
#                     device=device,
#                     ones=False, normalized=True).clone().permute(0, 2, 3, 1)

#     source_warped = warp_homography(source_grid, homography)
#     source_img = torch.nn.functional.grid_sample(target_img, source_warped, align_corners=True)

#     color_order = [0,1,2]
#     if np.random.rand() > 0.5:
#         random.shuffle(color_order)

#     to_gray = False
#     if np.random.rand() > 0.5:
#         to_gray = True

#     target_img = non_spatial_augmentation(target_img, jitter_paramters=jitter_paramters, color_order=color_order, to_gray=to_gray)
#     source_img = non_spatial_augmentation(source_img, jitter_paramters=jitter_paramters, color_order=color_order, to_gray=to_gray)

#     data['image'] = target_img.squeeze()
#     data['image_aug'] = source_img.squeeze()
#     data['homography'] = homography
#     return data
Пример #5
0
    def forward(self, x):
        """
        Processes a batch of images.

        Parameters
        ----------
        x : torch.Tensor
            Batch of input images (B, 3, H, W)

        Returns
        -------
        score : torch.Tensor
            Score map (B, 1, H_out, W_out)
        coord: torch.Tensor
            Keypoint coordinates (B, 2, H_out, W_out)
        feat: torch.Tensor
            Keypoint descriptors (B, 256, H_out, W_out)
        """
        B, _, H, W = x.shape

        x = self.relu(self.conv1a(x))
        x = self.relu(self.conv1b(x))
        if self.dropout:
            x = self.dropout(x)
        x = self.pool(x)
        x = self.relu(self.conv2a(x))
        x = self.relu(self.conv2b(x))
        if self.dropout:
            x = self.dropout(x)
        x = self.pool(x)
        x = self.relu(self.conv3a(x))
        skip = self.relu(self.conv3b(x))
        if self.dropout:
            skip = self.dropout(skip)
        x = self.pool(skip)
        x = self.relu(self.conv4a(x))
        x = self.relu(self.conv4b(x))
        if self.dropout:
            x = self.dropout(x)

        B, _, Hc, Wc = x.shape

        score = self.relu(self.convDa(x))
        if self.dropout:
            score = self.dropout(score)
        score = self.convDb(score).sigmoid()

        border_mask = torch.ones(B, Hc, Wc)
        border_mask[:, 0] = 0
        border_mask[:, Hc - 1] = 0
        border_mask[:, :, 0] = 0
        border_mask[:, :, Wc - 1] = 0
        border_mask = border_mask.unsqueeze(1)
        score = score * border_mask.to(score.device)

        center_shift = self.relu(self.convPa(x))
        if self.dropout:
            center_shift = self.dropout(center_shift)
        center_shift = self.convPb(center_shift).tanh()

        step = (self.cell - 1) / 2.
        center_base = image_grid(B,
                                 Hc,
                                 Wc,
                                 dtype=center_shift.dtype,
                                 device=center_shift.device,
                                 ones=False,
                                 normalized=False).mul(self.cell) + step

        coord_un = center_base.add(center_shift.mul(self.cross_ratio * step))
        coord = coord_un.clone()
        coord[:, 0] = torch.clamp(coord_un[:, 0], min=0, max=W - 1)
        coord[:, 1] = torch.clamp(coord_un[:, 1], min=0, max=H - 1)

        feat = self.relu(self.convFa(x))
        if self.dropout:
            feat = self.dropout(feat)
        if self.do_upsample:
            feat = self.upsample(self.convFb(feat))
            feat = torch.cat([feat, skip], dim=1)
        feat = self.relu(self.convFaa(feat))
        feat = self.convFbb(feat)

        if self.training is False:
            coord_norm = coord[:, :2].clone()
            coord_norm[:, 0] = (coord_norm[:, 0] / (float(W - 1) / 2.)) - 1.
            coord_norm[:, 1] = (coord_norm[:, 1] / (float(H - 1) / 2.)) - 1.
            coord_norm = coord_norm.permute(0, 2, 3, 1)

            feat = torch.nn.functional.grid_sample(feat,
                                                   coord_norm,
                                                   align_corners=True)

            dn = torch.norm(feat, p=2, dim=1)  # Compute the norm.
            feat = feat.div(torch.unsqueeze(dn,
                                            1))  # Divide by norm to normalize.
        return score, coord, feat