Ejemplo n.º 1
0
    def _getitem(self, key):
        from rasterio.vrt import WarpedVRT

        band_key, window, squeeze_axis, np_inds = self._get_indexer(key)

        if not band_key or any(start == stop for (start, stop) in window):
            # no need to do IO
            shape = (len(band_key), ) + tuple(stop - start
                                              for (start, stop) in window)
            out = np.zeros(shape, dtype=self.dtype)
        else:
            with self.lock:
                riods = self.manager.acquire(needs_lock=False)
                if self.vrt_params is not None:
                    riods = WarpedVRT(riods, **self.vrt_params)
                out = riods.read(band_key, window=window, masked=self.masked)
                if self.masked:
                    out = np.ma.filled(out.astype(self.dtype), np.nan)
                if self.mask_and_scale:
                    for band in np.atleast_1d(band_key):
                        band_iii = band - 1
                        out[band_iii] = (
                            out[band_iii] * riods.scales[band_iii] +
                            riods.offsets[band_iii])

        if squeeze_axis:
            out = np.squeeze(out, axis=squeeze_axis)
        return out[np_inds]
Ejemplo n.º 2
0
def coregister_raster(a_uri, b_uri, dest_file, resampling=Resampling.bilinear):
    """ Coregister raster a to the extent, resolution and projection of raster b.

    Write to dest_file.

    a_uri (read), b_uri (read), and dest_file (write) are passed to
    rasterio.open and thus are bound by its semantics.

    """
    with rasterio.open(a_uri) as ds_a, rasterio.open(b_uri) as ds_b:
        vrt = WarpedVRT(
            ds_a,
            crs=ds_b.crs,
            height=ds_b.height,
            width=ds_b.width,
            resampling=resampling,
            transform=ds_b.transform,
        )
        data = vrt.read(1)
        with rasterio.open(
                dest_file,
                "w",
                compress="lzw",
                count=1,
                crs=ds_b.crs,
                driver="GTiff",
                dtype=data.dtype,
                height=ds_b.height,
                width=ds_b.width,
                nodata=ds_a.nodata,
                tiled=True,
                transform=ds_b.transform,
        ) as dst:
            dst.write(data, indexes=1)
Ejemplo n.º 3
0
def extract_window(src, window, transform, nodata):
    """Extract raster data from src within window, and warp to DATA_CRS

    Parameters
    ----------
    src : open rasterio Dataset
    window : rasterio Window boject
    transform : rasterio Transform object
    nodata : int or float

    Returns
    -------
    2d array
    """
    vrt = WarpedVRT(
        src,
        width=window.width,
        height=window.height,
        nodata=nodata,
        transform=transform,
        crs=DATA_CRS,
        resampling=Resampling.nearest,
    )

    return vrt.read()[0]
Ejemplo n.º 4
0
def test_vrt_src_kept_alive(path_rgb_byte_tif):
    """VRT source dataset is kept alive, preventing crashes"""

    with rasterio.open(path_rgb_byte_tif) as dst:
        vrt = WarpedVRT(dst, crs="EPSG:3857")

    assert (vrt.read() != 0).any()
    vrt.close()
