Beispiel #1
0
    def stitch_pair(
        self,
        images_left: torch.Tensor,
        images_right: torch.Tensor,
        mask_left: Optional[torch.Tensor] = None,
        mask_right: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Compute the transformed images
        input_dict: Dict[str, torch.Tensor] = self.preprocess(
            images_left, images_right)
        out_shape: Tuple[int,
                         int] = (images_left.shape[-2], images_left.shape[-1] +
                                 images_right.shape[-1])
        correspondences: dict = self.on_matcher(input_dict)
        h**o: torch.Tensor = self.estimate_transform(**correspondences)
        src_img = warp_perspective(images_right, h**o, out_shape)
        dst_img = torch.cat(
            [images_left, torch.zeros_like(images_right)], dim=-1)

        # Compute the transformed masks
        if mask_left is None:
            mask_left = torch.ones_like(images_left)
        if mask_right is None:
            mask_right = torch.ones_like(images_right)
        # 'nearest' to ensure no floating points in the mask
        src_mask = warp_perspective(mask_right,
                                    h**o,
                                    out_shape,
                                    mode='nearest')
        dst_mask = torch.cat([mask_left, torch.zeros_like(mask_right)], dim=-1)
        return self.blend_image(src_img, dst_img,
                                src_mask), (dst_mask + src_mask).bool().to(
                                    src_mask.dtype)
Beispiel #2
0
    def training_step(self, batch, batch_nb):
        if self.supervised or len(batch) == 2:
            x, y = batch
            y_hat = self.forward(x)
            l = nn.BCEWithLogitsLoss()
            loss = 10 * l(y_hat.squeeze(1), y)
            # fig, ax= plt.subplots(1,x.shape[0])

            # for i in range(x.shape[0]):
            #     img1 = (x[i,0,...].cpu().detach().numpy())
            #     gt1 = (y[i,...].cpu().detach().numpy())
            #     ax[i].imshow(img1)
            #     ax[i].imshow(gt1,alpha=0.3)

            # self.logger.experiment.log_image(
            #     'train images',
            #     fig,
            #     description='x1,x2')

        else:
            #angle=d, translate=t, scale=sc, shear=sh
            x1, x2, affine_matrix = batch
            y1 = self.forward(x1)
            y1_transformed = K.warp_perspective(y1, affine_matrix.squeeze(0),
                                                y1[0, 0, ...].shape).squeeze(1)
            # y1_transformed=kornia.compute_affine_transformation(y1,affine_matrix.squeeze(0))
            y_teach1 = 1. * (self.teacher.forward(x1))
            # y_teach1=torch.from_numpy(np.array(Fvision.affine(Fvision.to_pil_image(y_teach1.squeeze(1).cpu()),angle=d, translate=t, scale=sc, shear=sh))).cuda()
            y_teach2 = K.warp_perspective(y_teach1, affine_matrix.squeeze(0),
                                          y_teach1[0, 0, ...].shape)
            # y_teach2=kornia.compute_affine_transformation(y_teach1,affine_matrix.squeeze(0))
            y_teach1 = 1. * (y_teach1)
            y2 = self.forward(x2)
            loss = self.criterion(
                y2.squeeze(1), y_teach2.squeeze(1)
            )  #+self.criterion(y1.squeeze(1),y_teach1.squeeze(1))+self.criterion(y1_transformed.squeeze(1),1.*(y2>0.5).squeeze(1))
            # self.teacher.load_state_dict(teacher_params)
            # print(len(teacher_params.keys()))
            # epo=self.current_epoch
            # with open(f"models_weights/epoch_{epo}.json",'w+') as f:
            #     json.dump((str(self.teacher.state_dict())),f)
            # img1 = (x1[0,0,...].cpu().detach().numpy())
            # img2= (x1[0,0,...].cpu().detach().numpy())
            # gt1 = (y1[0,0,...].cpu().detach().numpy()>0.5)
            # gt2= (y_teach1[0,0,...].cpu().detach().numpy()>0.5)
            # fig, (ax1,ax2)= plt.subplots(1,2)
            # ax1.imshow(img1)
            # ax1.imshow(gt1,alpha=0.3)
            # ax2.imshow(img2)
            # ax2.imshow(gt2,alpha=0.3)

            # self.logger.experiment.log_image(
            #     'train images',
            #     fig,
            #     description='x1,x2')
            # x,y=batch
            # y_hat = self.forward(x)
            # loss = self.criterion(y_hat.squeeze(1),y)
        self.logger.experiment.log_metric('train_loss', loss)
        return loss
Beispiel #3
0
    def track_next_frame(self, x: torch.Tensor) -> Tuple[torch.Tensor, bool]:
        """The frame `x` is prewarped according to the previous frame homography, matched with fast_matcher
        verified with ransac."""
        if self.previous_homography is not None:  # mypy, shut up
            Hwarp = self.previous_homography.clone()[None]
        # make a bit of border for safety
        Hwarp[:, 0:2, 0:2] = Hwarp[:, 0:2, 0:2] / 0.8
        Hwarp[:, 0:2, 2] -= 10.0
        Hinv = torch.inverse(Hwarp)
        h, w = self.target.shape[2:]
        frame_warped = warp_perspective(x, Hinv, (h, w))
        input_dict: Dict[str, torch.Tensor] = {
            "image0": self.target,
            "image1": frame_warped
        }
        for k, v in self.target_fast_representation.items():
            input_dict[f'{k}0'] = v

        match_dict = self.fast_matcher(input_dict)
        keypoints0 = match_dict['keypoints0'][match_dict['batch_indexes'] == 0]
        keypoints1 = match_dict['keypoints1'][match_dict['batch_indexes'] == 0]
        keypoints1 = transform_points(Hwarp, keypoints1)

        if len(keypoints0) < self.minimum_inliers_num:
            self.reset_tracking()
            return self.no_match()
        H, inliers = self.ransac(keypoints0, keypoints1)
        if inliers.sum().item() < self.minimum_inliers_num:
            self.reset_tracking()
            return self.no_match()
        self.previous_homography = H.clone()
        return H, True
Beispiel #4
0
 def apply_transform(self,
                     input: Tensor,
                     params: Dict[str, Tensor],
                     transform: Optional[Tensor] = None) -> Tensor:
     _, _, height, width = input.shape
     transform = cast(Tensor, transform)
     return warp_perspective(
         input,
         transform,
         (height, width),
         mode=self.flags["resample"].name.lower(),
         align_corners=self.flags["align_corners"],
     )
Beispiel #5
0
def apply_perspective(input: torch.Tensor,
                      params: Dict[str, torch.Tensor],
                      return_transform: bool = False) -> UnionType:
    r"""Perform perspective transform of the given torch.Tensor or batch of tensors.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape BxCxHxW.
        start_points (torch.Tensor): Tensor containing [top-left, top-right, bottom-right,
        bottom-left] of the orignal image with shape Bx4x2.
        end_points (torch.Tensor): Tensor containing [top-left, top-right, bottom-right,
        bottom-left] of the transformed image with shape Bx4x2.
        return_transform (bool): if ``True`` return the matrix describing the transformation
        applied to each. Default: False.

    Returns:
        torch.Tensor: Perspectively transformed tensor.
    """

    if not torch.is_tensor(input):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

    device: torch.device = input.device
    dtype: torch.dtype = input.dtype

    # arrange input data
    x_data: torch.Tensor = input.view(-1, *input.shape[-3:])

    batch_size, _, height, width = x_data.shape

    # compute the homography between the input points
    transform: torch.Tensor = get_perspective_transform(
        params['start_points'], params['end_points']).to(device, dtype)

    out_data: torch.Tensor = x_data.clone()

    # process valid samples
    mask = params['batch_prob'].to(device)

    # TODO: look for a workaround for this hack. In CUDA it fails when no elements found.

    if bool(mask.sum() > 0):
        # apply the computed transform
        height, width = x_data.shape[-2:]
        out_data[mask] = warp_perspective(x_data[mask], transform[mask],
                                          (height, width))

    if return_transform:
        return out_data.view_as(input), transform

    return out_data.view_as(input)
Beispiel #6
0
def _apply_affine(input: torch.Tensor,
                  params: Dict[str, torch.Tensor],
                  return_transform: bool = False) -> UnionType:
    if not torch.is_tensor(input):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
    r"""Random affine transformation of the image keeping center invariant
        Args:
            input (torch.Tensor): Tensor to be transformed with shape (*, C, H, W).
            degrees (float or tuple): Range of degrees to select from.
                If degrees is a number instead of sequence like (min, max), the range of degrees
                will be (-degrees, +degrees). Set to 0 to deactivate rotations.
            translate (tuple, optional): tuple of maximum absolute fraction for horizontal
                and vertical translations. For example translate=(a, b), then horizontal shift
                is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
                randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
            scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
                randomly sampled from the range a <= scale <= b. Will keep original scale by default.
            shear (sequence or float, optional): Range of degrees to select from.
                If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
                will be applied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
                range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
                a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
                Will not apply shear by default
            return_transform (bool): if ``True`` return the matrix describing the transformation
                applied to each. Default: False.
            mode (str): interpolation mode to calculate output values
                'bilinear' | 'nearest'. Default: 'bilinear'.
            padding_mode (str): padding mode for outside grid values
                'zeros' | 'border' | 'reflection'. Default: 'zeros'.
    """

    if not torch.is_tensor(input):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

    device: torch.device = input.device
    dtype: torch.dtype = input.dtype

    # arrange input data
    x_data: torch.Tensor = input.view(-1, *input.shape[-3:])

    height, width = x_data.shape[-2:]
    transform: torch.Tensor = params['transform'].to(device, dtype)

    out_data: torch.Tensor = warp_perspective(x_data, transform,
                                              (height, width))

    if return_transform:
        return out_data.view_as(input), transform

    return out_data.view_as(input)