예제 #1
0
def test_build_full_labelled_image(tanzania_image_size, tanzania_nb_labels,
                                   tanzania_raw_image_size):
    """Test the label prediction on a high-resolution image that has to be
        tiled during inference process

    The labelled output is composed of label IDs that must corresponds to the
    dataset glossary, and its shape must equal the original image size.
    """
    datapath = "./tests/data"
    dataset = "tanzania"
    image_paths = postprocess.get_image_paths(datapath, dataset,
                                              tanzania_image_size,
                                              "tanzania_sample")
    images = postprocess.extract_images(image_paths)
    coordinates = postprocess.extract_coordinates_from_filenames(image_paths)
    model = postprocess.get_trained_model(datapath, dataset,
                                          tanzania_image_size,
                                          tanzania_nb_labels)
    labelled_image = postprocess.build_full_labelled_image(
        images,
        coordinates,
        model,
        tile_size=tanzania_image_size,
        img_width=tanzania_raw_image_size,
        batch_size=2,
    )
    assert labelled_image.shape == (
        tanzania_raw_image_size,
        tanzania_raw_image_size,
    )
    assert np.all(
        [ul in range(tanzania_nb_labels) for ul in np.unique(labelled_image)])
예제 #2
0
def test_get_image_paths(tanzania_image_size):
    """Test the image path getting function

    Preprocessed image filenames must end with ".png"
    """
    filenames = postprocess.get_image_paths("./tests/data", "tanzania",
                                            tanzania_image_size, "grid_066")
    assert np.all([f.endswith(".png") for f in filenames])
예제 #3
0
def test_get_image_paths(tanzania_image_size):
    """Test the image path getting function

    Preprocessed image filenames must end with ".png"
    """
    filenames = postprocess.get_image_paths(
        f"./tests/data/tanzania/preprocessed/{tanzania_image_size}/testing/", "tanzania_sample"
    )
    assert np.all([f.endswith(".png") for f in filenames])
예제 #4
0
def test_extract_images(tanzania_image_size,
                        tanzania_nb_output_testing_images):
    """Test the image extraction function, that retrieve the accurate data in a
    'numpy.array' starting from a list of filenames

    This image data must be shaped as (nb_filenames, image_size, image_size, 3).
    """
    filenames = postprocess.get_image_paths("./tests/data", "tanzania",
                                            tanzania_image_size,
                                            "tanzania_sample")
    images = postprocess.extract_images(filenames)
    assert len(images.shape) == 4
    assert images.shape[0] == tanzania_nb_output_testing_images
    assert images.shape[1] == tanzania_image_size
    assert images.shape[2] == tanzania_image_size
    assert images.shape[3] == 3