Ejemplo n.º 5
0
        def worker(path):

            raster = rasterio_open(path)
            w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds)
            transform, _, _ = calculate_default_transform(raster.crs, "EPSG:3857", raster.width, raster.height, w, s, e, n)
            tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)]
            tiled = []

            for tile in tiles:

                try:
                    w, s, e, n = mercantile.xy_bounds(tile)

                    # inspired by rio-tiler, cf: https://github.com/mapbox/rio-tiler/pull/45
                    warp_vrt = WarpedVRT(
                        raster,
                        crs="epsg:3857",
                        resampling=Resampling.bilinear,
                        add_alpha=False,
                        transform=from_bounds(w, s, e, n, args.ts, args.ts),
                        width=math.ceil((e - w) / transform.a),
                        height=math.ceil((s - n) / transform.e),
                    )
                    data = warp_vrt.read(
                        out_shape=(len(raster.indexes), args.ts, args.ts), window=warp_vrt.window(w, s, e, n)
                    )
                    image = np.moveaxis(data, 0, 2)  # C,H,W -> H,W,C
                except:
                    sys.exit("Error: Unable to tile {} from raster {}.".format(str(tile), raster))

                tile_key = (str(tile.x), str(tile.y), str(tile.z))
                if not args.label and len(tiles_map[tile_key]) == 1 and is_border(image):
                    progress.update()
                    continue

                if len(tiles_map[tile_key]) > 1:
                    out = os.path.join(splits_path, str(tiles_map[tile_key].index(path)))
                else:
                    out = args.out

                x, y, z = map(int, tile)

                if not args.label:
                    ret = tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image)
                if args.label:
                    ret = tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, image)

                if not ret:
                    sys.exit("Error: Unable to write tile {} from raster {}.".format(str(tile), raster))

                if len(tiles_map[tile_key]) == 1:
                    progress.update()
                    tiled.append(mercantile.Tile(x=x, y=y, z=z))

            return tiled
Ejemplo n.º 6
0
def test_vrt_mem_src_kept_alive(path_rgb_byte_tif):
    """VRT in-memory source dataset is kept alive, preventing crashes"""

    with open(path_rgb_byte_tif, "rb") as fp:
        bands = fp.read()

    with MemoryFile(bands) as memfile, memfile.open() as dst:
        vrt = WarpedVRT(dst, crs="EPSG:3857")

    assert (vrt.read() != 0).any()
    vrt.close()
Ejemplo n.º 7
0
        def worker(path):

            raster = rasterio_open(path)
            w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds)
            tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)]
            tiled = []

            for tile in tiles:

                if cover and tile not in cover:
                    continue

                w, s, e, n = mercantile.xy_bounds(tile)

                warp_vrt = WarpedVRT(
                    raster,
                    crs="epsg:3857",
                    resampling=Resampling.bilinear,
                    add_alpha=False,
                    transform=from_bounds(w, s, e, n, width, height),
                    width=width,
                    height=height,
                )
                data = warp_vrt.read(out_shape=(len(raster.indexes), width, height), window=warp_vrt.window(w, s, e, n))
                image = np.moveaxis(data, 0, 2)  # C,H,W -> H,W,C

                tile_key = (str(tile.x), str(tile.y), str(tile.z))
                if (
                    not args.label
                    and len(tiles_map[tile_key]) == 1
                    and is_nodata(image, args.nodata, args.nodata_threshold, args.keep_borders)
                ):
                    progress.update()
                    continue

                if len(tiles_map[tile_key]) > 1:
                    out = os.path.join(splits_path, str(tiles_map[tile_key].index(path)))
                else:
                    out = args.out

                x, y, z = map(int, tile)

                if not args.label:
                    tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image)
                if args.label:
                    tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, image)

                if len(tiles_map[tile_key]) == 1:
                    progress.update()
                    tiled.append(mercantile.Tile(x=x, y=y, z=z))

            return tiled
Ejemplo n.º 8
0
def test_wrap_file(path_rgb_byte_tif):
    """A VirtualVRT has the expected dataset properties."""
    with rasterio.open(path_rgb_byte_tif) as src:
        vrt = WarpedVRT(src, crs=DST_CRS)
        assert vrt.crs == CRS.from_string(DST_CRS)
        assert tuple(round(x, 1) for x in vrt.bounds) == (
            -8789636.7, 2700460.0, -8524406.4, 2943560.2)
        assert vrt.name.startswith('WarpedVRT(')
        assert vrt.name.endswith('tests/data/RGB.byte.tif)')
        assert vrt.indexes == (1, 2, 3)
        assert vrt.nodatavals == (0, 0, 0)
        assert vrt.dtypes == ('uint8', 'uint8', 'uint8')
        assert vrt.read().shape == (3, 736, 803)
