def gpu_worker(rank, world_size, lock_file, args, config, dataset, palette, transparency): dist.init_process_group(backend="nccl", init_method="file://" + lock_file, world_size=world_size, rank=rank) torch.cuda.set_device(rank) chkpt = torch.load(args.checkpoint, map_location=torch.device(rank)) nn_module = load_module("neat_eo.nn.{}".format(chkpt["nn"].lower())) nn = getattr(nn_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"], chkpt["encoder"].lower()).to(rank) nn = DistributedDataParallel(nn, device_ids=[rank], find_unused_parameters=True) chkpt = torch.load(os.path.expanduser(args.checkpoint), map_location="cuda:{}".format(rank)) assert nn.module.version == chkpt["model_version"], "Model Version mismatch" nn.load_state_dict(chkpt["state_dict"]) sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) loader = DataLoader(dataset, batch_size=args.bs, shuffle=False, num_workers=args.workers, sampler=sampler) assert len(loader), "Empty predict dataset directory. Check your path." C, W, H = chkpt["shape_out"] nn.eval() with torch.no_grad(): dataloader = tqdm(loader, desc="Predict", unit="Batch/GPU", ascii=True) if rank == 0 else loader for images, tiles in dataloader: if args.metatiles: N = images.shape[0] qs = int(W / 4) hs = int(W / 2) ts = int(W) # fmt:off probs = np.zeros((N, C, W, H), dtype=np.float) probs[:, :, 0:hs, 0:hs] = nn(images[:, :, 0:ts, 0:ts].to(rank)).data.cpu().numpy()[:, :, qs:-qs, qs:-qs] probs[:, :, 0:hs, hs:] = nn(images[:, :, 0:ts, hs:].to(rank)).data.cpu().numpy()[:, :, qs:-qs, qs:-qs] probs[:, :, hs:, 0:hs] = nn(images[:, :, hs:, 0:ts].to(rank)).data.cpu().numpy()[:, :, qs:-qs, qs:-qs] probs[:, :, hs:, hs:] = nn(images[:, :, hs:, hs:].to(rank)).data.cpu().numpy()[:, :, qs:-qs, qs:-qs] # fmt:on else: probs = nn(images.to(rank)).data.cpu().numpy() for tile, prob in zip(tiles, probs): x, y, z = list(map(int, tile)) mask = np.zeros((W, H), dtype=np.uint8) for c in range(C): mask += np.around(prob[c, :, :]).astype(np.uint8) * c tile_label_to_file(args.out, mercantile.Tile(x, y, z), palette, transparency, mask)
def worker(tile_key): if len(tiles_map[tile_key]) == 1: return image = np.zeros((width, height, len(args.bands)), np.uint8) x, y, z = map(int, tile_key) for i in range(len(tiles_map[tile_key])): root = os.path.join(splits_path, str(i)) _, path = tile_from_xyz(root, x, y, z) if not args.label: split = tile_image_from_file(path) if args.label: split = tile_label_from_file(path) if len(split.shape) == 2: split = split.reshape((width, height, 1)) # H,W -> H,W,C assert image.shape == split.shape, "{}, {}".format( image.shape, split.shape) image[np.where(image == 0)] += split[np.where(image == 0)] if not args.label and is_nodata(image, args.nodata, args.nodata_threshold, args.keep_borders): progress.update() return tile = mercantile.Tile(x=x, y=y, z=z) if not args.label: tile_image_to_file(args.out, tile, image) if args.label: tile_label_to_file(args.out, tile, palette, image) progress.update() return tile
def worker(path): if path in skip: return None raster = rasterio_open(path) w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)] tiled = [] for tile in tiles: if cover and tile not in cover: continue w, s, e, n = mercantile.xy_bounds(tile) warp_vrt = WarpedVRT( raster, crs="epsg:3857", resampling=Resampling.bilinear, add_alpha=False, transform=from_bounds(w, s, e, n, width, height), width=width, height=height, ) data = warp_vrt.read( out_shape=(len(args.bands), width, height), indexes=args.bands, window=warp_vrt.window(w, s, e, n) ) if data.dtype == "uint16": # GeoTiff could be 16 bits data = np.uint8(data / 256) elif data.dtype == "uint32": # or 32 bits data = np.uint8(data / (256 * 256)) image = np.moveaxis(data, 0, 2) # C,H,W -> H,W,C tile_key = (str(tile.x), str(tile.y), str(tile.z)) if ( not args.label and len(tiles_map[tile_key]) == 1 and is_nodata(image, args.nodata, args.nodata_threshold, args.keep_borders) ): progress.update() continue if len(tiles_map[tile_key]) > 1: out = os.path.join(splits_path, str(tiles_map[tile_key].index(path))) else: out = args.out x, y, z = map(int, tile) if not args.label: tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image, ext=ext) if args.label: tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, args.nodata, image) if len(tiles_map[tile_key]) == 1: tiled.append(mercantile.Tile(x=x, y=y, z=z)) progress.update() raster.close() return tiled
def main(args): assert not (args.geojson is not None and args.pg is not None), "You have to choose between --pg or --geojson" assert len(args.ts.split(",")) == 2, "--ts expect width,height value (e.g 512,512)" config = load_config(args.config) check_classes(config) args.pg = config["auth"]["pg"] if not args.pg and "pg" in config["auth"].keys() else args.pg assert not (args.sql and not args.pg), "With --sql option, --pg dsn setting must also be provided" palette, transparency = make_palette([classe["color"] for classe in config["classes"]], complementary=True) index = [config["classes"].index(classe) for classe in config["classes"] if classe["title"] == args.type] assert index, "Requested type is not contains in your config file classes." burn_value = index[0] assert 0 < burn_value <= 255 if args.sql: assert "limit" not in args.sql.lower(), "LIMIT is not supported" assert "TILE_GEOM" in args.sql, "TILE_GEOM filter not found in your SQL" sql = re.sub(r"ST_Intersects( )*\((.*)?TILE_GEOM(.*)?\)", "1=1", args.sql, re.I) assert sql and sql != args.sql, "Incorrect TILE_GEOM filter in your SQL" if os.path.dirname(os.path.expanduser(args.out)): os.makedirs(os.path.expanduser(args.out), exist_ok=True) args.out = os.path.expanduser(args.out) log = Logs(os.path.join(args.out, "log"), out=sys.stderr) tiles = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))] assert len(tiles), "Empty Cover: {}".format(args.cover) if args.geojson: zoom = tiles[0].z assert not [tile for tile in tiles if tile.z != zoom], "Unsupported zoom mixed cover. Use PostGIS instead" feature_map = collections.defaultdict(list) log.log("neo rasterize - Compute spatial index") for geojson_file in args.geojson: with open(os.path.expanduser(geojson_file)) as geojson: feature_collection = json.load(geojson) srid = geojson_srid(feature_collection) for i, feature in enumerate(tqdm(feature_collection["features"], ascii=True, unit="feature")): feature_map = geojson_parse_feature(zoom, srid, feature_map, feature, args.buffer) features = args.geojson if args.sql: conn = psycopg2.connect(args.pg) db = conn.cursor() db.execute("""SELECT ST_Srid("1") AS srid FROM ({} LIMIT 1) AS t("1")""".format(sql)) srid = db.fetchone()[0] assert srid and int(srid) > 0, "Unable to retrieve geometry SRID." features = args.sql if not len(feature_map): log.log("-----------------------------------------------") log.log("NOTICE: no feature to rasterize, seems peculiar") log.log("-----------------------------------------------") log.log("neo rasterize - rasterizing {} from {} on cover {}".format(args.type, features, args.cover)) with open(os.path.join(os.path.expanduser(args.out), args.type.lower() + "_cover.csv"), mode="w") as cover: for tile in tqdm(tiles, ascii=True, unit="tile"): geojson = None if args.sql: w, s, e, n = tile_bbox(tile) tile_geom = "ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {})".format(w, s, e, n, srid) query = """ WITH sql AS ({}), geom AS (SELECT "1" AS geom FROM sql AS t("1")), json AS (SELECT '{{"type": "Feature", "geometry": ' || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6) || '}}' AS features FROM geom) SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}' FROM json """.format( args.sql.replace("TILE_GEOM", tile_geom) ) db.execute(query) row = db.fetchone() try: geojson = json.loads(row[0])["features"] if row and row[0] else None except Exception: log.log("Warning: Invalid geometries, skipping {}".format(tile)) conn = psycopg2.connect(args.pg) db = conn.cursor() if args.geojson: geojson = feature_map[tile] if tile in feature_map else None if geojson: num = len(geojson) out = geojson_tile_burn(tile, geojson, 4326, list(map(int, args.ts.split(","))), burn_value) if not geojson or out is None: num = 0 out = np.zeros(shape=list(map(int, args.ts.split(","))), dtype=np.uint8) tile_label_to_file(args.out, tile, palette, transparency, out, append=args.append) cover.write("{},{},{} {}{}".format(tile.x, tile.y, tile.z, num, os.linesep)) if not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "." tiles = [tile for tile in tiles_from_csv(args.cover)] web_ui(args.out, base_url, tiles, tiles, "png", template)