Beispiel #1
0
def gpu_worker(rank, world_size, lock_file, args, config, dataset, palette, transparency):

    dist.init_process_group(backend="nccl", init_method="file://" + lock_file, world_size=world_size, rank=rank)
    torch.cuda.set_device(rank)
    chkpt = torch.load(args.checkpoint, map_location=torch.device(rank))
    nn_module = load_module("neat_eo.nn.{}".format(chkpt["nn"].lower()))
    nn = getattr(nn_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"], chkpt["encoder"].lower()).to(rank)
    nn = DistributedDataParallel(nn, device_ids=[rank], find_unused_parameters=True)

    chkpt = torch.load(os.path.expanduser(args.checkpoint), map_location="cuda:{}".format(rank))
    assert nn.module.version == chkpt["model_version"], "Model Version mismatch"
    nn.load_state_dict(chkpt["state_dict"])

    sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=args.bs, shuffle=False, num_workers=args.workers, sampler=sampler)
    assert len(loader), "Empty predict dataset directory. Check your path."

    C, W, H = chkpt["shape_out"]

    nn.eval()
    with torch.no_grad():

        dataloader = tqdm(loader, desc="Predict", unit="Batch/GPU", ascii=True) if rank == 0 else loader

        for images, tiles in dataloader:

            if args.metatiles:
                N = images.shape[0]
                qs = int(W / 4)
                hs = int(W / 2)
                ts = int(W)

                # fmt:off
                probs = np.zeros((N, C, W, H), dtype=np.float)
                probs[:, :, 0:hs, 0:hs] = nn(images[:, :, 0:ts, 0:ts].to(rank)).data.cpu().numpy()[:, :, qs:-qs, qs:-qs]
                probs[:, :, 0:hs,  hs:] = nn(images[:, :, 0:ts,  hs:].to(rank)).data.cpu().numpy()[:, :, qs:-qs, qs:-qs]
                probs[:, :, hs:,  0:hs] = nn(images[:, :, hs:,  0:ts].to(rank)).data.cpu().numpy()[:, :, qs:-qs, qs:-qs]
                probs[:, :, hs:,   hs:] = nn(images[:, :, hs:,   hs:].to(rank)).data.cpu().numpy()[:, :, qs:-qs, qs:-qs]
                # fmt:on
            else:
                probs = nn(images.to(rank)).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))
                mask = np.zeros((W, H), dtype=np.uint8)

                for c in range(C):
                    mask += np.around(prob[c, :, :]).astype(np.uint8) * c

                tile_label_to_file(args.out, mercantile.Tile(x, y, z), palette, transparency, mask)
Beispiel #2
0
        def worker(tile_key):

            if len(tiles_map[tile_key]) == 1:
                return

            image = np.zeros((width, height, len(args.bands)), np.uint8)

            x, y, z = map(int, tile_key)
            for i in range(len(tiles_map[tile_key])):
                root = os.path.join(splits_path, str(i))
                _, path = tile_from_xyz(root, x, y, z)

                if not args.label:
                    split = tile_image_from_file(path)
                if args.label:
                    split = tile_label_from_file(path)

                if len(split.shape) == 2:
                    split = split.reshape((width, height, 1))  # H,W -> H,W,C

                assert image.shape == split.shape, "{}, {}".format(
                    image.shape, split.shape)
                image[np.where(image == 0)] += split[np.where(image == 0)]

            if not args.label and is_nodata(image, args.nodata,
                                            args.nodata_threshold,
                                            args.keep_borders):
                progress.update()
                return

            tile = mercantile.Tile(x=x, y=y, z=z)

            if not args.label:
                tile_image_to_file(args.out, tile, image)

            if args.label:
                tile_label_to_file(args.out, tile, palette, image)

            progress.update()
            return tile
Beispiel #3
0
        def worker(path):

            if path in skip:
                return None


            raster = rasterio_open(path)
            w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds)
            tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)]
            tiled = []

            for tile in tiles:

                if cover and tile not in cover:
                    continue

                w, s, e, n = mercantile.xy_bounds(tile)

                warp_vrt = WarpedVRT(
                    raster,
                    crs="epsg:3857",
                    resampling=Resampling.bilinear,
                    add_alpha=False,
                    transform=from_bounds(w, s, e, n, width, height),
                    width=width,
                    height=height,
                )

                data = warp_vrt.read(
                    out_shape=(len(args.bands), width, height), indexes=args.bands, window=warp_vrt.window(w, s, e, n)
                )
                if data.dtype == "uint16":  # GeoTiff could be 16 bits
                    data = np.uint8(data / 256)
                elif data.dtype == "uint32":  # or 32 bits
                    data = np.uint8(data / (256 * 256))

                image = np.moveaxis(data, 0, 2)  # C,H,W -> H,W,C

                tile_key = (str(tile.x), str(tile.y), str(tile.z))
                if (
                    not args.label
                    and len(tiles_map[tile_key]) == 1
                    and is_nodata(image, args.nodata, args.nodata_threshold, args.keep_borders)
                ):
                    progress.update()
                    continue

                if len(tiles_map[tile_key]) > 1:
                    out = os.path.join(splits_path, str(tiles_map[tile_key].index(path)))
                else:
                    out = args.out

                x, y, z = map(int, tile)

                if not args.label:
                    tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image, ext=ext)
                if args.label:
                    tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, args.nodata, image)

                if len(tiles_map[tile_key]) == 1:
                    tiled.append(mercantile.Tile(x=x, y=y, z=z))


                progress.update()

            raster.close()
            return tiled