Ejemplo n.º 9
0
def test_wrap_file(path_rgb_byte_tif):
    """A VirtualVRT has the expected dataset properties."""
    with rasterio.open(path_rgb_byte_tif) as src:
        vrt = WarpedVRT(src, dst_crs='EPSG:3857')
        assert vrt.crs == 'EPSG:3857'
        assert tuple(round(x, 1) for x in vrt.bounds) == (
            -8789636.7, 2700460.0, -8524406.4, 2943560.2)
        assert vrt.name.startswith('WarpedVRT(')
        assert vrt.name.endswith('tests/data/RGB.byte.tif)')
        assert vrt.indexes == (1, 2, 3)
        assert vrt.nodatavals == (0, 0, 0)
        assert vrt.dtypes == ('uint8', 'uint8', 'uint8')
        assert vrt.read().shape == (3, 736, 803)
Ejemplo n.º 10
0
def test_wrap_file(path_rgb_byte_tif):
    """A VirtualVRT has the expected dataset properties."""
    with rasterio.open(path_rgb_byte_tif) as src:
        vrt = WarpedVRT(src, crs=DST_CRS)
        assert vrt.crs == CRS.from_string(DST_CRS)
        assert tuple(round(x, 1) for x in vrt.bounds) == (
            -8789636.7, 2700460.0, -8524406.4, 2943560.2
        )
        assert vrt.name.startswith("WarpedVRT(")
        assert vrt.name.endswith("tests/data/RGB.byte.tif)")
        assert vrt.indexes == (1, 2, 3)
        assert vrt.nodatavals == (0, 0, 0)
        assert vrt.dtypes == ("uint8", "uint8", "uint8")
        assert vrt.read().shape == (3, 736, 803)
Ejemplo n.º 11
0
    def _read_window(self, vrt: WarpedVRT, dst_window: Window) -> MaskedArray:
        """Read window of input raster."""
        dst_bounds: Bounds = bounds(dst_window, self.dst[self.default_format].transform)
        window = vrt.window(*dst_bounds)

        src_bounds = transform_bounds(
            self.dst[self.default_format].crs, self.src.crs, *dst_bounds
        )

        LOGGER.debug(
            f"Read {dst_window} for Tile {self.tile_id} - this corresponds to bounds {src_bounds} in source"
        )

        shape = (
            len(self.layer.input_bands),
            int(round(dst_window.height)),
            int(round(dst_window.width)),
        )

        try:

            return vrt.read(
                window=window,
                out_shape=shape,
                masked=True,
            )
        except rasterio.RasterioIOError as e:
            if "Access window out of range" in str(e) and (
                shape[1] == 1 or shape[2] == 1
            ):
                LOGGER.warning(
                    f"Access window out of range while reading {dst_window} for Tile {self.tile_id}. "
                    "This is most likely due to subpixel misalignment. "
                    "Returning empty array instead."
                )
                return np.ma.array(
                    data=np.zeros(shape=shape), mask=np.ones(shape=shape)
                )

            else:
                LOGGER.warning(
                    f"RasterioIO error while reading {dst_window} for Tile {self.tile_id}. "
                    "Will make attempt to retry."
                )
                raise
Ejemplo n.º 12
0
    def _getitem(self, key):
        from rasterio.vrt import WarpedVRT

        band_key, window, squeeze_axis, np_inds = self._get_indexer(key)

        if not band_key or any(start == stop for (start, stop) in window):
            # no need to do IO
            shape = (len(band_key),) + tuple(stop - start for (start, stop) in window)
            out = np.zeros(shape, dtype=self.dtype)
        else:
            with self.lock:
                riods = self.manager.acquire(needs_lock=False)
                if self.vrt_params is not None:
                    riods = WarpedVRT(riods, **self.vrt_params)
                out = riods.read(band_key, window=window)

        if squeeze_axis:
            out = np.squeeze(out, axis=squeeze_axis)
        return out[np_inds]
