示例#1
0
    def test_read_tiles(self):
        filename = "tests/fixtures/tiles.csv"
        tiles = [tile for tile in tiles_from_csv(filename)]
        tiles.sort()

        self.assertEqual(len(tiles), 3)
        self.assertEqual(tiles[1], mercantile.Tile(69623, 104945, 18))
示例#2
0
def main(args):
    assert args.out or args.delete, "out parameter is required"
    args.out = os.path.expanduser(args.out)

    print("RoboSat.pink - subset {} with cover {}, on CPU".format(
        args.dir, args.cover),
          file=sys.stderr,
          flush=True)

    ext = set()
    tiles = set(tiles_from_csv(os.path.expanduser(args.cover)))

    for tile in tqdm(tiles, ascii=True, unit="tiles"):

        if isinstance(tile, mercantile.Tile):
            src_tile = tile_from_xyz(args.dir, tile.x, tile.y, tile.z)
            if not src_tile:
                print("WARNING: skipping tile {}".format(tile),
                      file=sys.stderr,
                      flush=True)
                continue
            _, src = src_tile
            dst_dir = os.path.join(args.out, str(tile.z), str(tile.x))
        else:
            src = tile
            dst_dir = os.path.join(args.out, os.path.dirname(tile))

        assert os.path.isfile(src)
        dst = os.path.join(dst_dir, os.path.basename(src))
        ext.add(os.path.splitext(src)[1][1:])

        if not os.path.isdir(dst_dir):
            os.makedirs(dst_dir, exist_ok=True)

        if args.delete:
            os.remove(src)
            assert not os.path.lexists(src)
        elif args.copy:
            shutil.copyfile(src, dst)
            assert os.path.exists(dst)
        else:
            if os.path.islink(dst):
                os.remove(dst)
            os.symlink(os.path.relpath(src, os.path.dirname(dst)), dst)
            assert os.path.islink(dst)

    if tiles and not args.no_web_ui and not args.delete:
        assert len(ext) == 1, "ERROR: Mixed extensions, can't generate Web UI"
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiles, tiles, list(ext)[0], template)
示例#3
0
def main(args):
    if not args.out and not args.delete:
        sys.exit("ERROR: out parameter is required")

    args.out = os.path.expanduser(args.out)
    extension = ""

    print("RoboSat.pink - subset {} with cover {}".format(
        args.dir, args.cover))

    tiles = set(tiles_from_csv(os.path.expanduser(args.cover)))
    for tile in tqdm(tiles, desc="Subset", unit="tiles", ascii=True):

        paths = glob(
            os.path.join(os.path.expanduser(args.dir), str(tile.z),
                         str(tile.x), "{}.*".format(tile.y)))
        if len(paths) != 1:
            print("Warning: {} skipped.".format(tile))
            continue
        src = paths[0]

        try:
            if not os.path.isdir(
                    os.path.join(args.out, str(tile.z), str(tile.x))):
                os.makedirs(os.path.join(args.out, str(tile.z), str(tile.x)),
                            exist_ok=True)

            extension = os.path.splitext(src)[1][1:]
            dst = os.path.join(args.out, str(tile.z), str(tile.x),
                               "{}.{}".format(tile.y, extension))

            if args.move:
                assert os.path.isfile(src)
                shutil.move(src, dst)

            elif args.delete:
                assert os.path.isfile(src)
                os.remove(src)

            else:
                shutil.copyfile(src, dst)

        except:
            sys.exit("Error: Unable to process tile: {}".format(str(tile)))

    if not args.no_web_ui and not args.delete:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        web_ui(args.out, base_url, tiles, tiles, extension, template)
示例#4
0
def main(args):
    if not args.out and args.mode in ["copy", "move"]:
        sys.exit("Zoom parameter is required")

    tiles = set(tiles_from_csv(args.cover))
    extension = ""

    for tile in tqdm(tiles, desc="Subset", unit="tiles", ascii=True):

        paths = glob(
            os.path.join(args.dir, str(tile.z), str(tile.x),
                         "{}.*".format(tile.y)))
        if len(paths) != 1:
            print("Warning: {} skipped.".format(tile))
            continue
        src = paths[0]

        try:
            if not os.path.isdir(
                    os.path.join(args.out, str(tile.z), str(tile.x))):
                os.makedirs(os.path.join(args.out, str(tile.z), str(tile.x)),
                            exist_ok=True)

            extension = os.path.splitext(src)[1][1:]
            dst = os.path.join(args.out, str(tile.z), str(tile.x),
                               "{}.{}".format(tile.y, extension))

            if args.mode == "move":
                assert os.path.isfile(src)
                shutil.move(src, dst)

            if args.mode == "copy":
                shutil.copyfile(src, dst)

            if args.mode == "delete":
                assert os.path.isfile(src)
                os.remove(src)

        except:
            sys.exit("Error: Unable to process {}".format(tile))

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        web_ui(args.out, base_url, tiles, tiles, extension, template)
示例#5
0
def main(args):
    assert args.out or args.delete, "out parameter is required"
    args.out = os.path.expanduser(args.out)
    if not args.workers:
        args.workers = 4 if os.cpu_count() >= 4 else os.cpu_count()

    print("RoboSat.pink - subset {} with cover {}, on CPU, with {} workers".format(args.dir, args.cover, args.workers))

    tiles = set(tiles_from_csv(os.path.expanduser(args.cover)))
    progress = tqdm(total=len(tiles), ascii=True, unit="tiles")
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(tile):

            if isinstance(tile, mercantile.Tile):
                _, src = tile_from_xyz(args.dir, tile.x, tile.y, tile.z)
                dst_dir = os.path.join(args.out, str(tile.z), str(tile.x))
            else:
                src = tile
                dst_dir = os.path.join(args.out, os.path.dirname(tile))

            assert os.path.isfile(src)
            dst = os.path.join(dst_dir, os.path.basename(src))

            if not os.path.isdir(dst_dir):
                os.makedirs(dst_dir, exist_ok=True)

            if args.move:
                shutil.move(src, dst)

            elif args.delete:
                os.remove(src)

            else:
                shutil.copyfile(src, dst)

            return os.path.splitext(src)[1][1:]  # ext

        for extension in executor.map(worker, tiles):
            progress.update()

    if not args.no_web_ui and not args.delete:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        web_ui(args.out, base_url, tiles, tiles, extension, template)
