def _getitem(self, key): from rasterio.vrt import WarpedVRT band_key, window, squeeze_axis, np_inds = self._get_indexer(key) if not band_key or any(start == stop for (start, stop) in window): # no need to do IO shape = (len(band_key), ) + tuple(stop - start for (start, stop) in window) out = np.zeros(shape, dtype=self.dtype) else: with self.lock: riods = self.manager.acquire(needs_lock=False) if self.vrt_params is not None: riods = WarpedVRT(riods, **self.vrt_params) out = riods.read(band_key, window=window, masked=self.masked) if self.masked: out = np.ma.filled(out.astype(self.dtype), np.nan) if self.mask_and_scale: for band in np.atleast_1d(band_key): band_iii = band - 1 out[band_iii] = ( out[band_iii] * riods.scales[band_iii] + riods.offsets[band_iii]) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) return out[np_inds]
def coregister_raster(a_uri, b_uri, dest_file, resampling=Resampling.bilinear): """ Coregister raster a to the extent, resolution and projection of raster b. Write to dest_file. a_uri (read), b_uri (read), and dest_file (write) are passed to rasterio.open and thus are bound by its semantics. """ with rasterio.open(a_uri) as ds_a, rasterio.open(b_uri) as ds_b: vrt = WarpedVRT( ds_a, crs=ds_b.crs, height=ds_b.height, width=ds_b.width, resampling=resampling, transform=ds_b.transform, ) data = vrt.read(1) with rasterio.open( dest_file, "w", compress="lzw", count=1, crs=ds_b.crs, driver="GTiff", dtype=data.dtype, height=ds_b.height, width=ds_b.width, nodata=ds_a.nodata, tiled=True, transform=ds_b.transform, ) as dst: dst.write(data, indexes=1)
def extract_window(src, window, transform, nodata): """Extract raster data from src within window, and warp to DATA_CRS Parameters ---------- src : open rasterio Dataset window : rasterio Window boject transform : rasterio Transform object nodata : int or float Returns ------- 2d array """ vrt = WarpedVRT( src, width=window.width, height=window.height, nodata=nodata, transform=transform, crs=DATA_CRS, resampling=Resampling.nearest, ) return vrt.read()[0]
def test_vrt_src_kept_alive(path_rgb_byte_tif): """VRT source dataset is kept alive, preventing crashes""" with rasterio.open(path_rgb_byte_tif) as dst: vrt = WarpedVRT(dst, crs="EPSG:3857") assert (vrt.read() != 0).any() vrt.close()
def worker(path): raster = rasterio_open(path) w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) transform, _, _ = calculate_default_transform(raster.crs, "EPSG:3857", raster.width, raster.height, w, s, e, n) tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)] tiled = [] for tile in tiles: try: w, s, e, n = mercantile.xy_bounds(tile) # inspired by rio-tiler, cf: https://github.com/mapbox/rio-tiler/pull/45 warp_vrt = WarpedVRT( raster, crs="epsg:3857", resampling=Resampling.bilinear, add_alpha=False, transform=from_bounds(w, s, e, n, args.ts, args.ts), width=math.ceil((e - w) / transform.a), height=math.ceil((s - n) / transform.e), ) data = warp_vrt.read( out_shape=(len(raster.indexes), args.ts, args.ts), window=warp_vrt.window(w, s, e, n) ) image = np.moveaxis(data, 0, 2) # C,H,W -> H,W,C except: sys.exit("Error: Unable to tile {} from raster {}.".format(str(tile), raster)) tile_key = (str(tile.x), str(tile.y), str(tile.z)) if not args.label and len(tiles_map[tile_key]) == 1 and is_border(image): progress.update() continue if len(tiles_map[tile_key]) > 1: out = os.path.join(splits_path, str(tiles_map[tile_key].index(path))) else: out = args.out x, y, z = map(int, tile) if not args.label: ret = tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image) if args.label: ret = tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, image) if not ret: sys.exit("Error: Unable to write tile {} from raster {}.".format(str(tile), raster)) if len(tiles_map[tile_key]) == 1: progress.update() tiled.append(mercantile.Tile(x=x, y=y, z=z)) return tiled
def test_vrt_mem_src_kept_alive(path_rgb_byte_tif): """VRT in-memory source dataset is kept alive, preventing crashes""" with open(path_rgb_byte_tif, "rb") as fp: bands = fp.read() with MemoryFile(bands) as memfile, memfile.open() as dst: vrt = WarpedVRT(dst, crs="EPSG:3857") assert (vrt.read() != 0).any() vrt.close()
def worker(path): raster = rasterio_open(path) w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)] tiled = [] for tile in tiles: if cover and tile not in cover: continue w, s, e, n = mercantile.xy_bounds(tile) warp_vrt = WarpedVRT( raster, crs="epsg:3857", resampling=Resampling.bilinear, add_alpha=False, transform=from_bounds(w, s, e, n, width, height), width=width, height=height, ) data = warp_vrt.read(out_shape=(len(raster.indexes), width, height), window=warp_vrt.window(w, s, e, n)) image = np.moveaxis(data, 0, 2) # C,H,W -> H,W,C tile_key = (str(tile.x), str(tile.y), str(tile.z)) if ( not args.label and len(tiles_map[tile_key]) == 1 and is_nodata(image, args.nodata, args.nodata_threshold, args.keep_borders) ): progress.update() continue if len(tiles_map[tile_key]) > 1: out = os.path.join(splits_path, str(tiles_map[tile_key].index(path))) else: out = args.out x, y, z = map(int, tile) if not args.label: tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image) if args.label: tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, image) if len(tiles_map[tile_key]) == 1: progress.update() tiled.append(mercantile.Tile(x=x, y=y, z=z)) return tiled
def test_wrap_file(path_rgb_byte_tif): """A VirtualVRT has the expected dataset properties.""" with rasterio.open(path_rgb_byte_tif) as src: vrt = WarpedVRT(src, crs=DST_CRS) assert vrt.crs == CRS.from_string(DST_CRS) assert tuple(round(x, 1) for x in vrt.bounds) == ( -8789636.7, 2700460.0, -8524406.4, 2943560.2) assert vrt.name.startswith('WarpedVRT(') assert vrt.name.endswith('tests/data/RGB.byte.tif)') assert vrt.indexes == (1, 2, 3) assert vrt.nodatavals == (0, 0, 0) assert vrt.dtypes == ('uint8', 'uint8', 'uint8') assert vrt.read().shape == (3, 736, 803)
def test_wrap_file(path_rgb_byte_tif): """A VirtualVRT has the expected dataset properties.""" with rasterio.open(path_rgb_byte_tif) as src: vrt = WarpedVRT(src, dst_crs='EPSG:3857') assert vrt.crs == 'EPSG:3857' assert tuple(round(x, 1) for x in vrt.bounds) == ( -8789636.7, 2700460.0, -8524406.4, 2943560.2) assert vrt.name.startswith('WarpedVRT(') assert vrt.name.endswith('tests/data/RGB.byte.tif)') assert vrt.indexes == (1, 2, 3) assert vrt.nodatavals == (0, 0, 0) assert vrt.dtypes == ('uint8', 'uint8', 'uint8') assert vrt.read().shape == (3, 736, 803)
def test_wrap_file(path_rgb_byte_tif): """A VirtualVRT has the expected dataset properties.""" with rasterio.open(path_rgb_byte_tif) as src: vrt = WarpedVRT(src, crs=DST_CRS) assert vrt.crs == CRS.from_string(DST_CRS) assert tuple(round(x, 1) for x in vrt.bounds) == ( -8789636.7, 2700460.0, -8524406.4, 2943560.2 ) assert vrt.name.startswith("WarpedVRT(") assert vrt.name.endswith("tests/data/RGB.byte.tif)") assert vrt.indexes == (1, 2, 3) assert vrt.nodatavals == (0, 0, 0) assert vrt.dtypes == ("uint8", "uint8", "uint8") assert vrt.read().shape == (3, 736, 803)
def _read_window(self, vrt: WarpedVRT, dst_window: Window) -> MaskedArray: """Read window of input raster.""" dst_bounds: Bounds = bounds(dst_window, self.dst[self.default_format].transform) window = vrt.window(*dst_bounds) src_bounds = transform_bounds( self.dst[self.default_format].crs, self.src.crs, *dst_bounds ) LOGGER.debug( f"Read {dst_window} for Tile {self.tile_id} - this corresponds to bounds {src_bounds} in source" ) shape = ( len(self.layer.input_bands), int(round(dst_window.height)), int(round(dst_window.width)), ) try: return vrt.read( window=window, out_shape=shape, masked=True, ) except rasterio.RasterioIOError as e: if "Access window out of range" in str(e) and ( shape[1] == 1 or shape[2] == 1 ): LOGGER.warning( f"Access window out of range while reading {dst_window} for Tile {self.tile_id}. " "This is most likely due to subpixel misalignment. " "Returning empty array instead." ) return np.ma.array( data=np.zeros(shape=shape), mask=np.ones(shape=shape) ) else: LOGGER.warning( f"RasterioIO error while reading {dst_window} for Tile {self.tile_id}. " "Will make attempt to retry." ) raise
def _getitem(self, key): from rasterio.vrt import WarpedVRT band_key, window, squeeze_axis, np_inds = self._get_indexer(key) if not band_key or any(start == stop for (start, stop) in window): # no need to do IO shape = (len(band_key),) + tuple(stop - start for (start, stop) in window) out = np.zeros(shape, dtype=self.dtype) else: with self.lock: riods = self.manager.acquire(needs_lock=False) if self.vrt_params is not None: riods = WarpedVRT(riods, **self.vrt_params) out = riods.read(band_key, window=window) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) return out[np_inds]
def _getitem(self, key): from rasterio.vrt import WarpedVRT band_key, window, squeeze_axis, np_inds = self._get_indexer(key) if not band_key or any(start == stop for (start, stop) in window): # no need to do IO shape = (len(band_key),) + tuple( stop - start for (start, stop) in window) out = np.zeros(shape, dtype=self.dtype) else: with self.lock: riods = self.manager.acquire(needs_lock=False) if self.vrt_params is not None: riods = WarpedVRT(riods, **self.vrt_params) out = riods.read(band_key, window=window) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) return out[np_inds]
def _warper(img, transform, s_crs, t_crs, resampling): """ Warp an image returning it as a virtual file """ b, h, w = img.shape with MemoryFile() as memfile: with memfile.open( driver="GTiff", height=h, width=w, count=b, dtype=str(img.dtype.name), crs=s_crs, transform=transform, ) as mraster: for band in range(b): mraster.write(img[band, :, :], band + 1) # --- Virtual Warp vrt = WarpedVRT(mraster, crs=t_crs, resampling=resampling) img = vrt.read() return img, vrt
return x # go through all tiles of input image # run convolutional model on tile # write labels to output label raster with tqdm.tqdm(total=num_tiles_y * num_tiles_x) as pbar: for y in range(patch_radius, image_height - patch_radius, tile_size): for x in range(patch_radius, image_width - patch_radius, tile_size): pbar.update(1) window = Window(x - patch_radius, y - patch_radius, padded_tile_size, padded_tile_size) # get tile from chm chm_tile = chm_vrt.read(1, window=window) if chm_tile.shape[0] != padded_tile_size or chm_tile.shape[ 1] != padded_tile_size: pad = ((0, padded_tile_size - chm_tile.shape[0]), (0, padded_tile_size - chm_tile.shape[1])) chm_tile = np.pad(chm_tile, pad, mode='constant', constant_values=0) chm_tile = np.expand_dims(chm_tile, axis=0) chm_bad = chm_tile <= height_threshold # get tile from image image_tile = image.read(window=window) image_pad_y = padded_tile_size - image_tile.shape[1]
def main(): parser = argparse.ArgumentParser() parser.add_argument("-o", "--output-dir", type=str, default=DEFAULT_OUTPUT_DIR) parser.add_argument("--overwrite", action="store_true", help="Overwrite existing CSVs in output dir") args = parser.parse_args() client = boto3.client("s3") for experiment in EXPERIMENTS: logger.info("Processing Experiment: {}".format(experiment.id)) experiment_url = urlparse(experiment.s3_dir) output_csv = os.path.join(args.output_dir, "{}-iou-f1.csv".format(experiment.id)) if not args.overwrite and os.path.isfile(output_csv): logger.info("SKIPPING Experiment {}. Already exists".format( experiment.id)) continue list_result = client.list_objects_v2( Bucket=experiment_url.hostname, Prefix=os.path.join(experiment_url.path.lstrip("/"), "predict"), ) if list_result["IsTruncated"]: raise ValueError( "Didn't get all results, implement ContinuationToken") try: s3_keys = list_result["Contents"] except KeyError: print("WARNING: No predictions for {}. Continuing...".format( experiment)) continue with open(output_csv, "w") as fp_csv: csv_fieldnames = [ "chip_id", "f1_all", "f1_urban", "f1_not_urban", "iou_all", "iou_urban", "iou_not_urban", ] writer = csv.DictWriter(fp_csv, csv_fieldnames) writer.writeheader() for obj in s3_keys: chip_id = os.path.basename(obj["Key"]) prediction_tif_uri = "s3://{}/{}".format( experiment_url.hostname, obj["Key"]) logger.info("\tprediction: {}".format(prediction_tif_uri)) # TODO: Encode this in Experiment somehow # Replace chip S1 qualifier with QC to match filenames in ground truth dir if chip_id.startswith("SEN1FLOODS11"): truth_tif_uri = os.path.join(experiment.ground_truth_dir, chip_id) else: chip_name = chip_id.replace("_S1.tif", "_QC.tif") truth_tif_uri = os.path.join(experiment.ground_truth_dir, chip_name) logger.info("\ttruth: {}".format(truth_tif_uri)) with rasterio.open(prediction_tif_uri) as ds_p, rasterio.open( truth_tif_uri) as ds_t, rasterio.open( NLCD_TIF_URI) as ds_nlcd: nlcd_vrt = WarpedVRT( ds_nlcd, crs=ds_p.crs, height=ds_p.height, width=ds_p.width, resampling=Resampling.nearest, transform=ds_p.transform, nodata=NODATA, ) p_band = ds_p.read(1).flatten() t_band = ds_t.read(1).flatten() nlcd_band = nlcd_vrt.read(1).flatten() assert len(p_band) == len(t_band) assert len(p_band) == len(nlcd_band) # TODO: Encode this in Experiment somehow... # For two class USFIMR experiments, collapse ground truth three class # labels for water to two class: 2 (flood) + 1 (perm) -> 1 (water) if experiment.id.startswith( "USFIMR" ) and 'flood' not in experiment.id and 'permanent' not in experiment.id: t_band = np.where(t_band == 2, 1, t_band) nlcd_urban_mask = np.ma.masked_outside(nlcd_band, 21, 24).mask p_band_urban = np.ma.array(p_band, mask=nlcd_urban_mask).filled(NODATA) t_band_urban = np.ma.array(t_band, mask=nlcd_urban_mask).filled(NODATA) nlcd_not_urban_mask = np.ma.masked_inside(nlcd_band, 21, 24).mask p_band_not_urban = np.ma.array( p_band, mask=nlcd_not_urban_mask).filled(NODATA) t_band_not_urban = np.ma.array( t_band, mask=nlcd_not_urban_mask).filled(NODATA) labels = experiment.labels scores = { "chip_id": chip_id, "f1_all": f1_score(t_band, p_band, labels=labels, average=None), "f1_urban": f1_score(t_band_urban, p_band_urban, labels=labels, average=None), "f1_not_urban": f1_score(t_band_not_urban, p_band_not_urban, labels=labels, average=None), "iou_all": jaccard_score(t_band, p_band, labels=labels, average=None), "iou_urban": jaccard_score(t_band_urban, p_band_urban, labels=labels, average=None), "iou_not_urban": jaccard_score(t_band_not_urban, p_band_not_urban, labels=labels, average=None), } logger.debug("\t\t{}".format(scores)) writer.writerow(scores) client.upload_file( output_csv, experiment_url.hostname, os.path.join(experiment_url.path, "stats-iou-f1.csv"), )
mask, transform, window = get_input_area_mask("flm") print("Reading and warping Florida Marine Blueprint...") with rasterio.open(src_dir / "FLBlueprintVer1.tif") as src: nodata = 255 vrt = WarpedVRT( src, width=window.width, height=window.height, nodata=nodata, crs=DATA_CRS, transform=transform, resampling=Resampling.nearest, ) data = vrt.read()[0].astype("uint8") # remap data remap_table = np.array([[1, 1], [2, 2], [3, 3]], dtype="uint8") data = remap(data, remap_table, nodata=255) # apply input area mask data = np.where(mask == 1, data, nodata).astype("uint8") write_raster(outfilename, data, transform=transform, crs=DATA_CRS, nodata=nodata) add_overviews(outfilename)
def make_tile(level, tile): """ MISSING :param level: :param tile: :return: """ # x,y tile indexes x = tile[0][0] y = tile[0][1] def div_by_16(x): if divmod(x, 16)[1] == 0: return x return div_by_16(x - 1) # put tile in its respective dir out_dir = out_folder.joinpath(str(level)) if not out_dir.exists(): out_dir.mkdir(exist_ok=True) size_x = tile[1].width if tile[1].width > 0 else 1 size_y = tile[1].height if tile[1].height > 0 else 1 # Out file constructor # how many chars to use for representing the tiles. name_length = max(len(str(self.tileinfos[level].countTilesX)), len(str(self.tileinfos[level].countTilesY))) + 1 filename = name_template.format(basename=self.name, x=str(x).zfill(name_length), y=str(y).zfill(name_length)) out_filepath = out_dir.joinpath(filename) ## End profile = default_gtiff_profile profile.update( crs='epsg:4326', driver='GTiff', transform=tile[2], compress='lzw', count=1, width=size_x, height=size_y, blockysize=div_by_16(min(self.blockSize, tile[1].height)), blockxsize=div_by_16(min(self.blockSize, tile[1].width)), ) if level > 1: # except OSError: # # in this level, the amount of pixels that need to be resampled are too many. # # I am choosing to use pixel at the central coordinate of the processing tile # # Sample error: # # ERROR 1: Integer overflow : nSrcXSize=425985, nSrcYSize=163840 # TODO: don't be lazy, clean write try: self.tileinfos[level - 1] except KeyError: _meta = self.get_metadata(level - 1) self.tileinfos[level - 1] = TileInfo( _meta['width'], _meta['height'], self.TileWidth, self.TileHeight) finally: name_length = max( len(str(self.tileinfos[level - 1].countTilesX)), len(str(self.tileinfos[level - 1].countTilesY))) + 1 prev_lvl_tiles = tile_children(zoom=level, src=out_filepath, ndigits=name_length) vrt_handler = buildvrt(prev_lvl_tiles) with rio.open(vrt_handler) as src: profile.update(nodata=src.nodata, dtype=src.meta['dtype']) resolution_factor = pow(2, 1) lvlx_height = src.height / 2 lvlx_width = src.width / 2 lvlx_tranform = Affine(src.transform.a * resolution_factor, src.transform.b, src.transform.c, src.transform.d, src.transform.e * resolution_factor, src.transform.f) vrt = WarpedVRT(src, transform=lvlx_tranform, width=lvlx_width, height=lvlx_height) data = vrt.read(1) else: with self.get_dataset(level) as src: profile.update(nodata=src.nodata, dtype=src.meta['dtype']) data = src.read(1, window=tile[1]) try: with rio.open(out_filepath, 'w', **profile) as dst: window_out = Window(0, 0, size_x, size_y) dst.write(data, window=window_out, indexes=1) except: print(profile) raise Exception
def main(args): if args.type == "label": try: config = load_config(args.config) except: sys.exit("Error: Unable to load DataSet config file") classes = config["classes"]["title"] colors = config["classes"]["colors"] assert len(classes) == len(colors), "classes and colors coincide" assert len(colors) == 2, "only binary models supported right now" try: raster = rasterio_open(args.raster) w, s, e, n = bounds = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) transform, _, _ = calculate_default_transform(raster.crs, "EPSG:3857", raster.width, raster.height, *bounds) except: sys.exit("Error: Unable to load raster or deal with it's projection") tiles = [ mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom) ] tiles_nodata = [] for tile in tqdm(tiles, desc="Tiling", unit="tile", ascii=True): w, s, e, n = tile_bounds = mercantile.xy_bounds(tile) # Inspired by Rio-Tiler, cf: https://github.com/mapbox/rio-tiler/pull/45 warp_vrt = WarpedVRT( raster, crs="EPSG:3857", resampling=Resampling.bilinear, add_alpha=False, transform=from_bounds(*tile_bounds, args.size, args.size), width=math.ceil((e - w) / transform.a), height=math.ceil((s - n) / transform.e), ) data = warp_vrt.read(out_shape=(len(raster.indexes), args.size, args.size), window=warp_vrt.window(w, s, e, n)) # If no_data is set, remove all tiles with at least one whole border filled only with no_data (on all bands) if type(args.no_data) is not None and ( np.all(data[:, 0, :] == args.no_data) or np.all(data[:, -1, :] == args.no_data) or np.all(data[:, :, 0] == args.no_data) or np.all(data[:, :, -1] == args.no_data)): tiles_nodata.append(tile) continue C, W, H = data.shape os.makedirs(os.path.join(args.out, str(args.zoom), str(tile.x)), exist_ok=True) path = os.path.join(args.out, str(args.zoom), str(tile.x), str(tile.y)) if args.type == "label": assert C == 1, "Error: Label raster input should be 1 band" ext = "png" img = Image.fromarray(np.squeeze(data, axis=0), mode="P") img.putpalette(make_palette(colors[0], colors[1])) img.save("{}.{}".format(path, ext), optimize=True) elif args.type == "image": assert C == 1 or C == 3, "Error: Image raster input should be either 1 or 3 bands" # GeoTiff could be 16 or 32bits if data.dtype == "uint16": data = np.uint8(data / 256) elif data.dtype == "uint32": data = np.uint8(data / (256 * 256)) if C == 1: ext = "png" Image.fromarray(np.squeeze(data, axis=0), mode="L").save("{}.{}".format(path, ext), optimize=True) elif C == 3: ext = "webp" Image.fromarray(np.moveaxis(data, 0, 2), mode="RGB").save("{}.{}".format(path, ext), optimize=True) if args.web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template tiles = [tile for tile in tiles if tile not in tiles_nodata] web_ui(args.out, args.web_ui, tiles, tiles, ext, template)
inputs_df.loc[inputs_df.value.isin(values)].geometry.values.data) ### Warp TNC resilient and connected landscapes to match Blueprint input area print("Reading and warping TNC resilient and connected landscapes...") with rasterio.open(src_dir / "Resilient_and_Connected20180308.tif") as rc: vrt = WarpedVRT( rc, width=window.width, height=window.height, nodata=int(rc.nodata), transform=transform, crs=DATA_CRS, resampling=Resampling.nearest, ) data = vrt.read()[0] # convert to uint8 data = np.where(data == int(rc.nodata), 255, data) # clip to mask data = np.where(mask == 1, data, 255).astype("uint8") tnc_data = data.copy() # Reclassify to incremental values based on lookup table print("Reclassifying TNC data...") table = read_dataframe( src_dir / "Resilient_and_Connected20180308.tif.vat.dbf", read_geometry=False, columns=["Value"],
) inland_data = inland.read(1) # Marine data must be resampled to 30m with matching offset to inland vrt = WarpedVRT( marine, width=inland.width, height=inland.height, nodata=marine.nodata, transform=inland.transform, resampling=Resampling.nearest, ) print("Reading and warping marine corridors...") marine_data = vrt.read()[0] # consolidate all values into a single raster, writing hubs over corridors data = np.ones(shape=inland_data.shape, dtype="uint8") * 255 data[inland_data == 1] = 1 data[marine_data == 1] = 3 data[inland_hubs_data == 1] = 0 data[marine_hubs_data == 1] = 2 meta = inland.profile.copy() meta["dtype"] = "uint8" meta["nodata"] = 255 with rasterio.open(out_dir / "corridors.tif", "w", **meta) as out: out.write(data.astype("uint8"), 1)
def extract_patches(image_uri, patch_radius, chm_uri, height_threshold, labels_uri): """Extract patches from an image At each labeled pixel, we extract a square patch around it. We discard the patch if it contains "no data" pixels or if the height according to the CHM is below a threshold. Arguments: image_uri: URI for image patch_radius: radius of patch (e.g. radius of 7 = 15x15 patch) chm_uri: URI for canopy height model height_threshold: threshold below which pixels will be discarded label_uri: URI for the labels raster Returns: image patches, patch labels """ # "no data value" for labels label_ndv = 255 # open the hyperspectral image image = rasterio.open(image_uri) image_ndv = image.meta['nodata'] image_width = image.meta['width'] image_height = image.meta['height'] # open the CHM chm = rasterio.open(chm_uri) chm_vrt = WarpedVRT(chm, crs=image.meta['crs'], transform=image.meta['transform'], width=image.meta['width'], height=image.meta['height'], resampling=Resampling.bilinear) with rasterio.open(labels_uri, 'r') as f: labels_raster = f.read(1) # create lists for the patches and labels image_patches = [] patch_labels = [] # get all labeled locations in the labels raster rows, cols = np.where(labels_raster != label_ndv) # extract the patch for each location # tqdm makes the cool progress bar that you see with tqdm.tqdm(total=len(rows)) as pbar: for row, col in zip(rows, cols): # increment the progress bar pbar.update(1) # check height in canopy height model chm_val = chm_vrt.read(1, window=((row, row + 1), (col, col + 1))) if chm_val == chm.nodata or chm_val <= height_threshold: continue # check patch bounds against image bounds if row - patch_radius < 0 or col - patch_radius < 0: continue if row + patch_radius >= image_height or col + patch_radius >= image_width: continue # get patch from image image_patch = image.read(window=((row - patch_radius, row + patch_radius + 1), (col - patch_radius, col + patch_radius + 1))) # check for nodata in patch if np.any(image_patch < 0): continue # append the patch and label to the lists image_patches.append(image_patch) patch_labels.append(labels_raster[row, col]) # close the raster files image.close() chm.close() # stack the patches into a numpy array image_patches = np.stack(image_patches, axis=0) # re-order the dimensions so that we have (index,height,width,channels) image_patches = np.transpose(image_patches, axes=[0, 2, 3, 1]) # stack the labels into a numpy array patch_labels = np.stack(patch_labels, axis=0) return image_patches, patch_labels