Ejemplo n.º 13
0
    def _getitem(self, key):
        from rasterio.vrt import WarpedVRT
        band_key, window, squeeze_axis, np_inds = self._get_indexer(key)

        if not band_key or any(start == stop for (start, stop) in window):
            # no need to do IO
            shape = (len(band_key),) + tuple(
                stop - start for (start, stop) in window)
            out = np.zeros(shape, dtype=self.dtype)
        else:
            with self.lock:
                riods = self.manager.acquire(needs_lock=False)
                if self.vrt_params is not None:
                    riods = WarpedVRT(riods, **self.vrt_params)
                out = riods.read(band_key, window=window)

        if squeeze_axis:
            out = np.squeeze(out, axis=squeeze_axis)
        return out[np_inds]
Ejemplo n.º 14
0
def _warper(img, transform, s_crs, t_crs, resampling):
    """
    Warp an image returning it as a virtual file
    """
    b, h, w = img.shape
    with MemoryFile() as memfile:
        with memfile.open(
                driver="GTiff",
                height=h,
                width=w,
                count=b,
                dtype=str(img.dtype.name),
                crs=s_crs,
                transform=transform,
        ) as mraster:
            for band in range(b):
                mraster.write(img[band, :, :], band + 1)
            # --- Virtual Warp
            vrt = WarpedVRT(mraster, crs=t_crs, resampling=resampling)
            img = vrt.read()
    return img, vrt
Ejemplo n.º 15
0
    return x


# go through all tiles of input image
# run convolutional model on tile
# write labels to output label raster
with tqdm.tqdm(total=num_tiles_y * num_tiles_x) as pbar:
    for y in range(patch_radius, image_height - patch_radius, tile_size):
        for x in range(patch_radius, image_width - patch_radius, tile_size):
            pbar.update(1)

            window = Window(x - patch_radius, y - patch_radius,
                            padded_tile_size, padded_tile_size)

            # get tile from chm
            chm_tile = chm_vrt.read(1, window=window)
            if chm_tile.shape[0] != padded_tile_size or chm_tile.shape[
                    1] != padded_tile_size:
                pad = ((0, padded_tile_size - chm_tile.shape[0]),
                       (0, padded_tile_size - chm_tile.shape[1]))
                chm_tile = np.pad(chm_tile,
                                  pad,
                                  mode='constant',
                                  constant_values=0)

            chm_tile = np.expand_dims(chm_tile, axis=0)
            chm_bad = chm_tile <= height_threshold

            # get tile from image
            image_tile = image.read(window=window)
            image_pad_y = padded_tile_size - image_tile.shape[1]
