예제 #1
0
def test_extract_tile_items(tanzania_example_image, tanzania_example_labels):
    """Test the extraction of polygons that overlap a given squared tile, based
    on a reference test image (see 'tests/data/tanzania/input/training/').

    The tests check that:
    - the example image contains 7 valid items
    - the items are 'Polygon' (in opposition to 'MultiPolygon')
    - the item union is contained into the tile footprint (overlapping items
    are cutted out so as out-of-image parts are removed)
    """
    ds = gdal.Open(str(tanzania_example_image))
    geofeatures = get_image_features(ds)
    labels = gpd.read_file(tanzania_example_labels)
    labels = labels.loc[~labels.geometry.isna(), ["condition", "geometry"]]
    none_mask = [lc is None for lc in labels.condition]
    labels.loc[none_mask, "condition"] = "Complete"
    tile_items = extract_tile_items(geofeatures, labels, 0, 0, 1000, 1000)
    expected_items = 7
    assert tile_items.shape[0] == expected_items
    assert np.all([geom.is_valid for geom in tile_items["geometry"]])
    assert np.all(
        [geom.geom_type == "Polygon" for geom in tile_items["geometry"]])
    item_bounds = tile_items.unary_union.bounds
    assert (item_bounds[0] >= geofeatures["west"]
            and item_bounds[0] <= geofeatures["east"])
    assert (item_bounds[1] >= geofeatures["south"]
            and item_bounds[1] <= geofeatures["north"])
    assert (item_bounds[2] >= geofeatures["west"]
            and item_bounds[2] <= geofeatures["east"])
    assert (item_bounds[3] >= geofeatures["south"]
            and item_bounds[3] <= geofeatures["north"])
예제 #2
0
def test_extract_empty_tile_items(tanzania_example_image,
                                  tanzania_example_labels):
    """Test the extraction of polygons that overlap a given squared tile, based
    on a reference test image (see 'tests/data/tanzania/input/training/').

    The tests is focused on an empty tile, that must provide an empty item set.
    """
    ds = gdal.Open(str(tanzania_example_image))
    geofeatures = get_image_features(ds)
    labels = gpd.read_file(tanzania_example_labels)
    labels = labels.loc[~labels.geometry.isna(), ["condition", "geometry"]]
    none_mask = [lc is None for lc in labels.condition]
    labels.loc[none_mask, "condition"] = "Complete"
    empty_tile_items = extract_tile_items(geofeatures, labels, 450, 450, 100,
                                          100)
    assert empty_tile_items.shape[0] == 0
예제 #3
0
    def _preprocess_for_training(self, image_filename, output_dir, nb_images):
        """Resize/crop then save the training & label images

        Parameters
        ----------
        image_filename : str
            Full path towards the image on the disk
        output_dir : str
            Output path where preprocessed image must be saved

        Returns
        -------
        dict
            Key/values with the filenames and label ids
        """
        raster = gdal.Open(image_filename)
        raw_img_width = raster.RasterXSize
        raw_img_height = raster.RasterYSize
        image_data = raster.ReadAsArray()
        image_data = np.swapaxes(image_data, 0, 2)
        result_dicts = []
        logger.info(
            "Image filename: %s, size: (%s, %s)",
            image_filename.split("/")[-1], raw_img_width, raw_img_height
        )

        label_filename = image_filename.replace("images", "labels").replace(
            ".tif", ".geojson"
        )
        labels = gpd.read_file(label_filename)
        labels = labels.loc[~labels.geometry.isna(), ["condition", "geometry"]]
        none_mask = [lc is None for lc in labels.condition]
        labels.loc[none_mask, "condition"] = "Complete"

        nb_attempts = 0
        image_counter = 0
        empty_image_counter = 0
        while image_counter < nb_images and nb_attempts < 2 * nb_images:
            # randomly pick an image
            x = np.random.randint(0, raw_img_width - self.image_size)
            y = np.random.randint(0, raw_img_height - self.image_size)

            tile_data = image_data[
                x:(x + self.image_size), y:(y + self.image_size)
            ]
            tile_image = Image.fromarray(tile_data)
            raster_features = geometries.get_image_features(raster)
            tile_items = geometries.extract_tile_items(
                raster_features, labels, x, y, self.image_size, self.image_size
            )
            mask = self.load_mask(tile_items, raster_features, x, y)
            label_dict = utils.build_labels(
                mask, range(self.get_nb_labels()), "tanzania"
            )
            labelled_image = utils.build_image_from_config(mask, self.labels)
            if len(tile_items) > 0:
                tiled_results = self._serialize(
                    tile_image,
                    labelled_image,
                    label_dict,
                    image_filename,
                    output_dir,
                    x,
                    y,
                    "nw",
                )
                if tiled_results:
                    result_dicts.append(tiled_results)
                image_counter += 1
                tile_image_ne = tile_image.transpose(Image.FLIP_LEFT_RIGHT)
                labelled_image_ne = labelled_image.transpose(
                    Image.FLIP_LEFT_RIGHT
                )
                tiled_results_ne = self._serialize(
                    tile_image_ne,
                    labelled_image_ne,
                    label_dict,
                    image_filename,
                    output_dir,
                    x,
                    y,
                    "ne",
                )
                if tiled_results_ne:
                    result_dicts.append(tiled_results_ne)
                image_counter += 1
                tile_image_sw = tile_image.transpose(Image.FLIP_TOP_BOTTOM)
                labelled_image_sw = labelled_image.transpose(
                    Image.FLIP_TOP_BOTTOM
                )
                tiled_results_sw = self._serialize(
                    tile_image_sw,
                    labelled_image_sw,
                    label_dict,
                    image_filename,
                    output_dir,
                    x,
                    y,
                    "sw",
                )
                if tiled_results_sw:
                    result_dicts.append(tiled_results_sw)
                image_counter += 1
                tile_image_se = tile_image_sw.transpose(Image.FLIP_LEFT_RIGHT)
                labelled_image_se = labelled_image_sw.transpose(
                    Image.FLIP_LEFT_RIGHT
                )
                tiled_results_se = self._serialize(
                    tile_image_se,
                    labelled_image_se,
                    label_dict,
                    image_filename,
                    output_dir,
                    x,
                    y,
                    "se",
                )
                if tiled_results_se:
                    result_dicts.append(tiled_results_se)
                image_counter += 1
                del tile_image_se, tile_image_sw, tile_image_ne
                del labelled_image_se, labelled_image_sw, labelled_image_ne
            else:
                if empty_image_counter < 0.1 * nb_images:
                    tiled_results = self._serialize(
                        tile_image,
                        labelled_image,
                        label_dict,
                        image_filename,
                        output_dir,
                        x,
                        y,
                        "nw",
                    )
                    if tiled_results:
                        result_dicts.append(tiled_results)
                    image_counter += 1
                    empty_image_counter += 1
            nb_attempts += 1
        del raster
        logger.info(
            "Generate %s images after %s attempts.", image_counter, nb_attempts
        )
        return result_dicts