示例#6
0
def main(args):

    tiles = list(tiles_from_csv(args.cover))
    os.makedirs(os.path.expanduser(args.out), exist_ok=True)

    if not args.workers:
        args.workers = max(1, math.floor(os.cpu_count() * 0.5))

    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)
    log.log(
        "RoboSat.pink - download with {} workers, at max {} req/s, from: {}".
        format(args.workers, args.rate, args.url))

    already_dl = 0
    dl = 0

    with requests.Session() as session:

        progress = tqdm(total=len(tiles), ascii=True, unit="image")
        with futures.ThreadPoolExecutor(args.workers) as executor:

            def worker(tile):
                tick = time.monotonic()
                progress.update()

                try:
                    x, y, z = map(str, [tile.x, tile.y, tile.z])
                    os.makedirs(os.path.join(args.out, z, x), exist_ok=True)
                except:
                    return tile, None, False

                path = os.path.join(args.out, z, x,
                                    "{}.{}".format(y, args.format))
                if os.path.isfile(path):  # already downloaded
                    return tile, None, True

                if args.type == "XYZ":
                    url = args.url.format(x=tile.x, y=tile.y, z=tile.z)
                elif args.type == "TMS":
                    y = (2**tile.z) - tile.y - 1
                    url = args.url.format(x=tile.x, y=y, z=tile.z)
                elif args.type == "WMS":
                    xmin, ymin, xmax, ymax = xy_bounds(tile)
                    url = args.url.format(xmin=xmin,
                                          ymin=ymin,
                                          xmax=xmax,
                                          ymax=ymax)

                res = tile_image_from_url(session, url, args.timeout)
                if res is None:  # let's retry once
                    res = tile_image_from_url(session, url, args.timeout)
                    if res is None:
                        return tile, url, False

                try:
                    tile_image_to_file(args.out, tile, res)
                except OSError:
                    return tile, url, False

                tock = time.monotonic()

                time_for_req = tock - tick
                time_per_worker = args.workers / args.rate

                if time_for_req < time_per_worker:
                    time.sleep(time_per_worker - time_for_req)

                return tile, url, True

            for tile, url, ok in executor.map(worker, tiles):
                if url and ok:
                    dl += 1
                elif not url and ok:
                    already_dl += 1
                else:
                    log.log("Warning:\n {} failed, skipping.\n {}\n".format(
                        tile, url))

    if already_dl:
        log.log(
            "Notice: {} tiles were already downloaded previously, and so skipped now."
            .format(already_dl))
    if already_dl + dl == len(tiles):
        log.log("Notice: Coverage is fully downloaded.")

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        web_ui(args.out, base_url, tiles, tiles, args.format, template)
示例#7
0
def main(args):

    assert not (args.sql and
                args.geojson), "You can only use at once --pg OR --geojson."
    assert not (args.pg and not args.sql
                ), "With PostgreSQL --pg, --sql must also be provided"
    assert len(args.ts.split(
        ",")) == 2, "--ts expect width,height value (e.g 512,512)"

    config = load_config(args.config)
    check_classes(config)

    palette = make_palette([classe["color"] for classe in config["classes"]],
                           complementary=True)
    index = [
        config["classes"].index(classe) for classe in config["classes"]
        if classe["title"] == args.type
    ]
    assert index, "Requested type is not contains in your config file classes."
    burn_value = int(math.pow(2, index[0] - 1))  # 8bits One Hot Encoding
    assert 0 <= burn_value <= 128

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    if args.geojson:

        tiles = [
            tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
        ]
        assert tiles, "Empty cover"

        zoom = tiles[0].z
        assert not [tile for tile in tiles if tile.z != zoom
                    ], "Unsupported zoom mixed cover. Use PostGIS instead"

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                feature_collection = json.load(geojson)
                srid = geojson_srid(feature_collection)

                feature_map = collections.defaultdict(list)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):
                    feature_map = geojson_parse_feature(
                        zoom, srid, feature_map, feature)

        features = args.geojson

    if args.pg:

        conn = psycopg2.connect(args.pg)
        db = conn.cursor()

        assert "limit" not in args.sql.lower(), "LIMIT is not supported"
        assert "TILE_GEOM" in args.sql, "TILE_GEOM filter not found in your SQL"
        sql = re.sub(r"ST_Intersects( )*\((.*)?TILE_GEOM(.*)?\)", "1=1",
                     args.sql, re.I)
        assert sql and sql != args.sql

        db.execute(
            """SELECT ST_Srid("1") AS srid FROM ({} LIMIT 1) AS t("1")""".
            format(sql))
        srid = db.fetchone()[0]
        assert srid and int(srid) > 0, "Unable to retrieve geometry SRID."

        features = args.sql

    log.log(
        "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format(
            args.type, features, args.cover))
    with open(os.path.join(os.path.expanduser(args.out),
                           "instances_" + args.type.lower() + ".cover"),
              mode="w") as cover:

        for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))),
                         ascii=True,
                         unit="tile"):

            geojson = None

            if args.pg:

                w, s, e, n = tile_bbox(tile)
                tile_geom = "ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {})".format(
                    w, s, e, n, srid)

                query = """
                WITH
                  sql  AS ({}),
                  geom AS (SELECT "1" AS geom FROM sql AS t("1")),
                  json AS (SELECT '{{"type": "Feature", "geometry": '
                         || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
                         || '}}' AS features
                        FROM geom)
                SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}'
                FROM json
                """.format(args.sql.replace("TILE_GEOM", tile_geom))

                db.execute(query)
                row = db.fetchone()
                try:
                    geojson = json.loads(
                        row[0])["features"] if row and row[0] else None
                except Exception:
                    log.log("Warning: Invalid geometries, skipping {}".format(
                        tile))
                    conn = psycopg2.connect(args.pg)
                    db = conn.cursor()

            if args.geojson:
                geojson = feature_map[tile] if tile in feature_map else None

            if geojson:
                num = len(geojson)
                out = geojson_tile_burn(tile, geojson, 4326,
                                        list(map(int, args.ts.split(","))),
                                        burn_value)

            if not geojson or out is None:
                num = 0
                out = np.zeros(shape=list(map(int, args.ts.split(","))),
                               dtype=np.uint8)

            tile_label_to_file(args.out,
                               tile,
                               palette,
                               out,
                               append=args.append)
            cover.write("{},{},{}  {}{}".format(tile.x, tile.y, tile.z, num,
                                                os.linesep))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
示例#8
0
def main(args):

    if (args.geojson and args.postgis) or (not args.geojson
                                           and not args.postgis):
        sys.exit(
            "ERROR: Input features to rasterize must be either GeoJSON or PostGIS"
        )

    if args.postgis and not args.pg_dsn:
        sys.exit(
            "ERROR: With PostGIS input features, --pg_dsn must be provided")

    config = load_config(args.config)
    check_classes(config)
    palette = make_palette(*[classe["color"] for classe in config["classes"]],
                           complementary=True)
    burn_value = next(config["classes"].index(classe)
                      for classe in config["classes"]
                      if classe["title"] == args.type)
    if "burn_value" not in locals():
        sys.exit(
            "ERROR: asked type to rasterize is not contains in your config file classes."
        )

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def geojson_parse_polygon(zoom, srid, feature_map, polygon, i):

        try:
            if srid != 4326:
                polygon = [
                    xy for xy in geojson_reproject(
                        {
                            "type": "feature",
                            "geometry": polygon
                        }, srid, 4326)
                ][0]

            for i, ring in enumerate(
                    polygon["coordinates"]
            ):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[
                    x, y
                ] for point in ring for x, y in zip([point[0]], [point[1]])]

            if polygon["coordinates"]:
                for tile in burntiles.burn([{
                        "type": "feature",
                        "geometry": polygon
                }],
                                           zoom=zoom):
                    feature_map[mercantile.Tile(*tile)].append({
                        "type":
                        "feature",
                        "geometry":
                        polygon
                    })

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def geojson_parse_geometry(zoom, srid, feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = geojson_parse_polygon(zoom, srid, feature_map,
                                                geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = geojson_parse_polygon(zoom, srid, feature_map, {
                    "type": "Polygon",
                    "coordinates": polygon
                }, i)
        else:
            log.log(
                "Notice: {} is a non surfacic geometry type, skipping feature {}"
                .format(geometry["type"], i))

        return feature_map

    if args.geojson:

        try:
            tiles = [
                tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
            ]
            zoom = tiles[0].z
            assert not [tile for tile in tiles if tile.z != zoom]
        except:
            sys.exit("ERROR: Inconsistent cover {}".format(args.cover))

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                feature_collection = json.load(geojson)

                try:
                    crs_mapping = {"CRS84": "4326", "900913": "3857"}
                    srid = feature_collection["crs"]["properties"][
                        "name"].split(":")[-1]
                    srid = int(srid) if srid not in crs_mapping else int(
                        crs_mapping[srid])
                except:
                    srid = int(4326)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):

                    if feature["geometry"]["type"] == "GeometryCollection":
                        for geometry in feature["geometry"]["geometries"]:
                            feature_map = geojson_parse_geometry(
                                zoom, srid, feature_map, geometry, i)
                    else:
                        feature_map = geojson_parse_geometry(
                            zoom, srid, feature_map, feature["geometry"], i)
        features = args.geojson

    if args.postgis:

        pg_conn = psycopg2.connect(args.pg_dsn)
        pg = pg_conn.cursor()

        pg.execute(
            "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(
                args.postgis))
        try:
            srid = pg.fetchone()[0]
        except Exception:
            sys.exit("Unable to retrieve geometry SRID.")

        features = args.postgis

    log.log(
        "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format(
            args.type, features, args.cover))
    with open(os.path.join(os.path.expanduser(args.out), "instances.cover"),
              mode="w") as cover:

        for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))),
                         ascii=True,
                         unit="tile"):

            if args.postgis:

                s, w, e, n = mercantile.bounds(tile)

                query = """
                WITH
                  a AS ({}),
                  b AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}) AS geom)
                SELECT '{{
  "type": "FeatureCollection", "features": [{{"type": "Feature", "geometry": '
  || ST_AsGeoJSON(ST_Transform(ST_Intersection(a.geom, b.geom), 4326), 6)
  || '}}]}}'
                FROM a, b
                WHERE ST_Intersects(a.geom, b.geom)
                """.format(args.postgis, s, w, e, n, srid)

                try:
                    pg.execute(query)
                    row = pg.fetchone()
                    geojson = json.loads(row[0])["features"] if row else None

                except Exception:
                    log.log("Warning: Invalid geometries, skipping {}".format(
                        tile))
                    pg_conn = psycopg2.connect(args.pg_dsn)
                    pg = pg_conn.cursor()

            if args.geojson:
                geojson = feature_map[tile] if tile in feature_map else None

            if geojson:
                num = len(geojson)
                out = geojson_tile_burn(tile, geojson, 4326, args.ts,
                                        burn_value)

            if not geojson or out is None:
                num = 0
                out = np.zeros(shape=(args.ts, args.ts), dtype=np.uint8)

            tile_label_to_file(args.out, tile, palette, out)
            cover.write("{},{},{}  {}{}".format(tile.x, tile.y, tile.z, num,
                                                os.linesep))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