Ejemplo n.º 16
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-o",
                        "--output-dir",
                        type=str,
                        default=DEFAULT_OUTPUT_DIR)
    parser.add_argument("--overwrite",
                        action="store_true",
                        help="Overwrite existing CSVs in output dir")
    args = parser.parse_args()

    client = boto3.client("s3")

    for experiment in EXPERIMENTS:
        logger.info("Processing Experiment: {}".format(experiment.id))
        experiment_url = urlparse(experiment.s3_dir)

        output_csv = os.path.join(args.output_dir,
                                  "{}-iou-f1.csv".format(experiment.id))
        if not args.overwrite and os.path.isfile(output_csv):
            logger.info("SKIPPING Experiment {}. Already exists".format(
                experiment.id))
            continue

        list_result = client.list_objects_v2(
            Bucket=experiment_url.hostname,
            Prefix=os.path.join(experiment_url.path.lstrip("/"), "predict"),
        )
        if list_result["IsTruncated"]:
            raise ValueError(
                "Didn't get all results, implement ContinuationToken")

        try:
            s3_keys = list_result["Contents"]
        except KeyError:
            print("WARNING: No predictions for {}. Continuing...".format(
                experiment))
            continue

        with open(output_csv, "w") as fp_csv:
            csv_fieldnames = [
                "chip_id",
                "f1_all",
                "f1_urban",
                "f1_not_urban",
                "iou_all",
                "iou_urban",
                "iou_not_urban",
            ]
            writer = csv.DictWriter(fp_csv, csv_fieldnames)
            writer.writeheader()

            for obj in s3_keys:
                chip_id = os.path.basename(obj["Key"])

                prediction_tif_uri = "s3://{}/{}".format(
                    experiment_url.hostname, obj["Key"])
                logger.info("\tprediction: {}".format(prediction_tif_uri))

                # TODO: Encode this in Experiment somehow
                # Replace chip S1 qualifier with QC to match filenames in ground truth dir
                if chip_id.startswith("SEN1FLOODS11"):
                    truth_tif_uri = os.path.join(experiment.ground_truth_dir,
                                                 chip_id)
                else:
                    chip_name = chip_id.replace("_S1.tif", "_QC.tif")
                    truth_tif_uri = os.path.join(experiment.ground_truth_dir,
                                                 chip_name)
                logger.info("\ttruth: {}".format(truth_tif_uri))

                with rasterio.open(prediction_tif_uri) as ds_p, rasterio.open(
                        truth_tif_uri) as ds_t, rasterio.open(
                            NLCD_TIF_URI) as ds_nlcd:
                    nlcd_vrt = WarpedVRT(
                        ds_nlcd,
                        crs=ds_p.crs,
                        height=ds_p.height,
                        width=ds_p.width,
                        resampling=Resampling.nearest,
                        transform=ds_p.transform,
                        nodata=NODATA,
                    )

                    p_band = ds_p.read(1).flatten()
                    t_band = ds_t.read(1).flatten()
                    nlcd_band = nlcd_vrt.read(1).flatten()

                assert len(p_band) == len(t_band)
                assert len(p_band) == len(nlcd_band)

                # TODO: Encode this in Experiment somehow...
                # For two class USFIMR experiments, collapse ground truth three class
                # labels for water to two class: 2 (flood) + 1 (perm) -> 1 (water)
                if experiment.id.startswith(
                        "USFIMR"
                ) and 'flood' not in experiment.id and 'permanent' not in experiment.id:
                    t_band = np.where(t_band == 2, 1, t_band)

                nlcd_urban_mask = np.ma.masked_outside(nlcd_band, 21, 24).mask
                p_band_urban = np.ma.array(p_band,
                                           mask=nlcd_urban_mask).filled(NODATA)
                t_band_urban = np.ma.array(t_band,
                                           mask=nlcd_urban_mask).filled(NODATA)

                nlcd_not_urban_mask = np.ma.masked_inside(nlcd_band, 21,
                                                          24).mask
                p_band_not_urban = np.ma.array(
                    p_band, mask=nlcd_not_urban_mask).filled(NODATA)
                t_band_not_urban = np.ma.array(
                    t_band, mask=nlcd_not_urban_mask).filled(NODATA)

                labels = experiment.labels
                scores = {
                    "chip_id":
                    chip_id,
                    "f1_all":
                    f1_score(t_band, p_band, labels=labels, average=None),
                    "f1_urban":
                    f1_score(t_band_urban,
                             p_band_urban,
                             labels=labels,
                             average=None),
                    "f1_not_urban":
                    f1_score(t_band_not_urban,
                             p_band_not_urban,
                             labels=labels,
                             average=None),
                    "iou_all":
                    jaccard_score(t_band, p_band, labels=labels, average=None),
                    "iou_urban":
                    jaccard_score(t_band_urban,
                                  p_band_urban,
                                  labels=labels,
                                  average=None),
                    "iou_not_urban":
                    jaccard_score(t_band_not_urban,
                                  p_band_not_urban,
                                  labels=labels,
                                  average=None),
                }
                logger.debug("\t\t{}".format(scores))
                writer.writerow(scores)

        client.upload_file(
            output_csv,
            experiment_url.hostname,
            os.path.join(experiment_url.path, "stats-iou-f1.csv"),
        )
Ejemplo n.º 17
0
mask, transform, window = get_input_area_mask("flm")

print("Reading and warping Florida Marine Blueprint...")
with rasterio.open(src_dir / "FLBlueprintVer1.tif") as src:
    nodata = 255
    vrt = WarpedVRT(
        src,
        width=window.width,
        height=window.height,
        nodata=nodata,
        crs=DATA_CRS,
        transform=transform,
        resampling=Resampling.nearest,
    )

    data = vrt.read()[0].astype("uint8")

