def __getitem__(self, index):
 
        img_orig_PIL = self.get_PIL_img(index)
        img_orig_rgb = np.array(img_orig_PIL)

        # sample pixels
        uv_orig = self.get_random_pixels(width=img_orig_rgb.shape[1], height=img_orig_rgb.shape[0], num_samples=self.num_matching_pixels*4)
        keypoints_on_images_orig = [ia.KeypointsOnImage.from_coords_array(coords=uv_orig, shape=img_orig_rgb.shape)]
        
        # augment image to generate a pair of images for training and find correspondences
        aug_seq_det_a = self.aug_seq.to_deterministic()
        img_a_rgb = aug_seq_det_a.augment_image(img_orig_rgb)
        keypoints_a = aug_seq_det_a.augment_keypoints(keypoints_on_images_orig)[0]
        uv_a = np.rint(keypoints_a.get_coords_array()).astype(int)
        
        aug_seq_det_b = self.aug_seq.to_deterministic()
        img_b_rgb = aug_seq_det_b.augment_image(img_orig_rgb)
        keypoints_b = aug_seq_det_b.augment_keypoints(keypoints_on_images_orig)[0]
        uv_b = np.rint(keypoints_b.get_coords_array()).astype(int)
        
        # remove pixels outside frame
        within_a = np.logical_and(uv_a >= [0, 0], uv_a < [img_orig_rgb.shape[1], img_orig_rgb.shape[0]])
        within_b = np.logical_and(uv_b >= [0, 0], uv_b < [img_orig_rgb.shape[1], img_orig_rgb.shape[0]])
        valid_ids = np.where(np.logical_and(within_a, within_b).all(axis=1))[0]
        uv_orig = uv_orig[valid_ids][:self.num_matching_pixels]
        uv_a = uv_a[valid_ids][:self.num_matching_pixels]
        uv_b = uv_b[valid_ids][:self.num_matching_pixels]
        
        if self.debug:
            self.aug_seq.show_grid(img_orig_rgb, cols=4, rows=4)
            img_a_cv = cv2.cvtColor(img_a_rgb, cv2.COLOR_RGB2BGR)
            img_b_cv = cv2.cvtColor(img_b_rgb, cv2.COLOR_RGB2BGR)
            for pix_a, pix_b in zip(uv_a, uv_b):
                color = cv2.cvtColor(np.array([[[np.random.randint(256), 255, 255]]], dtype=np.uint8), cv2.COLOR_HSV2RGB)[0][0].astype(int)
                cv2.circle(img_a_cv, tuple(pix_a), 5, color, -1)
                cv2.circle(img_b_cv, tuple(pix_b), 5, color, -1)

            cv2.imshow('a', img_a_cv)
            cv2.imshow('b', img_b_cv)
            k = cv2.waitKey(0)

        # convert to torch Tensor
        uv_a = (torch.from_numpy(uv_a[:, 0]).type(torch.LongTensor), torch.from_numpy(uv_a[:, 1]).type(torch.LongTensor))
        uv_b = (torch.from_numpy(uv_b[:, 0]).type(torch.FloatTensor), torch.from_numpy(uv_b[:, 1]).type(torch.FloatTensor))

        # find non_correspondences
        uv_b_non_matches = correspondence_finder.create_non_correspondences(uv_b, img_b_rgb.shape, num_non_matches_per_match=self.num_non_matches_per_match)

        # convert PIL.Image to torch.FloatTensor
        image_a_rgb = self.rgb_image_to_tensor(img_a_rgb)
        image_b_rgb = self.rgb_image_to_tensor(img_b_rgb)

        image_width = img_orig_rgb.shape[1]
        image_height = img_orig_rgb.shape[0]
        matches_a = SaladDataset.flatten_uv_tensor(uv_a, image_width)
        matches_b = SaladDataset.flatten_uv_tensor(uv_b, image_width)

        uv_a_long, uv_b_non_matches_long = self.create_non_matches(uv_a, uv_b_non_matches, self.num_non_matches_per_match)

        non_matches_a = SaladDataset.flatten_uv_tensor(uv_a_long, image_width).squeeze(1)
        non_matches_b = SaladDataset.flatten_uv_tensor(uv_b_non_matches_long, image_width).squeeze(1)

        # 5 is custom data type for distortion image training
        return 5, image_a_rgb, image_b_rgb, matches_a, matches_b, non_matches_a, non_matches_b