Beispiel #4
0
def main(args):

    assert not (args.geojson is not None and args.pg is not None), "You have to choose between --pg or --geojson"
    assert len(args.ts.split(",")) == 2, "--ts expect width,height value (e.g 512,512)"

    config = load_config(args.config)
    check_classes(config)

    args.pg = config["auth"]["pg"] if not args.pg and "pg" in config["auth"].keys() else args.pg
    assert not (args.sql and not args.pg), "With --sql option, --pg dsn setting must also be provided"

    palette, transparency = make_palette([classe["color"] for classe in config["classes"]], complementary=True)
    index = [config["classes"].index(classe) for classe in config["classes"] if classe["title"] == args.type]
    assert index, "Requested type is not contains in your config file classes."
    burn_value = index[0]
    assert 0 < burn_value <= 255

    if args.sql:
        assert "limit" not in args.sql.lower(), "LIMIT is not supported"
        assert "TILE_GEOM" in args.sql, "TILE_GEOM filter not found in your SQL"
        sql = re.sub(r"ST_Intersects( )*\((.*)?TILE_GEOM(.*)?\)", "1=1", args.sql, re.I)
        assert sql and sql != args.sql, "Incorrect TILE_GEOM filter in your SQL"

    if os.path.dirname(os.path.expanduser(args.out)):
        os.makedirs(os.path.expanduser(args.out), exist_ok=True)
    args.out = os.path.expanduser(args.out)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    tiles = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))]
    assert len(tiles), "Empty Cover: {}".format(args.cover)

    if args.geojson:
        zoom = tiles[0].z
        assert not [tile for tile in tiles if tile.z != zoom], "Unsupported zoom mixed cover. Use PostGIS instead"

        feature_map = collections.defaultdict(list)

        log.log("neo rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                feature_collection = json.load(geojson)
                srid = geojson_srid(feature_collection)

                for i, feature in enumerate(tqdm(feature_collection["features"], ascii=True, unit="feature")):
                    feature_map = geojson_parse_feature(zoom, srid, feature_map, feature, args.buffer)

        features = args.geojson

    if args.sql:
        conn = psycopg2.connect(args.pg)
        db = conn.cursor()

        db.execute("""SELECT ST_Srid("1") AS srid FROM ({} LIMIT 1) AS t("1")""".format(sql))
        srid = db.fetchone()[0]
        assert srid and int(srid) > 0, "Unable to retrieve geometry SRID."

        features = args.sql

    if not len(feature_map):
        log.log("-----------------------------------------------")
        log.log("NOTICE: no feature to rasterize, seems peculiar")
        log.log("-----------------------------------------------")

    log.log("neo rasterize - rasterizing {} from {} on cover {}".format(args.type, features, args.cover))
    with open(os.path.join(os.path.expanduser(args.out), args.type.lower() + "_cover.csv"), mode="w") as cover:

        for tile in tqdm(tiles, ascii=True, unit="tile"):

            geojson = None

            if args.sql:
                w, s, e, n = tile_bbox(tile)
                tile_geom = "ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {})".format(w, s, e, n, srid)

                query = """
                WITH
                  sql  AS ({}),
                  geom AS (SELECT "1" AS geom FROM sql AS t("1")),
                  json AS (SELECT '{{"type": "Feature", "geometry": '
                         || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
                         || '}}' AS features
                        FROM geom)
                SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}'
                FROM json
                """.format(
                    args.sql.replace("TILE_GEOM", tile_geom)
                )

                db.execute(query)
                row = db.fetchone()
                try:
                    geojson = json.loads(row[0])["features"] if row and row[0] else None
                except Exception:
                    log.log("Warning: Invalid geometries, skipping {}".format(tile))
                    conn = psycopg2.connect(args.pg)
                    db = conn.cursor()

            if args.geojson:
                geojson = feature_map[tile] if tile in feature_map else None

            if geojson:
                num = len(geojson)
                out = geojson_tile_burn(tile, geojson, 4326, list(map(int, args.ts.split(","))), burn_value)

            if not geojson or out is None:
                num = 0
                out = np.zeros(shape=list(map(int, args.ts.split(","))), dtype=np.uint8)

            tile_label_to_file(args.out, tile, palette, transparency, out, append=args.append)
            cover.write("{},{},{}  {}{}".format(tile.x, tile.y, tile.z, num, os.linesep))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)