# remap data
remap_table = np.array([[1, 1], [2, 2], [3, 3]], dtype="uint8")
data = remap(data, remap_table, nodata=255)

# apply input area mask
data = np.where(mask == 1, data, nodata).astype("uint8")

write_raster(outfilename,
             data,
             transform=transform,
             crs=DATA_CRS,
             nodata=nodata)

add_overviews(outfilename)
Ejemplo n.º 18
0
        def make_tile(level, tile):
            """
            MISSING
            :param level:
            :param tile:
            :return:
            """

            # x,y tile indexes
            x = tile[0][0]
            y = tile[0][1]

            def div_by_16(x):
                if divmod(x, 16)[1] == 0:
                    return x
                return div_by_16(x - 1)

            # put tile in its respective dir
            out_dir = out_folder.joinpath(str(level))
            if not out_dir.exists():
                out_dir.mkdir(exist_ok=True)

            size_x = tile[1].width if tile[1].width > 0 else 1
            size_y = tile[1].height if tile[1].height > 0 else 1

            # Out file constructor
            # how many chars to use for representing the tiles.
            name_length = max(len(str(self.tileinfos[level].countTilesX)),
                              len(str(self.tileinfos[level].countTilesY))) + 1
            filename = name_template.format(basename=self.name,
                                            x=str(x).zfill(name_length),
                                            y=str(y).zfill(name_length))
            out_filepath = out_dir.joinpath(filename)
            ## End

            profile = default_gtiff_profile
            profile.update(
                crs='epsg:4326',
                driver='GTiff',
                transform=tile[2],
                compress='lzw',
                count=1,
                width=size_x,
                height=size_y,
                blockysize=div_by_16(min(self.blockSize, tile[1].height)),
                blockxsize=div_by_16(min(self.blockSize, tile[1].width)),
            )

            if level > 1:
                # except OSError:
                #     # in this level, the amount of pixels that need to be resampled are too many.
                #     # I am choosing to use pixel at the central coordinate of the processing tile
                #     # Sample error:
                #     # ERROR 1: Integer overflow : nSrcXSize=425985, nSrcYSize=163840
                # TODO: don't be lazy, clean write
                try:
                    self.tileinfos[level - 1]
                except KeyError:
                    _meta = self.get_metadata(level - 1)
                    self.tileinfos[level - 1] = TileInfo(
                        _meta['width'], _meta['height'], self.TileWidth,
                        self.TileHeight)

                finally:
                    name_length = max(
                        len(str(self.tileinfos[level - 1].countTilesX)),
                        len(str(self.tileinfos[level - 1].countTilesY))) + 1

                prev_lvl_tiles = tile_children(zoom=level,
                                               src=out_filepath,
                                               ndigits=name_length)
                vrt_handler = buildvrt(prev_lvl_tiles)
                with rio.open(vrt_handler) as src:
                    profile.update(nodata=src.nodata, dtype=src.meta['dtype'])
                    resolution_factor = pow(2, 1)
                    lvlx_height = src.height / 2
                    lvlx_width = src.width / 2
                    lvlx_tranform = Affine(src.transform.a * resolution_factor,
                                           src.transform.b, src.transform.c,
                                           src.transform.d,
                                           src.transform.e * resolution_factor,
                                           src.transform.f)
                    vrt = WarpedVRT(src,
                                    transform=lvlx_tranform,
                                    width=lvlx_width,
                                    height=lvlx_height)
                    data = vrt.read(1)
            else:
                with self.get_dataset(level) as src:

                    profile.update(nodata=src.nodata, dtype=src.meta['dtype'])
                    data = src.read(1, window=tile[1])

            try:
                with rio.open(out_filepath, 'w', **profile) as dst:
                    window_out = Window(0, 0, size_x, size_y)
                    dst.write(data, window=window_out, indexes=1)

            except:
                print(profile)
                raise Exception
