예제 #1
0
    def test_extract_patch_script_returns_expected_patches(self):
        to_extract = [self.testImage, self.mask]
        patch_size = 4

        output = extract_patch(to_extract, patch_size)

        expectedTopLeftPatch = self.testImage[0:4, 0:4]
        expectedBottomRightPatch = self.testImage[12:16, 12:16]

        assert np.array_equal(output[0][0], expectedTopLeftPatch)
        assert np.array_equal(output[-1][0], expectedBottomRightPatch)
예제 #2
0
def raw_img_to_patches(path_raw_data,
                       path_patched_data,
                       thresh_indices=[0, 0.2, 0.8],
                       patch_size=512,
                       resampling_resolution=0.1):
    """
    Transform a raw acquisition to a folder of patches of size indicated in the arguments. Also performs resampling.
    Note: this functions needs to be run as many times as there are different general pixel size
    (thus different acquisition types / resolutions).
    :param path_raw_data: Path to where the raw image folders are located.
    :param path_patched_data: Path to where we will store the patched acquisitions.
    :param thresh_indices: List of float, determining the thresholds separating the classes.
    :param patch_size: Int, size of the patches to generate (and consequently input size of the network).
    :param resampling_resolution: Float, the resolution we need to resample to so that each sample
    has the same resolution in a dataset.
    :return: Nothing.
    """

    # If string, convert to Path objects
    path_raw_data = convert_path(path_raw_data)
    path_patched_data = convert_path(path_patched_data)

    # First we define where we are going to store the patched data and we create the directory if it does not exist.
    if not path_patched_data.exists():
        path_patched_data.mkdir(parents=True)

    # Loop over each raw image folder
    img_folder_names = [im.name for im in path_raw_data.iterdir()]
    for img_folder in tqdm(img_folder_names):
        path_img_folder = path_raw_data / img_folder
        if path_img_folder.is_dir():

            # We are now in the image folder.
            file = open(path_img_folder / 'pixel_size_in_micrometer.txt', 'r')
            pixel_size = float(file.read())
            resample_coeff = float(
                pixel_size
            ) / resampling_resolution  # Used to set the resolution to the general_pixel_size

            # We go through every file in the image folder
            data_names = [d.name for d in path_img_folder.iterdir()]
            for data in data_names:
                if 'image' in data:  # If it's the raw image.

                    img = ads.imread(path_img_folder / data)
                    img = rescale(img,
                                  resample_coeff,
                                  preserve_range=True,
                                  mode='constant').astype(int)

                elif 'mask' in data:
                    mask_init = ads.imread(path_img_folder / data)
                    mask = rescale(mask_init,
                                   resample_coeff,
                                   preserve_range=True,
                                   mode='constant',
                                   order=0)

                    # Set the mask values to the classes' values
                    mask = labellize_mask_2d(
                        mask, thresh_indices
                    )  # shape (size, size), values float 0.0-1.0

            to_extract = [img, mask]
            patches = extract_patch(to_extract, patch_size)
            # The patch extraction is done, now we put the new patches in the corresponding folders

            # We create it if it does not exist
            path_patched_folder = path_patched_data / img_folder
            if not path_patched_folder.exists():
                path_patched_folder.mkdir(parents=True)

            for j, patch in enumerate(patches):
                ads.imwrite(path_patched_folder.joinpath('image_%s.png' % j),
                            patch[0])
                ads.imwrite(path_patched_folder.joinpath('mask_%s.png' % j),
                            patch[1])
예제 #3
0
    def test_extract_patch_script_errors_for_incorrect_first_arg_format(self):
        to_extract = self.testImage
        patch_size = 4

        with pytest.raises(ValueError):
            extract_patch(to_extract, patch_size)
예제 #4
0
    def test_extract_patch_script_errors_for_patch_size_eq_to_image_dim(self):
        to_extract = [self.testImage, self.mask]
        patch_size = min(self.testImage.shape)

        with pytest.raises(ValueError):
            extract_patch(to_extract, patch_size)
예제 #5
0
    def test_extract_patch_script_errors_for_patch_size_smaller_than_3(self):
        to_extract = [self.testImage, self.mask]
        patch_size = 2

        with pytest.raises(ValueError):
            extract_patch(to_extract, patch_size)
예제 #6
0
def rescaling(patch, factor_max=1.2, verbose=0):
    """
    Resamples the image by a factor between 1/factor_max and factor_max. Does not resample if the factor is
    too close to 1. Random sampling increases axons size diversity.
    :param patch: List of 2 or 3 ndarrays [image,mask,(weights)]
    :param factor_max: Float, maximum rescaling factor possible. Minimum is obtained by inverting this max_factor.
    :param verbose: Int. The higher, the more information is displayed about the transformation.
    :return: List of 2 or 3 randomly rescaled input, [image,mask, (weights)]
    """

    low_bound = 1.0/factor_max
    high_bound = 1.0*factor_max
    n_classes = patch[1].shape[-1]

    # Randomly choosing the resampling factor.
    scale = np.random.uniform(low_bound, high_bound, 1)[0]
    if verbose >= 1:
        print(('rescaling factor: ', scale))
        
    patch_size = patch[0].shape[0]
    new_patch_size = int(patch_size*scale)

    # If the resampling factor is too close to 1 we do not resample.
    if (new_patch_size <= patch_size+5) and (new_patch_size >= patch_size-5): # To avoid having q_h = 0
        return patch
    else :
        image_rescale = rescale(patch[0], scale, preserve_range= True, mode='constant')
        mask_rescale = rescale(patch[1], scale, preserve_range= True, mode='constant')
        if len(patch) == 3:
            weights_rescale = rescale(patch[2], scale, preserve_range=True, mode='constant')

        s_r = mask_rescale.shape[0]
        q_h, r_h = divmod(patch_size-s_r,2)

        # If we undersample, we pad the rest of the image.
        if q_h > 0:
            image_rescale = np.pad(image_rescale,(q_h, q_h+r_h), mode = "reflect")
            mask_rescale = [np.pad(np.squeeze(e),(q_h, q_h+r_h), mode = "reflect") for e in np.split(mask_rescale, n_classes, axis=-1)]
            mask_rescale = np.stack(mask_rescale, axis=-1)
            weights_rescale = np.pad(weights_rescale,(q_h, q_h+r_h), mode = "reflect")

        # if we oversample
        else:
            to_extract = [image_rescale, mask_rescale]
            if len(patch) == 3:
                to_extract += [weights_rescale]

            # We extract all the patches coming from the oversampled image.
            patches = extract_patch(to_extract, patch_size)
            i = np.random.randint(len(patches), size=1)[0]

            if len(patch) == 3:
                image_rescale, mask_rescale, weights_rescale = patches[i]
            else:
                image_rescale, mask_rescale = patches[i]

        mask_rescale = np.array(mask_rescale)
        if len(patch) == 3:
            weights_rescale = np.array(weights_rescale)
            return [image_rescale.astype(np.uint8), mask_rescale.astype(np.uint8),
                              weights_rescale.astype(np.float32)]
        else:
            return [image_rescale.astype(np.uint8), mask_rescale.astype(np.uint8)]