示例#9
0
def main(args):

    assert not (args.extent and args.splits
                ), "--splits and --extent are mutually exclusive options."
    assert not (args.extent and
                len(args.out) > 1), "--extent option imply a single output."
    assert not (args.sql and not args.pg), "--sql option imply --pg"
    assert (
        int(args.bbox is not None) + int(args.geojson is not None) +
        int(args.sql is not None) + int(args.dir is not None) +
        int(args.raster is not None) + int(args.cover is not None) == 1
    ), "One, and only one, input type must be provided, among: --dir, --bbox, --cover, --raster, --geojson or --sql"

    if args.bbox:
        try:
            w, s, e, n, crs = args.bbox.split(",")
            w, s, e, n = map(float, (w, s, e, n))
        except:
            crs = None
            w, s, e, n = map(float, args.bbox.split(","))
        assert isinstance(w, float) and isinstance(
            s, float), "Invalid bbox parameter."

    if args.splits:
        splits = [int(split) for split in args.splits.split("/")]
        assert len(splits) == len(args.out) and 0 < sum(
            splits) <= 100, "Invalid split value or incoherent with out paths."

    assert not (not args.zoom and
                (args.geojson or args.bbox
                 or args.raster)), "Zoom parameter is required."

    args.out = [os.path.expanduser(out) for out in args.out]

    cover = []

    if args.raster:
        print("RoboSat.pink - cover from {} at zoom {}".format(
            args.raster, args.zoom),
              file=sys.stderr,
              flush=True)
        with rasterio_open(os.path.expanduser(args.raster)) as r:
            w, s, e, n = transform_bounds(r.crs, "EPSG:4326", *r.bounds)
            assert isinstance(w, float) and isinstance(
                s, float), "Unable to deal with raster projection"

            cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.geojson:
        print("RoboSat.pink - cover from {} at zoom {}".format(
            args.geojson, args.zoom),
              file=sys.stderr,
              flush=True)
        with open(os.path.expanduser(args.geojson)) as f:
            feature_collection = json.load(f)
            srid = geojson_srid(feature_collection)
            feature_map = collections.defaultdict(list)

            for feature in tqdm(feature_collection["features"],
                                ascii=True,
                                unit="feature"):
                feature_map = geojson_parse_feature(args.zoom, srid,
                                                    feature_map, feature)

        cover = feature_map.keys()

    if args.sql:
        print("RoboSat.pink - cover from {} {} at zoom {}".format(
            args.sql, args.pg, args.zoom),
              file=sys.stderr,
              flush=True)
        conn = psycopg2.connect(args.pg)
        assert conn, "Unable to connect to PostgreSQL database."
        db = conn.cursor()

        query = """
            WITH
              sql  AS ({}),
              geom AS (SELECT "1" AS geom FROM sql AS t("1"))
              SELECT '{{"type": "Feature", "geometry": '
                     || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
                     || '}}' AS features
              FROM geom
            """.format(args.sql)

        db.execute(query)
        assert db.rowcount is not None and db.rowcount != -1, "SQL Query return no result."

        feature_map = collections.defaultdict(list)

        for feature in tqdm(
                db.fetchall(), ascii=True, unit="feature"
        ):  # FIXME: fetchall will not always fit in memory...
            feature_map = geojson_parse_feature(args.zoom, 4326, feature_map,
                                                json.loads(feature[0]))

        cover = feature_map.keys()

    if args.bbox:
        print("RoboSat.pink - cover from {} at zoom {}".format(
            args.bbox, args.zoom),
              file=sys.stderr,
              flush=True)
        if crs:
            w, s, e, n = transform_bounds(crs, "EPSG:4326", w, s, e, n)
            assert isinstance(w, float) and isinstance(
                s, float), "Unable to deal with raster projection"

        cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.cover:
        print("RoboSat.pink - cover from {}".format(args.cover),
              file=sys.stderr,
              flush=True)
        cover = [tile for tile in tiles_from_csv(args.cover)]

    if args.dir:
        print("RoboSat.pink - cover from {}".format(args.dir),
              file=sys.stderr,
              flush=True)
        cover = [
            tile for tile in tiles_from_dir(args.dir, xyz=not (args.no_xyz))
        ]

    _cover = []
    extent_w, extent_s, extent_n, extent_e = (180.0, 90.0, -180.0, -90.0)
    for tile in tqdm(cover, ascii=True, unit="tile"):
        if args.zoom and tile.z != args.zoom:
            w, s, n, e = transform_bounds("EPSG:3857", "EPSG:4326",
                                          *xy_bounds(tile))
            for t in tiles(w, s, n, e, args.zoom):
                unique = True
                for _t in _cover:
                    if _t == t:
                        unique = False
                if unique:
                    _cover.append(t)
        else:
            if args.extent:
                w, s, n, e = transform_bounds("EPSG:3857", "EPSG:4326",
                                              *xy_bounds(tile))
            _cover.append(tile)

        if args.extent:
            extent_w, extent_s, extent_n, extent_e = (min(extent_w,
                                                          w), min(extent_s, s),
                                                      max(extent_n,
                                                          n), max(extent_e, e))

    cover = _cover

    if args.splits:
        shuffle(cover)  # in-place
        cover_splits = [
            math.floor(len(cover) * split / 100)
            for i, split in enumerate(splits, 1)
        ]
        if len(splits) > 1 and sum(map(
                int, splits)) == 100 and len(cover) > sum(map(int, splits)):
            cover_splits[0] = len(cover) - sum(map(
                int, cover_splits[1:]))  # no tile waste
        s = 0
        covers = []
        for e in cover_splits:
            covers.append(cover[s:s + e])
            s += e
    else:
        covers = [cover]

    if args.extent:
        if args.out and os.path.dirname(args.out[0]) and not os.path.isdir(
                os.path.dirname(args.out[0])):
            os.makedirs(os.path.dirname(args.out[0]), exist_ok=True)

        extent = "{:.8f},{:.8f},{:.8f},{:.8f}".format(extent_w, extent_s,
                                                      extent_n, extent_e)

        if args.out:
            with open(args.out[0], "w") as fp:
                fp.write(extent)
        else:
            print(extent)
    else:
        for i, cover in enumerate(covers):

            if os.path.dirname(args.out[i]) and not os.path.isdir(
                    os.path.dirname(args.out[i])):
                os.makedirs(os.path.dirname(args.out[i]), exist_ok=True)

            with open(args.out[i], "w") as fp:
                csv.writer(fp).writerows(cover)