Ejemplo n.º 19
0
def main(args):

    if args.type == "label":
        try:
            config = load_config(args.config)
        except:
            sys.exit("Error: Unable to load DataSet config file")

        classes = config["classes"]["title"]
        colors = config["classes"]["colors"]
        assert len(classes) == len(colors), "classes and colors coincide"
        assert len(colors) == 2, "only binary models supported right now"

    try:
        raster = rasterio_open(args.raster)
        w, s, e, n = bounds = transform_bounds(raster.crs, "EPSG:4326",
                                               *raster.bounds)
        transform, _, _ = calculate_default_transform(raster.crs, "EPSG:3857",
                                                      raster.width,
                                                      raster.height, *bounds)
    except:
        sys.exit("Error: Unable to load raster or deal with it's projection")

    tiles = [
        mercantile.Tile(x=x, y=y, z=z)
        for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
    ]
    tiles_nodata = []

    for tile in tqdm(tiles, desc="Tiling", unit="tile", ascii=True):

        w, s, e, n = tile_bounds = mercantile.xy_bounds(tile)

        # Inspired by Rio-Tiler, cf: https://github.com/mapbox/rio-tiler/pull/45
        warp_vrt = WarpedVRT(
            raster,
            crs="EPSG:3857",
            resampling=Resampling.bilinear,
            add_alpha=False,
            transform=from_bounds(*tile_bounds, args.size, args.size),
            width=math.ceil((e - w) / transform.a),
            height=math.ceil((s - n) / transform.e),
        )
        data = warp_vrt.read(out_shape=(len(raster.indexes), args.size,
                                        args.size),
                             window=warp_vrt.window(w, s, e, n))

        # If no_data is set, remove all tiles with at least one whole border filled only with no_data (on all bands)
        if type(args.no_data) is not None and (
                np.all(data[:, 0, :] == args.no_data)
                or np.all(data[:, -1, :] == args.no_data)
                or np.all(data[:, :, 0] == args.no_data)
                or np.all(data[:, :, -1] == args.no_data)):
            tiles_nodata.append(tile)
            continue

        C, W, H = data.shape

        os.makedirs(os.path.join(args.out, str(args.zoom), str(tile.x)),
                    exist_ok=True)
        path = os.path.join(args.out, str(args.zoom), str(tile.x), str(tile.y))

        if args.type == "label":
            assert C == 1, "Error: Label raster input should be 1 band"

            ext = "png"
            img = Image.fromarray(np.squeeze(data, axis=0), mode="P")
            img.putpalette(make_palette(colors[0], colors[1]))
            img.save("{}.{}".format(path, ext), optimize=True)

        elif args.type == "image":
            assert C == 1 or C == 3, "Error: Image raster input should be either 1 or 3 bands"

            # GeoTiff could be 16 or 32bits
            if data.dtype == "uint16":
                data = np.uint8(data / 256)
            elif data.dtype == "uint32":
                data = np.uint8(data / (256 * 256))

            if C == 1:
                ext = "png"
                Image.fromarray(np.squeeze(data, axis=0),
                                mode="L").save("{}.{}".format(path, ext),
                                               optimize=True)
            elif C == 3:
                ext = "webp"
                Image.fromarray(np.moveaxis(data, 0, 2),
                                mode="RGB").save("{}.{}".format(path, ext),
                                                 optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile in tiles if tile not in tiles_nodata]
        web_ui(args.out, args.web_ui, tiles, tiles, ext, template)
    inputs_df.loc[inputs_df.value.isin(values)].geometry.values.data)

### Warp TNC resilient and connected landscapes to match Blueprint input area
print("Reading and warping TNC resilient and connected landscapes...")
with rasterio.open(src_dir / "Resilient_and_Connected20180308.tif") as rc:
    vrt = WarpedVRT(
        rc,
        width=window.width,
        height=window.height,
        nodata=int(rc.nodata),
        transform=transform,
        crs=DATA_CRS,
        resampling=Resampling.nearest,
    )

    data = vrt.read()[0]

# convert to uint8
data = np.where(data == int(rc.nodata), 255, data)

# clip to mask
data = np.where(mask == 1, data, 255).astype("uint8")