예제 #2
0
    def __getitem__(self, index):
        """
        The method through which the dataset is accessed for training.

        The index param is not currently used, and instead each dataset[i] is the result of
        a random sampling over:
        - random scene
        - random rgbd frame from that scene
        - random rgbd frame (different enough pose) from that scene
        - various randomization in the match generation and non-match generation procedure

        returns a large amount of variables, separated by commas.

        0th return arg: the type of data sampled (this can be used as a flag for different loss functions)
        0th rtype: string

        1st, 2nd return args: image_a_rgb, image_b_rgb
        1st, 2nd rtype: 3-dimensional torch.FloatTensor of shape (image_height, image_width, 3)

        3rd, 4th return args: matches_a, matches_b
        3rd, 4th rtype: 1-dimensional torch.LongTensor of shape (num_matches)

        5th, 6th return args: non_matches_a, non_matches_b
        5th, 6th rtype: 1-dimensional torch.LongTensor of shape (num_non_matches)

        Return values 3,4,5,6 are all in the "single index" format for pixels. That is

        (u,v) --> n = u + image_width * v

        """

        # stores metadata about this data
        metadata = dict()


        # pick a scene
        scene_name = self.get_random_scene_name()
        metadata['scene_name'] = scene_name

        # image a
        image_a_idx = self.get_random_image_index(scene_name)
        image_a_rgb, image_a_depth, image_a_mask, image_a_pose = self.get_rgbd_mask_pose(scene_name, image_a_idx)

        metadata['image_a_idx'] = image_a_idx

        # image b
        image_b_idx = self.get_img_idx_with_different_pose(scene_name, image_a_pose, num_attempts=50)
        # image_b_idx = self.get_next_img_idx(scene_name, image_a_idx)
        metadata['image_b_idx'] = image_b_idx
        if image_b_idx is None:
            logging.info("no frame with sufficiently different pose found, returning")
            # TODO: return something cleaner than no-data
            image_a_rgb_tensor = self.rgb_image_to_tensor(image_a_rgb)
            return self.return_empty_data(image_a_rgb_tensor, image_a_rgb_tensor)

        image_b_rgb, image_b_depth, image_b_mask, image_b_pose = self.get_rgbd_mask_pose(scene_name, image_b_idx)

        image_a_depth_numpy = np.asarray(image_a_depth)
        image_b_depth_numpy = np.asarray(image_b_depth)

        # find correspondences
        uv_a, uv_b = correspondence_finder.batch_find_pixel_correspondences(image_a_depth_numpy, image_a_pose, 
                                                                           image_b_depth_numpy, image_b_pose, 
                                                                           num_attempts=self.num_matching_attempts, img_a_mask=np.asarray(image_a_mask))

        if uv_a is None:
            logging.info("no matches found, returning")
            image_a_rgb_tensor = self.rgb_image_to_tensor(image_a_rgb)
            return self.return_empty_data(image_a_rgb_tensor, image_a_rgb_tensor)

        if self.debug:
            # downsample so can plot
            num_matches_to_plot = 10
            indexes_to_keep = (torch.rand(num_matches_to_plot)*len(uv_a[0])).floor().type(torch.LongTensor)
            uv_a = (torch.index_select(uv_a[0], 0, indexes_to_keep), torch.index_select(uv_a[1], 0, indexes_to_keep))
            uv_b = (torch.index_select(uv_b[0], 0, indexes_to_keep), torch.index_select(uv_b[1], 0, indexes_to_keep))

        # data augmentation
        if self._domain_randomize:
            image_a_rgb = correspondence_augmentation.random_domain_randomize_background(image_a_rgb, image_a_mask)
            image_b_rgb = correspondence_augmentation.random_domain_randomize_background(image_b_rgb, image_b_mask)


        if not self.debug:
            [image_a_rgb], uv_a                 = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb], uv_a)
            [image_b_rgb, image_b_mask], uv_b   = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_mask], uv_b)
        else: # also mutate depth just for plotting
            [image_a_rgb, image_a_depth], uv_a               = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb, image_a_depth], uv_a)
            [image_b_rgb, image_b_depth, image_b_mask], uv_b = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_depth, image_b_mask], uv_b)
            image_a_depth_numpy = np.asarray(image_a_depth)
            image_b_depth_numpy = np.asarray(image_b_depth)

        # find non_correspondences

        if index%2:
            metadata['non_match_type'] = 'masked'
            logging.debug("masking non-matches")
            image_b_mask = torch.from_numpy(np.asarray(image_b_mask)).type(torch.FloatTensor)
        else:
            metadata['non_match_type'] = 'non_masked'
            logging.debug("not masking non-matches")
            image_b_mask = None
            
        image_b_shape = image_b_depth_numpy.shape
        image_width  = image_b_shape[1]
        image_height = image_b_shape[1]

        uv_b_non_matches = correspondence_finder.create_non_correspondences(uv_b, image_b_shape, 
            num_non_matches_per_match=self.num_non_matches_per_match, img_b_mask=image_b_mask)

        if self.debug:
            # only want to bring in plotting code if in debug mode
            import correspondence_plotter

            # Just show all images 
            uv_a_long = (torch.t(uv_a[0].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1), 
                     torch.t(uv_a[1].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1))
            uv_b_non_matches_long = (uv_b_non_matches[0].view(-1,1), uv_b_non_matches[1].view(-1,1) )
            
            # Show correspondences
            if uv_a is not None:
                fig, axes = correspondence_plotter.plot_correspondences_direct(image_a_rgb, image_a_depth_numpy, image_b_rgb, image_b_depth_numpy, uv_a, uv_b, show=False)
                correspondence_plotter.plot_correspondences_direct(image_a_rgb, image_a_depth_numpy, image_b_rgb, image_b_depth_numpy,
                                                  uv_a_long, uv_b_non_matches_long,
                                                  use_previous_plot=(fig,axes),
                                                  circ_color='r')


        # image_a_rgb, image_b_rgb = self.both_to_tensor([image_a_rgb, image_b_rgb])

        # convert PIL.Image to torch.FloatTensor
        image_a_rgb = self.rgb_image_to_tensor(image_a_rgb)
        image_b_rgb = self.rgb_image_to_tensor(image_b_rgb)

        uv_a_long = (torch.t(uv_a[0].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1), 
                     torch.t(uv_a[1].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1))
        uv_b_non_matches_long = (uv_b_non_matches[0].view(-1,1), uv_b_non_matches[1].view(-1,1) )

        # flatten correspondences and non_correspondences
        matches_a = uv_a[1].long()*image_width+uv_a[0].long()
        matches_b = uv_b[1].long()*image_width+uv_b[0].long()
        non_matches_a = uv_a_long[1].long()*image_width+uv_a_long[0].long()
        non_matches_a = non_matches_a.squeeze(1)
        non_matches_b = uv_b_non_matches_long[1].long()*image_width+uv_b_non_matches_long[0].long()
        non_matches_b = non_matches_b.squeeze(1)

        return "matches", image_a_rgb, image_b_rgb, matches_a, matches_b, non_matches_a, non_matches_b, metadata
    def __getitem__(self, index):
        """
        The method through which the dataset is accessed for training.

        The index param is not currently used, and instead each dataset[i] is the result of
        a random sampling over:
        - random scene
        - random rgbd frame from that scene
        - random rgbd frame (different enough pose) from that scene
        - various randomization in the match generation and non-match generation procedure

        returns a large amount of variables, separated by commas.

        0th return arg: the type of data sampled (this can be used as a flag for different loss functions)
        0th rtype: string

        1st, 2nd return args: image_a_rgb, image_b_rgb
        1st, 2nd rtype: 3-dimensional torch.FloatTensor of shape (image_height, image_width, 3)

        3rd, 4th return args: matches_a, matches_b
        3rd, 4th rtype: 1-dimensional torch.LongTensor of shape (num_matches)

        5th, 6th return args: non_matches_a, non_matches_b
        5th, 6th rtype: 1-dimensional torch.LongTensor of shape (num_non_matches)

        Return values 3,4,5,6 are all in the "single index" format for pixels. That is

        (u,v) --> n = u + image_width * v

        """

        # stores metadata about this data
        metadata = dict()


        # pick a scene
        scene_name = self.get_random_scene_name()
        metadata['scene_name'] = scene_name

        # image a
        image_a_idx = self.get_random_image_index(scene_name)
        image_a_rgb, image_a_depth, image_a_mask, image_a_pose = self.get_rgbd_mask_pose(scene_name, image_a_idx)

        metadata['image_a_idx'] = image_a_idx

        # image b
        image_b_idx = self.get_img_idx_with_different_pose(scene_name, image_a_pose, num_attempts=50)
        metadata['image_b_idx'] = image_b_idx
        if image_b_idx is None:
            logging.info("no frame with sufficiently different pose found, returning")
            # TODO: return something cleaner than no-data
            image_a_rgb_tensor = self.rgb_image_to_tensor(image_a_rgb)
            return self.return_empty_data(image_a_rgb_tensor, image_a_rgb_tensor)

        image_b_rgb, image_b_depth, image_b_mask, image_b_pose = self.get_rgbd_mask_pose(scene_name, image_b_idx)

        image_a_depth_numpy = np.asarray(image_a_depth)
        image_b_depth_numpy = np.asarray(image_b_depth)

        # find correspondences
        uv_a, uv_b = correspondence_finder.batch_find_pixel_correspondences(image_a_depth_numpy, image_a_pose, 
                                                                           image_b_depth_numpy, image_b_pose, 
                                                                           num_attempts=self.num_matching_attempts, img_a_mask=np.asarray(image_a_mask))

        if uv_a is None:
            logging.info("no matches found, returning")
            image_a_rgb_tensor = self.rgb_image_to_tensor(image_a_rgb)
            return self.return_empty_data(image_a_rgb_tensor, image_a_rgb_tensor)

        if self.debug:
            # downsample so can plot
            num_matches_to_plot = 10
            indexes_to_keep = (torch.rand(num_matches_to_plot)*len(uv_a[0])).floor().type(torch.LongTensor)
            uv_a = (torch.index_select(uv_a[0], 0, indexes_to_keep), torch.index_select(uv_a[1], 0, indexes_to_keep))
            uv_b = (torch.index_select(uv_b[0], 0, indexes_to_keep), torch.index_select(uv_b[1], 0, indexes_to_keep))

        # data augmentation
        if self._domain_randomize:
            image_a_rgb = correspondence_augmentation.random_domain_randomize_background(image_a_rgb, image_a_mask)
            image_b_rgb = correspondence_augmentation.random_domain_randomize_background(image_b_rgb, image_b_mask)


        if not self.debug:
            [image_a_rgb], uv_a                 = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb], uv_a)
            [image_b_rgb, image_b_mask], uv_b   = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_mask], uv_b)
        else: # also mutate depth just for plotting
            [image_a_rgb, image_a_depth], uv_a               = correspondence_augmentation.random_image_and_indices_mutation([image_a_rgb, image_a_depth], uv_a)
            [image_b_rgb, image_b_depth, image_b_mask], uv_b = correspondence_augmentation.random_image_and_indices_mutation([image_b_rgb, image_b_depth, image_b_mask], uv_b)
            image_a_depth_numpy = np.asarray(image_a_depth)
            image_b_depth_numpy = np.asarray(image_b_depth)

        # find non_correspondences

        if index%2:
            metadata['non_match_type'] = 'masked'
            logging.debug("masking non-matches")
            image_b_mask = torch.from_numpy(np.asarray(image_b_mask)).type(torch.FloatTensor)
        else:
            metadata['non_match_type'] = 'non_masked'
            logging.debug("not masking non-matches")
            image_b_mask = None
            
        image_b_shape = image_b_depth_numpy.shape
        image_width  = image_b_shape[1]
        image_height = image_b_shape[1]

        uv_b_non_matches = correspondence_finder.create_non_correspondences(uv_b, image_b_shape, 
            num_non_matches_per_match=self.num_non_matches_per_match, img_b_mask=image_b_mask)

        if self.debug:
            # only want to bring in plotting code if in debug mode
            import correspondence_plotter

            # Just show all images 
            uv_a_long = (torch.t(uv_a[0].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1), 
                     torch.t(uv_a[1].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1))
            uv_b_non_matches_long = (uv_b_non_matches[0].view(-1,1), uv_b_non_matches[1].view(-1,1) )
            
            # Show correspondences
            if uv_a is not None:
                fig, axes = correspondence_plotter.plot_correspondences_direct(image_a_rgb, image_a_depth_numpy, image_b_rgb, image_b_depth_numpy, uv_a, uv_b, show=False)
                correspondence_plotter.plot_correspondences_direct(image_a_rgb, image_a_depth_numpy, image_b_rgb, image_b_depth_numpy,
                                                  uv_a_long, uv_b_non_matches_long,
                                                  use_previous_plot=(fig,axes),
                                                  circ_color='r')


        # image_a_rgb, image_b_rgb = self.both_to_tensor([image_a_rgb, image_b_rgb])

        # convert PIL.Image to torch.FloatTensor
        image_a_rgb = self.rgb_image_to_tensor(image_a_rgb)
        image_b_rgb = self.rgb_image_to_tensor(image_b_rgb)

        uv_a_long = (torch.t(uv_a[0].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1), 
                     torch.t(uv_a[1].repeat(self.num_non_matches_per_match, 1)).contiguous().view(-1,1))
        uv_b_non_matches_long = (uv_b_non_matches[0].view(-1,1), uv_b_non_matches[1].view(-1,1) )

        # flatten correspondences and non_correspondences
        matches_a = uv_a[1].long()*image_width+uv_a[0].long()
        matches_b = uv_b[1].long()*image_width+uv_b[0].long()
        non_matches_a = uv_a_long[1].long()*image_width+uv_a_long[0].long()
        non_matches_a = non_matches_a.squeeze(1)
        non_matches_b = uv_b_non_matches_long[1].long()*image_width+uv_b_non_matches_long[0].long()
        non_matches_b = non_matches_b.squeeze(1)

        return "matches", image_a_rgb, image_b_rgb, matches_a, matches_b, non_matches_a, non_matches_b, metadata