示例#10
0
def main(args):

    if (
        int(args.bbox is not None)
        + int(args.geojson is not None)
        + int(args.dir is not None)
        + int(args.xyz is not None)
        + int(args.raster is not None)
        + int(args.cover is not None)
        != 1
    ):
        sys.exit("ERROR: One, and only one, input type must be provided, among: --dir, --bbox, --cover or --geojson.")

    if args.bbox:
        try:
            w, s, e, n, crs = args.bbox.split(",")
            w, s, e, n = map(float, (w, s, e, n))
        except:
            try:
                crs = None
                w, s, e, n = map(float, args.bbox.split(","))
            except:
                sys.exit("ERROR: invalid bbox parameter.")

    if args.splits:

        try:
            splits = [int(split) for split in args.splits.split("/")]
            assert len(splits) == len(args.out)
            assert sum(splits) == 100
        except:
            sys.exit("ERROR: Invalid split value or incoherent with provided out paths.")

    if not args.zoom and (args.geojson or args.bbox or args.raster):
        sys.exit("ERROR: Zoom parameter is required.")

    args.out = [os.path.expanduser(out) for out in args.out]

    cover = []

    if args.raster:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.raster, args.zoom))
        with rasterio_open(os.path.expanduser(args.raster)) as r:
            try:
                w, s, e, n = transform_bounds(r.crs, "EPSG:4326", *r.bounds)
            except:
                sys.exit("ERROR: unable to deal with raster projection")

            cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.geojson:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.geojson, args.zoom))
        with open(os.path.expanduser(args.geojson)) as f:
            features = json.load(f)

        try:
            for feature in tqdm(features["features"], ascii=True, unit="feature"):
                cover.extend(map(tuple, burntiles.burn([feature], args.zoom).tolist()))
        except:
            sys.exit("ERROR: invalid or unsupported GeoJSON.")

        cover = list(set(cover))  # tiles can overlap for multiple features; unique tile ids

    if args.bbox:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.bbox, args.zoom))
        if crs:
            try:
                w, s, e, n = transform_bounds(crs, "EPSG:4326", w, s, e, n)
            except:
                sys.exit("ERROR: unable to deal with raster projection")

        cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.cover:
        print("RoboSat.pink - cover from {}".format(args.cover))
        cover = [tile for tile in tiles_from_csv(args.cover)]

    if args.dir:
        print("RoboSat.pink - cover from {}".format(args.dir))
        cover = [tile for tile in tiles_from_dir(args.dir, xyz=False)]

    if args.xyz:
        print("RoboSat.pink - cover from {}".format(args.xyz))
        cover = [tile for tile in tiles_from_dir(args.xyz, xyz=True)]

    _cover = []
    for tile in tqdm(cover, ascii=True, unit="tile"):
        if args.zoom and tile.z != args.zoom:
            w, s, n, e = transform_bounds("EPSG:3857", "EPSG:4326", *xy_bounds(tile))
            for t in tiles(w, s, n, e, args.zoom):
                unique = True
                for _t in _cover:
                    if _t == t:
                        unique = False
                if unique:
                    _cover.append(t)
        else:
            _cover.append(tile)
    cover = _cover

    if args.splits:
        shuffle(cover)  # in-place
        splits = [math.floor(len(cover) * split / 100) for i, split in enumerate(splits, 1)]
        s = 0
        covers = []
        for e in splits:
            covers.append(cover[s : s + e - 1])
            s += e
    else:
        covers = [cover]

    for i, cover in enumerate(covers):

        if os.path.dirname(args.out[i]) and not os.path.isdir(os.path.dirname(args.out[i])):
            os.makedirs(os.path.dirname(args.out[i]), exist_ok=True)

        with open(args.out[i], "w") as fp:
            csv.writer(fp).writerows(cover)
示例#11
0
def main(args):

    if args.pg:
        if not args.sql:
            sys.exit("ERROR: With PostgreSQL db, --sql must be provided")

    if (args.sql and args.geojson) or (args.sql and not args.pg):
        sys.exit(
            "ERROR: You can use either --pg or --geojson inputs, but only one at once."
        )

    config = load_config(args.config)
    check_classes(config)
    palette = make_palette(*[classe["color"] for classe in config["classes"]],
                           complementary=True)
    burn_value = next(config["classes"].index(classe)
                      for classe in config["classes"]
                      if classe["title"] == args.type)
    if "burn_value" not in locals():
        sys.exit(
            "ERROR: asked type to rasterize is not contains in your config file classes."
        )

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def geojson_parse_polygon(zoom, srid, feature_map, polygon, i):

        try:
            if srid != 4326:
                polygon = [
                    xy for xy in geojson_reproject(
                        {
                            "type": "feature",
                            "geometry": polygon
                        }, srid, 4326)
                ][0]

            for i, ring in enumerate(
                    polygon["coordinates"]
            ):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[
                    x, y
                ] for point in ring for x, y in zip([point[0]], [point[1]])]

            if polygon["coordinates"]:
                for tile in burntiles.burn([{
                        "type": "feature",
                        "geometry": polygon
                }],
                                           zoom=zoom):
                    feature_map[mercantile.Tile(*tile)].append({
                        "type":
                        "feature",
                        "geometry":
                        polygon
                    })

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def geojson_parse_geometry(zoom, srid, feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = geojson_parse_polygon(zoom, srid, feature_map,
                                                geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = geojson_parse_polygon(zoom, srid, feature_map, {
                    "type": "Polygon",
                    "coordinates": polygon
                }, i)
        else:
            log.log(
                "Notice: {} is a non surfacic geometry type, skipping feature {}"
                .format(geometry["type"], i))

        return feature_map

    if args.geojson:

        tiles = [
            tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
        ]
        assert tiles, "Empty cover"

        zoom = tiles[0].z
        assert not [tile for tile in tiles if tile.z != zoom
                    ], "Unsupported zoom mixed cover. Use PostGIS instead"

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                feature_collection = json.load(geojson)

                try:
                    crs_mapping = {"CRS84": "4326", "900913": "3857"}
                    srid = feature_collection["crs"]["properties"][
                        "name"].split(":")[-1]
                    srid = int(srid) if srid not in crs_mapping else int(
                        crs_mapping[srid])
                except:
                    srid = int(4326)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):

                    if feature["geometry"]["type"] == "GeometryCollection":
                        for geometry in feature["geometry"]["geometries"]:
                            feature_map = geojson_parse_geometry(
                                zoom, srid, feature_map, geometry, i)
                    else:
                        feature_map = geojson_parse_geometry(
                            zoom, srid, feature_map, feature["geometry"], i)
        features = args.geojson

    if args.pg:

        conn = psycopg2.connect(args.pg)
        db = conn.cursor()

        assert "limit" not in args.sql.lower(), "LIMIT is not supported"
        db.execute(
            "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(
                args.sql))
        srid = db.fetchone()[0]
        assert srid, "Unable to retrieve geometry SRID."

        if "where" not in args.sql.lower(
        ):  # TODO: Find a more reliable way to handle feature filtering
            args.sql += " WHERE ST_Intersects(tile.geom, geom)"
        else:
            args.sql += " AND ST_Intersects(tile.geom, geom)"
        features = args.sql

    log.log(
        "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format(
            args.type, features, args.cover))
    with open(os.path.join(os.path.expanduser(args.out), "instances.cover"),
              mode="w") as cover:

        for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))),
                         ascii=True,
                         unit="tile"):

            geojson = None

            if args.pg:

                w, s, e, n = tile_bbox(tile)

                query = """
                WITH
                  tile AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}) AS geom),
                  geom AS (SELECT ST_Intersection(tile.geom, sql.geom) AS geom FROM tile CROSS JOIN LATERAL ({}) sql),
                  json AS (SELECT '{{"type": "Feature", "geometry": '
                         || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
                         || '}}' AS features
                        FROM geom)
                SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}'
                FROM json
                """.format(w, s, e, n, srid, args.sql)

                db.execute(query)
                row = db.fetchone()
                try:
                    geojson = json.loads(
                        row[0])["features"] if row and row[0] else None
                except Exception:
                    log.log("Warning: Invalid geometries, skipping {}".format(
                        tile))
                    conn = psycopg2.connect(args.pg)
                    db = conn.cursor()

            if args.geojson:
                geojson = feature_map[tile] if tile in feature_map else None

            if geojson:
                num = len(geojson)
                out = geojson_tile_burn(tile, geojson, 4326, args.ts,
                                        burn_value)

            if not geojson or out is None:
                num = 0
                out = np.zeros(shape=(args.ts, args.ts), dtype=np.uint8)

            tile_label_to_file(args.out, tile, palette, out)
            cover.write("{},{},{}  {}{}".format(tile.x, tile.y, tile.z, num,
                                                os.linesep))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
示例#12
0
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    palette = make_palette([classe["color"] for classe in config["classes"]])
    if not args.bs:
        try:
            args.bs = config["model"]["bs"]
        except:
            pass

    assert args.bs, "For rsp predict, model/bs must be set either in config file, or pass trought parameter --bs"
    args.workers = args.bs if not args.workers else args.workers
    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        log.log("")
        log.log("============================================================")
        log.log("WARNING: Are you -really- sure about not predicting on GPU ?")
        log.log("============================================================")
        log.log("")
        device = torch.device("cpu")

    chkpt = torch.load(args.checkpoint, map_location=device)
    nn_module = load_module("robosat_pink.nn.{}".format(chkpt["nn"].lower()))
    nn = getattr(nn_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"],
                                         chkpt["encoder"].lower()).to(device)
    nn = torch.nn.DataParallel(nn)
    nn.load_state_dict(chkpt["state_dict"])
    nn.eval()

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        tiled = []
        if args.passes in ["first", "both"]:
            log.log("== Predict First Pass ==")
            tiled = predict(config, cover, args, palette, chkpt, nn, device,
                            "predict")

        if args.passes in ["second", "both"]:
            log.log("== Predict Second Pass ==")
            predict(config, cover, args, palette, chkpt, nn, device,
                    "predict_translate")

    if not args.no_web_ui and tiled:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiled, tiled, "png", template)