tnc_data = data.copy()

# Reclassify to incremental values based on lookup table
print("Reclassifying TNC data...")
table = read_dataframe(
    src_dir / "Resilient_and_Connected20180308.tif.vat.dbf",
    read_geometry=False,
    columns=["Value"],
    )

    inland_data = inland.read(1)

    # Marine data must be resampled to 30m with matching offset to inland
    vrt = WarpedVRT(
        marine,
        width=inland.width,
        height=inland.height,
        nodata=marine.nodata,
        transform=inland.transform,
        resampling=Resampling.nearest,
    )
    print("Reading and warping marine corridors...")

    marine_data = vrt.read()[0]

    # consolidate all values into a single raster, writing hubs over corridors
    data = np.ones(shape=inland_data.shape, dtype="uint8") * 255
    data[inland_data == 1] = 1
    data[marine_data == 1] = 3
    data[inland_hubs_data == 1] = 0
    data[marine_hubs_data == 1] = 2

    meta = inland.profile.copy()
    meta["dtype"] = "uint8"
    meta["nodata"] = 255

    with rasterio.open(out_dir / "corridors.tif", "w", **meta) as out:
        out.write(data.astype("uint8"), 1)
Ejemplo n.º 22
0
def extract_patches(image_uri, patch_radius, chm_uri, height_threshold,
                    labels_uri):
    """Extract patches from an image

    At each labeled pixel, we extract a square patch around it.  We discard the patch if it contains "no data" pixels or if the height according to the CHM is below a threshold.
    Arguments:
      image_uri: URI for image
      patch_radius: radius of patch (e.g. radius of 7 = 15x15 patch)
      chm_uri: URI for canopy height model
      height_threshold: threshold below which pixels will be discarded
      label_uri: URI for the labels raster
    Returns:
      image patches, patch labels
    """

    # "no data value" for labels
    label_ndv = 255

    # open the hyperspectral image
    image = rasterio.open(image_uri)
    image_ndv = image.meta['nodata']
    image_width = image.meta['width']
    image_height = image.meta['height']

    # open the CHM
    chm = rasterio.open(chm_uri)
    chm_vrt = WarpedVRT(chm,
                        crs=image.meta['crs'],
                        transform=image.meta['transform'],
                        width=image.meta['width'],
                        height=image.meta['height'],
                        resampling=Resampling.bilinear)

    with rasterio.open(labels_uri, 'r') as f:
        labels_raster = f.read(1)

    # create lists for the patches and labels
    image_patches = []
    patch_labels = []

    # get all labeled locations in the labels raster
    rows, cols = np.where(labels_raster != label_ndv)

    # extract the patch for each location
    # tqdm makes the cool progress bar that you see
    with tqdm.tqdm(total=len(rows)) as pbar:

        for row, col in zip(rows, cols):

            # increment the progress bar
            pbar.update(1)

            # check height in canopy height model
            chm_val = chm_vrt.read(1, window=((row, row + 1), (col, col + 1)))
            if chm_val == chm.nodata or chm_val <= height_threshold: continue

            # check patch bounds against image bounds
            if row - patch_radius < 0 or col - patch_radius < 0: continue
            if row + patch_radius >= image_height or col + patch_radius >= image_width:
                continue

            # get patch from image
            image_patch = image.read(window=((row - patch_radius,
                                              row + patch_radius + 1),
                                             (col - patch_radius,
                                              col + patch_radius + 1)))

            # check for nodata in patch
            if np.any(image_patch < 0): continue

            # append the patch and label to the lists
            image_patches.append(image_patch)
            patch_labels.append(labels_raster[row, col])

    # close the raster files
    image.close()
    chm.close()

    # stack the patches into a numpy array
    image_patches = np.stack(image_patches, axis=0)

    # re-order the dimensions so that we have (index,height,width,channels)
    image_patches = np.transpose(image_patches, axes=[0, 2, 3, 1])

    # stack the labels into a numpy array
    patch_labels = np.stack(patch_labels, axis=0)

    return image_patches, patch_labels