示例#13
0
def main(args):

    if not args.workers:
        args.workers = min(os.cpu_count(), len(args.rasters))

    if args.label:
        config = load_config(args.config)
        check_classes(config)
        colors = [classe["color"] for classe in config["classes"]]
        palette = make_palette(colors)

    assert len(args.ts.split(
        ",")) == 2, "--ts expect width,height value (e.g 512,512)"
    width, height = list(map(int, args.ts.split(",")))

    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    splits_path = os.path.join(os.path.expanduser(args.out), ".splits")
    tiles_map = {}

    print("RoboSat.pink - tile on CPU, with {} workers".format(args.workers),
          file=sys.stderr,
          flush=True)

    bands = -1
    for path in args.rasters:
        raster = rasterio_open(path)
        w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds)

        if bands != -1:
            assert bands == len(
                raster.indexes), "Coverage must be bands consistent"
        bands = len(raster.indexes)

        tiles = [
            mercantile.Tile(x=x, y=y, z=z)
            for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
        ]
        tiles = list(set(tiles) & set(cover)) if cover else tiles

        for tile in tiles:
            tile_key = (str(tile.x), str(tile.y), str(tile.z))
            if tile_key not in tiles_map.keys():
                tiles_map[tile_key] = []
            tiles_map[tile_key].append(path)

    if args.label:
        ext = "png"
        bands = 1
    if not args.label:
        if bands == 1:
            ext = "png"
        if bands == 3:
            ext = "webp"
        if bands > 3:
            ext = "tiff"

    tiles = []
    progress = tqdm(total=len(tiles_map), ascii=True, unit="tile")
    # Begin to tile plain tiles
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(path):

            raster = rasterio_open(path)
            w, s, e, n = transform_bounds(raster.crs, "EPSG:4326",
                                          *raster.bounds)
            tiles = [
                mercantile.Tile(x=x, y=y, z=z)
                for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
            ]
            tiled = []

            for tile in tiles:

                if cover and tile not in cover:
                    continue

                w, s, e, n = mercantile.xy_bounds(tile)

                warp_vrt = WarpedVRT(
                    raster,
                    crs="epsg:3857",
                    resampling=Resampling.bilinear,
                    add_alpha=False,
                    transform=from_bounds(w, s, e, n, width, height),
                    width=width,
                    height=height,
                )
                data = warp_vrt.read(out_shape=(len(raster.indexes), width,
                                                height),
                                     window=warp_vrt.window(w, s, e, n))

                if data.dtype == "uint16":  # GeoTiff could be 16 bits
                    data = np.uint8(data / 256)
                elif data.dtype == "uint32":  # or 32 bits
                    data = np.uint8(data / (256 * 256))

                image = np.moveaxis(data, 0, 2)  # C,H,W -> H,W,C

                tile_key = (str(tile.x), str(tile.y), str(tile.z))
                if (not args.label
                        and len(tiles_map[tile_key]) == 1 and is_nodata(
                            image, args.nodata, args.nodata_threshold,
                            args.keep_borders)):
                    progress.update()
                    continue

                if len(tiles_map[tile_key]) > 1:
                    out = os.path.join(splits_path,
                                       str(tiles_map[tile_key].index(path)))
                else:
                    out = args.out

                x, y, z = map(int, tile)

                if not args.label:
                    tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z),
                                       image)
                if args.label:
                    tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z),
                                       palette, image)

                if len(tiles_map[tile_key]) == 1:
                    progress.update()
                    tiled.append(mercantile.Tile(x=x, y=y, z=z))

            return tiled

        for tiled in executor.map(worker, args.rasters):
            if tiled is not None:
                tiles.extend(tiled)

    # Aggregate remaining tiles splits
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(tile_key):

            if len(tiles_map[tile_key]) == 1:
                return

            image = np.zeros((width, height, bands), np.uint8)

            x, y, z = map(int, tile_key)
            for i in range(len(tiles_map[tile_key])):
                root = os.path.join(splits_path, str(i))
                _, path = tile_from_xyz(root, x, y, z)

                if not args.label:
                    split = tile_image_from_file(path)
                if args.label:
                    split = tile_label_from_file(path)
                    split = split.reshape((width, height, 1))  # H,W -> H,W,C

                assert image.shape == split.shape
                image[np.where(image == 0)] += split[np.where(image == 0)]

            if not args.label and is_nodata(image, args.nodata,
                                            args.nodata_threshold,
                                            args.keep_borders):
                progress.update()
                return

            tile = mercantile.Tile(x=x, y=y, z=z)

            if not args.label:
                tile_image_to_file(args.out, tile, image)

            if args.label:
                tile_label_to_file(args.out, tile, palette, image)

            progress.update()
            return tile

        for tiled in executor.map(worker, tiles_map.keys()):
            if tiled is not None:
                tiles.append(tiled)

        if splits_path and os.path.isdir(splits_path):
            shutil.rmtree(splits_path)  # Delete suffixes dir if any

    if tiles and not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiles, tiles, ext, template)
示例#14
0
def main(args):

    if (args.geojson and args.postgis) or (not args.geojson
                                           and not args.postgis):
        sys.exit(
            "ERROR: Input features to rasterize must be either GeoJSON or PostGIS"
        )

    if args.postgis and not args.pg_dsn:
        sys.exit(
            "ERROR: With PostGIS input features, --pg_dsn must be provided")

    config = load_config(args.config)
    check_classes(config)
    palette = make_palette(*[classe["color"] for classe in config["classes"]],
                           complementary=True)
    burn_value = 1

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def geojson_parse_polygon(zoom, srid, feature_map, polygon, i):

        try:
            if srid != 4326:
                polygon = [
                    xy for xy in geojson_reproject(
                        {
                            "type": "feature",
                            "geometry": polygon
                        }, srid, 4326)
                ][0]

            for i, ring in enumerate(
                    polygon["coordinates"]
            ):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[
                    x, y
                ] for point in ring for x, y in zip([point[0]], [point[1]])]

            if polygon["coordinates"]:
                for tile in burntiles.burn([{
                        "type": "feature",
                        "geometry": polygon
                }],
                                           zoom=zoom):
                    feature_map[mercantile.Tile(*tile)].append({
                        "type":
                        "feature",
                        "geometry":
                        polygon
                    })

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def geojson_parse_geometry(zoom, srid, feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = geojson_parse_polygon(zoom, srid, feature_map,
                                                geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = geojson_parse_polygon(zoom, srid, feature_map, {
                    "type": "Polygon",
                    "coordinates": polygon
                }, i)
        else:
            log.log(
                "Notice: {} is a non surfacic geometry type, skipping feature {}"
                .format(geometry["type"], i))

        return feature_map

    if args.geojson:

        try:
            tiles = [
                tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
            ]
            zoom = tiles[0].z
            assert not [tile for tile in tiles if tile.z != zoom]
        except:
            sys.exit("ERROR: Inconsistent cover {}".format(args.cover))

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                try:
                    feature_collection = json.load(geojson)
                except:
                    sys.exit("ERROR: {} is not a valid JSON file.".format(
                        geojson_file))

                try:
                    crs_mapping = {"CRS84": "4326", "900913": "3857"}
                    srid = feature_collection["crs"]["properties"][
                        "name"].split(":")[-1]
                    srid = int(srid) if srid not in crs_mapping else int(
                        crs_mapping[srid])
                except:
                    srid = int(4326)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):

                    try:
                        if feature["geometry"]["type"] == "GeometryCollection":
                            for geometry in feature["geometry"]["geometries"]:
                                feature_map = geojson_parse_geometry(
                                    zoom, srid, feature_map, geometry, i)
                        else:
                            feature_map = geojson_parse_geometry(
                                zoom, srid, feature_map, feature["geometry"],
                                i)
                    except:
                        sys.exit(
                            "ERROR: Unable to parse {} file. Seems not a valid GEOJSON file."
                            .format(geojson_file))

        log.log(
            "RoboSat.pink - rasterize - rasterizing tiles from {} on cover {}".
            format(args.geojson, args.cover))
        with open(os.path.join(os.path.expanduser(args.out),
                               "instances.cover"),
                  mode="w") as cover:
            for tile in tqdm(list(
                    tiles_from_csv(os.path.expanduser(args.cover))),
                             ascii=True,
                             unit="tile"):

                try:
                    if tile in feature_map:
                        cover.write("{},{},{}  {}{}".format(
                            tile.x, tile.y, tile.z, len(feature_map[tile]),
                            os.linesep))
                        out = geojson_tile_burn(tile, feature_map[tile], 4326,
                                                args.ts, burn_value)
                    else:
                        cover.write("{},{},{}  {}{}".format(
                            tile.x, tile.y, tile.z, 0, os.linesep))
                        out = np.zeros(shape=(args.ts, args.ts),
                                       dtype=np.uint8)

                    tile_label_to_file(args.out, tile, palette, out)
                except:
                    log.log("Warning: Unable to rasterize tile. Skipping {}".
                            format(str(tile)))

    if args.postgis:

        try:
            pg_conn = psycopg2.connect(args.pg_dsn)
            pg = pg_conn.cursor()
        except Exception:
            sys.exit("Unable to connect PostgreSQL: {}".format(args.pg_dsn))

        log.log(
            "RoboSat.pink - rasterize - rasterizing tiles from PostGIS on cover {}"
            .format(args.cover))
        log.log(" SQL {}".format(args.postgis))
        try:
            pg.execute(
                "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(
                    args.postgis))
            srid = pg.fetchone()[0]
        except Exception:
            sys.exit("Unable to retrieve geometry SRID.")

        for tile in tqdm(list(tiles_from_csv(args.cover)),
                         ascii=True,
                         unit="tile"):

            s, w, e, n = mercantile.bounds(tile)
            raster = np.zeros((args.ts, args.ts))

            query = """
WITH
     bbox      AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}  ) AS bbox),
     bbox_merc AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), 3857) AS bbox),

     rast_a    AS (SELECT ST_AddBand(
                           ST_SetSRID(
                             ST_MakeEmptyRaster({}, {}, ST_Xmin(bbox), ST_Ymax(bbox), (ST_YMax(bbox) - ST_YMin(bbox)) / {}),
                           3857),
                          '8BUI'::text, 0) AS rast
                   FROM bbox_merc),

     features  AS (SELECT ST_Union(ST_Transform(ST_Force2D(geom), 3857)) AS geom
                   FROM ({}) AS sub, bbox
                   WHERE ST_Intersects(geom, bbox)),

     rast_b    AS (SELECT ST_AsRaster(geom, rast, '8BUI', {}) AS rast
                   FROM features, rast_a
                   WHERE NOT ST_IsEmpty(geom))

SELECT ST_AsBinary(ST_MapAlgebra(rast_a.rast, rast_b.rast, '{}', NULL, 'FIRST')) AS wkb FROM rast_a, rast_b

""".format(s, w, e, n, srid, s, w, e, n, args.ts, args.ts, args.ts,
            args.postgis, burn_value, burn_value)

            try:
                pg.execute(query)
                row = pg.fetchone()
                if row:
                    raster = np.squeeze(wkb_to_numpy(io.BytesIO(row[0])),
                                        axis=2)

            except Exception:
                log.log(
                    "Warning: Invalid geometries, skipping {}".format(tile))
                pg_conn = psycopg2.connect(args.pg_dsn)
                pg = pg_conn.cursor()

            try:
                tile_label_to_file(args.out, tile, palette, raster)
            except:
                log.log(
                    "Warning: Unable to rasterize tile. Skipping {}".format(
                        str(tile)))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
示例#15
0
def main(args):

    assert not (args.extent and args.splits), "--splits and --extent are mutually exclusive options."
    assert not (args.extent and len(args.out) > 1), "--extent option imply a single output."
    assert (
        int(args.bbox is not None)
        + int(args.geojson is not None)
        + int(args.dir is not None)
        + int(args.raster is not None)
        + int(args.cover is not None)
        == 1
    ), "One, and only one, input type must be provided, among: --dir, --bbox, --cover or --geojson."

    if args.bbox:
        try:
            w, s, e, n, crs = args.bbox.split(",")
            w, s, e, n = map(float, (w, s, e, n))
        except:
            crs = None
            w, s, e, n = map(float, args.bbox.split(","))
        assert isinstance(w, float) and isinstance(s, float), "Invalid bbox parameter."

    if args.splits:
        splits = [int(split) for split in args.splits.split("/")]
        assert len(splits) == len(args.out) and 0 < sum(splits) <= 100, "Invalid split value or incoherent with out paths."

    assert not (not args.zoom and (args.geojson or args.bbox or args.raster)), "Zoom parameter is required."

    args.out = [os.path.expanduser(out) for out in args.out]

    cover = []

    if args.raster:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.raster, args.zoom), file=sys.stderr, flush=True)
        with rasterio_open(os.path.expanduser(args.raster)) as r:
            w, s, e, n = transform_bounds(r.crs, "EPSG:4326", *r.bounds)
            assert isinstance(w, float) and isinstance(s, float), "Unable to deal with raster projection"

            cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.geojson:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.geojson, args.zoom), file=sys.stderr, flush=True)
        with open(os.path.expanduser(args.geojson)) as f:
            feature_collection = json.load(f)
            srid = geojson_srid(feature_collection)
            feature_map = collections.defaultdict(list)

            for i, feature in enumerate(tqdm(feature_collection["features"], ascii=True, unit="feature")):
                feature_map = geojson_parse_feature(args.zoom, srid, feature_map, feature)

        cover = feature_map.keys()

    if args.bbox:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.bbox, args.zoom), file=sys.stderr, flush=True)
        if crs:
            w, s, e, n = transform_bounds(crs, "EPSG:4326", w, s, e, n)
            assert isinstance(w, float) and isinstance(s, float), "Unable to deal with raster projection"

        cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.cover:
        print("RoboSat.pink - cover from {}".format(args.cover), file=sys.stderr, flush=True)
        cover = [tile for tile in tiles_from_csv(args.cover)]

    if args.dir:
        print("RoboSat.pink - cover from {}".format(args.dir), file=sys.stderr, flush=True)
        cover = [tile for tile in tiles_from_dir(args.dir, xyz=not (args.no_xyz))]

    _cover = []
    extent_w, extent_s, extent_n, extent_e = (180.0, 90.0, -180.0, -90.0)
    for tile in tqdm(cover, ascii=True, unit="tile"):
        if args.zoom and tile.z != args.zoom:
            w, s, n, e = transform_bounds("EPSG:3857", "EPSG:4326", *xy_bounds(tile))
            for t in tiles(w, s, n, e, args.zoom):
                unique = True
                for _t in _cover:
                    if _t == t:
                        unique = False
                if unique:
                    _cover.append(t)
        else:
            if args.extent:
                w, s, n, e = transform_bounds("EPSG:3857", "EPSG:4326", *xy_bounds(tile))
            _cover.append(tile)

        if args.extent:
            extent_w, extent_s, extent_n, extent_e = (min(extent_w, w), min(extent_s, s), max(extent_n, n), max(extent_e, e))

    cover = _cover

    if args.splits:
        shuffle(cover)  # in-place
        cover_splits = [math.floor(len(cover) * split / 100) for i, split in enumerate(splits, 1)]
        if len(splits) > 1 and sum(map(int, splits)) == 100 and len(cover) > sum(map(int, splits)):
            cover_splits[0] = len(cover) - sum(map(int, cover_splits[1:]))  # no tile waste
        s = 0
        covers = []
        for e in cover_splits:
            covers.append(cover[s : s + e])
            s += e
    else:
        covers = [cover]

    if args.extent:
        if args.out and os.path.dirname(args.out[0]) and not os.path.isdir(os.path.dirname(args.out[0])):
            os.makedirs(os.path.dirname(args.out[0]), exist_ok=True)

        extent = "{:.8f},{:.8f},{:.8f},{:.8f}".format(extent_w, extent_s, extent_n, extent_e)

        if args.out:
            with open(args.out[0], "w") as fp:
                fp.write(extent)
        else:
            print(extent)
    else:
        for i, cover in enumerate(covers):

            if os.path.dirname(args.out[i]) and not os.path.isdir(os.path.dirname(args.out[i])):
                os.makedirs(os.path.dirname(args.out[i]), exist_ok=True)

            with open(args.out[i], "w") as fp:
                csv.writer(fp).writerows(cover)
示例#16
0
def main(args):

    if (args.geojson and args.postgis) or (not args.geojson and not args.postgis):
        sys.exit("Input features to rasterize must be either GeoJSON or PostGIS")

    config = load_config(args.config)
    tile_size = args.tile_size if args.tile_size else config["model"]["tile_size"]
    colors = [classe["color"] for classe in config["classes"]]
    burn_value = 1

    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def geojson_parse_polygon(zoom, feature_map, polygon, i):

        try:
            for i, ring in enumerate(polygon["coordinates"]):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[x, y] for point in ring for x, y in zip([point[0]], [point[1]])]

            for tile in burntiles.burn([{"type": "feature", "geometry": polygon}], zoom=zoom):
                feature_map[mercantile.Tile(*tile)].append({"type": "feature", "geometry": polygon})

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def geojson_parse_geometry(zoom, feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = geojson_parse_polygon(zoom, feature_map, geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = geojson_parse_polygon(zoom, feature_map, {"type": "Polygon", "coordinates": polygon}, i)
        else:
            log.log("Notice: {} is a non surfacic geometry type, skipping feature {}".format(geometry["type"], i))

        return feature_map

    if args.geojson:

        tiles = [tile for tile in tiles_from_csv(args.cover)]
        zoom = tiles[0].z

        if [tile for tile in tiles if tile.z != zoom]:
            sys.exit("With GeoJson input, all tiles z values have to be the same, in the csv cover file.")

        feature_map = collections.defaultdict(list)

        # Compute a spatial index like
        for geojson_file in args.geojson:
            with open(geojson_file) as geojson:
                feature_collection = json.load(geojson)
                for i, feature in enumerate(tqdm(feature_collection["features"], ascii=True, unit="feature")):

                    if feature["geometry"]["type"] == "GeometryCollection":
                        for geometry in feature["geometry"]["geometries"]:
                            feature_map = geojson_parse_geometry(zoom, feature_map, geometry, i)
                    else:
                        feature_map = geojson_parse_geometry(zoom, feature_map, feature["geometry"], i)

        # Rasterize tiles
        for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"):
            if tile in feature_map:
                out = geojson_tile_burn(tile, feature_map[tile], tile_size, burn_value)
            else:
                out = np.zeros(shape=(tile_size, tile_size), dtype=np.uint8)

            write_tile(args.out, tile, colors, out)

    if args.postgis:

        try:
            pg_conn = psycopg2.connect(config["dataset"]["pg_dsn"])
            pg = pg_conn.cursor()
        except Exception:
            sys.exit("Unable to connect PostgreSQL: {}".format(config["dataset"]["pg_dsn"]))

        try:
            pg.execute("SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(args.postgis))
            srid = pg.fetchone()[0]
        except Exception:
            sys.exit("Unable to retrieve geometry SRID.")

        for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"):

            s, w, e, n = mercantile.bounds(tile)
            raster = np.zeros((tile_size, tile_size))

            query = """
WITH
     bbox      AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}  ) AS bbox),
     bbox_merc AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), 3857) AS bbox),

     rast_a    AS (SELECT ST_AddBand(
                           ST_SetSRID(
                             ST_MakeEmptyRaster({}, {}, ST_Xmin(bbox), ST_Ymax(bbox), (ST_YMax(bbox) - ST_YMin(bbox)) / {}),
                           3857),
                          '8BUI'::text, 0) AS rast
                   FROM bbox_merc),

     features  AS (SELECT ST_Union(ST_Transform(ST_Force2D(geom), 3857)) AS geom
                   FROM ({}) AS sub, bbox
                   WHERE ST_Intersects(geom, bbox)),

     rast_b    AS (SELECT ST_AsRaster(geom, rast, '8BUI', {}) AS rast
                   FROM features, rast_a
                   WHERE NOT ST_IsEmpty(geom))

SELECT ST_AsBinary(ST_MapAlgebra(rast_a.rast, rast_b.rast, '{}', NULL, 'FIRST')) AS wkb FROM rast_a, rast_b

""".format(
                s, w, e, n, srid, s, w, e, n, tile_size, tile_size, tile_size, args.postgis, burn_value, burn_value
            )

            try:
                pg.execute(query)
                row = pg.fetchone()
                if row:
                    raster = np.squeeze(wkb_to_numpy(io.BytesIO(row[0])), axis=2)

            except Exception:
                log.log("Warning: Invalid geometries, skipping {}".format(tile))
                pg_conn = psycopg2.connect(config["dataset"]["pg_dsn"])
                pg = pg_conn.cursor()

            write_tile(args.out, tile, colors, raster)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
示例#17
0
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    palette = make_palette([classe["color"] for classe in config["classes"]])
    args.workers = torch.cuda.device_count() * 2 if torch.device(
        "cuda") and not args.workers else args.workers
    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        log.log("")
        log.log("============================================================")
        log.log("WARNING: Are you -really- sure about not predicting on GPU ?")
        log.log("============================================================")
        log.log("")
        device = torch.device("cpu")

    chkpt = torch.load(args.checkpoint, map_location=device)
    model_module = load_module("robosat_pink.models.{}".format(
        chkpt["nn"].lower()))
    nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"],
                                            chkpt["shape_out"]).to(device)
    nn = torch.nn.DataParallel(nn)
    nn.load_state_dict(chkpt["state_dict"])
    nn.eval()

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    mode = "predict" if not args.translate else "predict_translate"
    loader_module = load_module("robosat_pink.loaders.{}".format(
        chkpt["loader"].lower()))
    loader_predict = getattr(loader_module,
                             chkpt["loader"])(config,
                                              chkpt["shape_in"][1:3],
                                              args.dataset,
                                              cover,
                                              mode=mode)

    loader = DataLoader(loader_predict,
                        batch_size=args.bs,
                        num_workers=args.workers)
    assert len(loader), "Empty predict dataset directory. Check your path."

    tiled = []
    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):

            images = images.to(device)

            outputs = nn(images)
            probs = torch.nn.functional.softmax(outputs,
                                                dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))
                mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()
                if args.translate:
                    tile_translate_to_file(args.out, mercantile.Tile(x, y, z),
                                           palette, mask)
                else:
                    tile_label_to_file(args.out, mercantile.Tile(x, y, z),
                                       palette, mask)
                tiled.append(mercantile.Tile(x, y, z))

    if not args.no_web_ui and not args.translate:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiled, tiled, "png", template)
示例#18
0
def main(args):
    tiles = list(tiles_from_csv(args.tiles))
    already_dl = 0
    dl = 0

    with requests.Session() as session:
        num_workers = args.rate

        os.makedirs(os.path.join(args.out), exist_ok=True)
        log = Logs(os.path.join(args.out, "log"), out=sys.stderr)
        log.log("Begin download from {}".format(args.url))

        progress = tqdm(total=len(tiles), ascii=True, unit="image")

        with futures.ThreadPoolExecutor(num_workers) as executor:

            def worker(tile):
                tick = time.monotonic()

                x, y, z = map(str, [tile.x, tile.y, tile.z])

                os.makedirs(os.path.join(args.out, z, x), exist_ok=True)
                path = os.path.join(args.out, z, x,
                                    "{}.{}".format(y, args.ext))

                if os.path.isfile(path):
                    progress.update()
                    return tile, None, True

                if args.type == "XYZ":
                    url = args.url.format(x=tile.x, y=tile.y, z=tile.z)
                elif args.type == "TMS":
                    tile.y = (2**tile.z) - tile.y - 1
                    url = args.url.format(x=tile.x, y=tile.y, z=tile.z)
                elif args.type == "WMS":
                    xmin, ymin, xmax, ymax = xy_bounds(tile)
                    url = args.url.format(xmin=xmin,
                                          ymin=ymin,
                                          xmax=xmax,
                                          ymax=ymax)

                res = tile_image_from_url(session, url, args.timeout)
                if not res:
                    return tile, url, False

                try:
                    cv2.imwrite(
                        path,
                        cv2.imdecode(np.fromstring(res.read(), np.uint8),
                                     cv2.IMREAD_COLOR))
                    progress.update()
                except OSError:
                    return tile, url, False

                tock = time.monotonic()

                time_for_req = tock - tick
                time_per_worker = num_workers / args.rate

                if time_for_req < time_per_worker:
                    time.sleep(time_per_worker - time_for_req)

                return tile, url, True

            for tile, url, ok in executor.map(worker, tiles):
                if url and ok:
                    dl += 1
                elif not url and ok:
                    already_dl += 1
                else:
                    log.log("Warning:\n {} failed, skipping.\n {}\n".format(
                        tile, url))

    if already_dl:
        log.log(
            "Notice:\n {} tiles were already downloaded previously, and so skipped now."
            .format(already_dl))
    if already_dl + dl == len(tiles):
        log.log(" Coverage is fully downloaded.")

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        web_ui(args.out, base_url, tiles, tiles, args.ext, template)
示例#19
0
def main(args):

    if not args.workers:
        args.workers = min(os.cpu_count(), len(args.rasters))

    if args.label:
        config = load_config(args.config)
        check_classes(config)
        colors = [classe["color"] for classe in config["classes"]]
        palette = make_palette(*colors)

    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    splits_path = os.path.join(os.path.expanduser(args.out), ".splits")
    tiles_map = {}

    print("RoboSat.pink - tile on CPU, with {} workers".format(args.workers))

    bands = -1
    for path in args.rasters:
        raster = rasterio_open(path)
        w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds)

        if bands != -1:
            assert bands == len(
                raster.indexes), "Coverage must be bands consistent"
        bands = len(raster.indexes)

        tiles = [
            mercantile.Tile(x=x, y=y, z=z)
            for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
        ]
        tiles = list(set(tiles) & set(cover)) if cover else tiles

        for tile in tiles:
            tile_key = (str(tile.x), str(tile.y), str(tile.z))
            if tile_key not in tiles_map.keys():
                tiles_map[tile_key] = []
            tiles_map[tile_key].append(path)

    if args.label:
        ext = "png"
        bands = 1
    if not args.label:
        if bands == 1:
            ext = "png"
        if bands == 3:
            ext = "webp"
        if bands > 3:
            ext = "tiff"

    tiles = []
    progress = tqdm(total=len(tiles_map), ascii=True, unit="tile")
    # Begin to tile plain tiles
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(path):

            raster = rasterio_open(path)
            w, s, e, n = transform_bounds(raster.crs, "EPSG:4326",
                                          *raster.bounds)
            transform, _, _ = calculate_default_transform(
                raster.crs, "EPSG:3857", raster.width, raster.height, w, s, e,
                n)
            tiles = [
                mercantile.Tile(x=x, y=y, z=z)
                for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
            ]
            tiled = []

            for tile in tiles:

                if cover and tile not in cover:
                    continue

                w, s, e, n = mercantile.xy_bounds(tile)

                # inspired by rio-tiler, cf: https://github.com/mapbox/rio-tiler/pull/45
                warp_vrt = WarpedVRT(
                    raster,
                    crs="epsg:3857",
                    resampling=Resampling.bilinear,
                    add_alpha=False,
                    transform=from_bounds(w, s, e, n, args.ts, args.ts),
                    width=math.ceil((e - w) / transform.a),
                    height=math.ceil((s - n) / transform.e),
                )
                data = warp_vrt.read(out_shape=(len(raster.indexes), args.ts,
                                                args.ts),
                                     window=warp_vrt.window(w, s, e, n))
                image = np.moveaxis(data, 0, 2)  # C,H,W -> H,W,C

                tile_key = (str(tile.x), str(tile.y), str(tile.z))
                if not args.label and len(
                        tiles_map[tile_key]) == 1 and is_nodata(
                            image, threshold=args.nodata_threshold):
                    progress.update()
                    continue

                if len(tiles_map[tile_key]) > 1:
                    out = os.path.join(splits_path,
                                       str(tiles_map[tile_key].index(path)))
                else:
                    out = args.out

                x, y, z = map(int, tile)

                if not args.label:
                    ret = tile_image_to_file(out, mercantile.Tile(x=x,
                                                                  y=y,
                                                                  z=z), image)
                if args.label:
                    ret = tile_label_to_file(out, mercantile.Tile(x=x,
                                                                  y=y,
                                                                  z=z),
                                             palette, image)

                assert ret, "Unable to write tile {} from raster {}.".format(
                    str(tile), raster)

                if len(tiles_map[tile_key]) == 1:
                    progress.update()
                    tiled.append(mercantile.Tile(x=x, y=y, z=z))

            return tiled

        for tiled in executor.map(worker, args.rasters):
            if tiled is not None:
                tiles.extend(tiled)

    # Aggregate remaining tiles splits
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(tile_key):

            if len(tiles_map[tile_key]) == 1:
                return

            image = np.zeros((args.ts, args.ts, bands), np.uint8)

            x, y, z = map(int, tile_key)
            for i in range(len(tiles_map[tile_key])):
                root = os.path.join(splits_path, str(i))
                _, path = tile_from_xyz(root, x, y, z)

                if not args.label:
                    split = tile_image_from_file(path)
                if args.label:
                    split = tile_label_from_file(path)
                    split = split.reshape(
                        (args.ts, args.ts, 1))  # H,W -> H,W,C

                assert image.shape == split.shape
                image[:, :, :] += split[:, :, :]

            if not args.label and is_nodata(image,
                                            threshold=args.nodata_threshold):
                progress.update()
                return

            tile = mercantile.Tile(x=x, y=y, z=z)

            if not args.label:
                ret = tile_image_to_file(args.out, tile, image)

            if args.label:
                ret = tile_label_to_file(args.out, tile, palette, image)

            assert ret, "Unable to write tile {} from raster {}.".format(
                str(tile_key))

            progress.update()
            return tile

        for tiled in executor.map(worker, tiles_map.keys()):
            if tiled is not None:
                tiles.append(tiled)

        if splits_path and os.path.isdir(splits_path):
            shutil.rmtree(splits_path)  # Delete suffixes dir if any

    if tiles and not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiles, tiles, ext, template)
示例#20
0
def main(args):
    config = load_config(args.config)
    tile_size = args.tile_size if args.tile_size else config["model"]["tile_size"]
    colors = [classe["color"] for classe in config["classes"]]

    os.makedirs(args.out, exist_ok=True)

    # We can only rasterize all tiles at a single zoom.
    assert all(tile.z == args.zoom for tile in tiles_from_csv(args.cover))

    # Find all tiles the features cover and make a map object for quick lookup.
    feature_map = collections.defaultdict(list)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def parse_polygon(feature_map, polygon, i):

        try:
            for i, ring in enumerate(polygon["coordinates"]):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[x, y] for point in ring for x, y in zip([point[0]], [point[1]])]

            for tile in burntiles.burn([{"type": "feature", "geometry": polygon}], zoom=args.zoom):
                feature_map[mercantile.Tile(*tile)].append({"type": "feature", "geometry": polygon})

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def parse_geometry(feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = parse_polygon(feature_map, geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = parse_polygon(feature_map, {"type": "Polygon", "coordinates": polygon}, i)
        else:
            log.log("Notice: {} is a non surfacic geometry type, skipping feature {}".format(geometry["type"], i))

        return feature_map

    for feature in args.features:
        with open(feature) as f:
            fc = json.load(f)
            for i, feature in enumerate(tqdm(fc["features"], ascii=True, unit="feature")):

                if feature["geometry"]["type"] == "GeometryCollection":
                    for geometry in feature["geometry"]["geometries"]:
                        feature_map = parse_geometry(feature_map, geometry, i)
                else:
                    feature_map = parse_geometry(feature_map, feature["geometry"], i)

    # Burn features to tiles and write to a slippy map directory.
    for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"):
        if tile in feature_map:
            out = burn(tile, feature_map[tile], tile_size)
        else:
            out = np.zeros(shape=(tile_size, tile_size), dtype=np.uint8)

        out_dir = os.path.join(args.out, str(tile.z), str(tile.x))
        os.makedirs(out_dir, exist_ok=True)

        out_path = os.path.join(out_dir, "{}.png".format(tile.y))

        if os.path.exists(out_path):
            prev = np.array(Image.open(out_path))
            out = np.maximum(out, prev)

        out = Image.fromarray(out, mode="P")

        out_path = os.path.join(args.out, str(tile.z), str(tile.x))
        os.makedirs(out_path, exist_ok=True)

        out.putpalette(complementary_palette(make_palette(colors[0], colors[1])))
        out.save(os.path.join(out_path, "{}.png".format(tile.y)), optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)