コード例 #1
0
ファイル: vectorize.py プロジェクト: vizzbee/robosat.pink
def main(args):
    config = load_config(args.config)
    check_classes(config)
    index = [i for i in (list(range(len(config["classes"])))) if config["classes"][i]["title"] == args.type]
    assert index, "Requested type {} not found among classes title in the config file.".format(args.type)
    print("RoboSat.pink - vectorize {} from {}".format(args.type, args.masks), file=sys.stderr, flush=True)

    out = open(args.out, "w", encoding="utf-8")
    assert out, "Unable to write in output file"

    out.write('{"type":"FeatureCollection","features":[')

    first = True
    for tile, path in tqdm(list(tiles_from_dir(args.masks, xyz_path=True)), ascii=True, unit="mask"):
        mask = (np.array(Image.open(path).convert("P"), dtype=np.uint8) == index).astype(np.uint8)
        try:
            C, W, H = mask.shape
        except:
            W, H = mask.shape
        transform = rasterio.transform.from_bounds((*mercantile.bounds(tile.x, tile.y, tile.z)), W, H)

        for shape, value in rasterio.features.shapes(mask, transform=transform, mask=mask):
            geom = '"geometry":{{"type": "Polygon", "coordinates":{}}}'.format(json.dumps(shape["coordinates"]))
            out.write('{}{{"type":"Feature",{}}}'.format("" if first else ",", geom))
            first = False

    out.write("]}")
コード例 #2
0
def main(args):
    dataset = load_config(args.dataset)

    labels = dataset["common"]["classes"]
    assert set(labels).issuperset(
        set(handlers.keys())), "handlers have a class label"
    index = labels.index(args.type)

    handler = handlers[args.type]()

    tiles = list(tiles_from_slippy_map(args.masks))

    for tile, path in tqdm(tiles, ascii=True, unit="mask"):
        image = np.array(Image.open(path).convert("P"), dtype=np.uint8)
        mask = (image == index).astype(np.uint8)

        handler.apply(tile, mask)

    handler.save(args.out)
コード例 #3
0
def main(args):
    config = load_config(args.config)
    check_classes(config)
    index = [
        i for i in (list(range(len(config["classes"]))))
        if config["classes"][i]["title"] == args.type
    ]
    assert index, "Requested type {} not found among classes title in the config file.".format(
        args.type)

    print("RoboSat.pink - vectorize {} from {}".format(args.type, args.masks))

    with open(args.out, "w", encoding="utf-8") as out:
        first = True
        out.write('{"type":"FeatureCollection","features":[')

        for tile, path in tqdm(list(tiles_from_slippy_map(args.masks)),
                               ascii=True,
                               unit="mask"):
            features = (np.array(Image.open(path).convert("P"),
                                 dtype=np.uint8) == index).astype(np.uint8)
            try:
                C, W, H = features.shape
            except:
                W, H = features.shape
            transform = rasterio.transform.from_bounds(
                (*mercantile.bounds(tile.x, tile.y, tile.z)), W, H)

            for shape, value in rasterio.features.shapes(features,
                                                         transform=transform):
                prop = '"properties":{{"x":{},"y":{},"z":{}}}'.format(
                    int(tile.x), int(tile.y), int(tile.z))
                geom = '"geometry":{{"type": "Polygon", "coordinates":{}}}'.format(
                    json.dumps(shape["coordinates"]))
                out.write('{}{{"type":"Feature",{},{}}}'.format(
                    "," if not first else "", geom, prop))
                first = False

        out.write("]}")
コード例 #4
0
ファイル: tile.py プロジェクト: liushuaicare/robosat.pink
def main(args):

    if not args.workers:
        args.workers = min(os.cpu_count(), len(args.rasters))

    if args.label:
        config = load_config(args.config)
        check_classes(config)
        colors = [classe["color"] for classe in config["classes"]]
        palette = make_palette(*colors)

    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    splits_path = os.path.join(os.path.expanduser(args.out), ".splits")
    tiles_map = {}

    print("RoboSat.pink - tile on CPU, with {} workers".format(args.workers))

    bands = -1
    for path in args.rasters:
        raster = rasterio_open(path)
        w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds)

        if bands != -1:
            assert bands == len(
                raster.indexes), "Coverage must be bands consistent"
        bands = len(raster.indexes)

        tiles = [
            mercantile.Tile(x=x, y=y, z=z)
            for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
        ]
        tiles = list(set(tiles) & set(cover)) if cover else tiles

        for tile in tiles:
            tile_key = (str(tile.x), str(tile.y), str(tile.z))
            if tile_key not in tiles_map.keys():
                tiles_map[tile_key] = []
            tiles_map[tile_key].append(path)

    if args.label:
        ext = "png"
        bands = 1
    if not args.label:
        if bands == 1:
            ext = "png"
        if bands == 3:
            ext = "webp"
        if bands > 3:
            ext = "tiff"

    tiles = []
    progress = tqdm(total=len(tiles_map), ascii=True, unit="tile")
    # Begin to tile plain tiles
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(path):

            raster = rasterio_open(path)
            w, s, e, n = transform_bounds(raster.crs, "EPSG:4326",
                                          *raster.bounds)
            transform, _, _ = calculate_default_transform(
                raster.crs, "EPSG:3857", raster.width, raster.height, w, s, e,
                n)
            tiles = [
                mercantile.Tile(x=x, y=y, z=z)
                for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
            ]
            tiled = []

            for tile in tiles:

                if cover and tile not in cover:
                    continue

                w, s, e, n = mercantile.xy_bounds(tile)

                # inspired by rio-tiler, cf: https://github.com/mapbox/rio-tiler/pull/45
                warp_vrt = WarpedVRT(
                    raster,
                    crs="epsg:3857",
                    resampling=Resampling.bilinear,
                    add_alpha=False,
                    transform=from_bounds(w, s, e, n, args.ts, args.ts),
                    width=math.ceil((e - w) / transform.a),
                    height=math.ceil((s - n) / transform.e),
                )
                data = warp_vrt.read(out_shape=(len(raster.indexes), args.ts,
                                                args.ts),
                                     window=warp_vrt.window(w, s, e, n))
                image = np.moveaxis(data, 0, 2)  # C,H,W -> H,W,C

                tile_key = (str(tile.x), str(tile.y), str(tile.z))
                if not args.label and len(
                        tiles_map[tile_key]) == 1 and is_nodata(
                            image, threshold=args.nodata_threshold):
                    progress.update()
                    continue

                if len(tiles_map[tile_key]) > 1:
                    out = os.path.join(splits_path,
                                       str(tiles_map[tile_key].index(path)))
                else:
                    out = args.out

                x, y, z = map(int, tile)

                if not args.label:
                    ret = tile_image_to_file(out, mercantile.Tile(x=x,
                                                                  y=y,
                                                                  z=z), image)
                if args.label:
                    ret = tile_label_to_file(out, mercantile.Tile(x=x,
                                                                  y=y,
                                                                  z=z),
                                             palette, image)

                assert ret, "Unable to write tile {} from raster {}.".format(
                    str(tile), raster)

                if len(tiles_map[tile_key]) == 1:
                    progress.update()
                    tiled.append(mercantile.Tile(x=x, y=y, z=z))

            return tiled

        for tiled in executor.map(worker, args.rasters):
            if tiled is not None:
                tiles.extend(tiled)

    # Aggregate remaining tiles splits
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(tile_key):

            if len(tiles_map[tile_key]) == 1:
                return

            image = np.zeros((args.ts, args.ts, bands), np.uint8)

            x, y, z = map(int, tile_key)
            for i in range(len(tiles_map[tile_key])):
                root = os.path.join(splits_path, str(i))
                _, path = tile_from_xyz(root, x, y, z)

                if not args.label:
                    split = tile_image_from_file(path)
                if args.label:
                    split = tile_label_from_file(path)
                    split = split.reshape(
                        (args.ts, args.ts, 1))  # H,W -> H,W,C

                assert image.shape == split.shape
                image[:, :, :] += split[:, :, :]

            if not args.label and is_nodata(image,
                                            threshold=args.nodata_threshold):
                progress.update()
                return

            tile = mercantile.Tile(x=x, y=y, z=z)

            if not args.label:
                ret = tile_image_to_file(args.out, tile, image)

            if args.label:
                ret = tile_label_to_file(args.out, tile, palette, image)

            assert ret, "Unable to write tile {} from raster {}.".format(
                str(tile_key))

            progress.update()
            return tile

        for tiled in executor.map(worker, tiles_map.keys()):
            if tiled is not None:
                tiles.append(tiled)

        if splits_path and os.path.isdir(splits_path):
            shutil.rmtree(splits_path)  # Delete suffixes dir if any

    if tiles and not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiles, tiles, ext, template)
コード例 #5
0
def main(args):
    config = load_config(args.config)
    args.out = os.path.expanduser(args.out)
    config["model"][
        "loader"] = args.loader if args.loader else config["model"]["loader"]
    config["model"]["bs"] = args.bs if args.bs else config["model"]["bs"]
    config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"]
    config["model"]["ts"] = tuple(map(
        int, args.ts.split(","))) if args.ts else config["model"]["ts"]
    config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"]
    config["model"]["encoder"] = args.encoder if args.encoder else config[
        "model"]["encoder"]
    config["model"][
        "loss"] = args.loss if args.loss else config["model"]["loss"]
    config["model"]["da"] = args.da if args.da else config["model"]["da"]
    config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"]
    args.workers = config["model"]["bs"] if not args.workers else args.workers
    check_classes(config)
    check_channels(config)
    check_model(config)

    assert os.path.isdir(os.path.expanduser(
        args.dataset)), "Dataset is not a directory"
    if args.no_training and args.no_validation:
        sys.exit()

    log = Logs(os.path.join(args.out, "log"))
    csv_train = None if args.no_training else open(
        os.path.join(args.out, "training.csv"), mode="a")
    csv_val = None if args.no_validation else open(
        os.path.join(args.out, "validation.csv"), mode="a")

    if torch.cuda.is_available():
        log.log("RoboSat.pink - training on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - training on CPU, with {} workers - (Torch:{})".
                format(args.workers, torch.__version__))
        log.log("")
        log.log("==========================================================")
        log.log("WARNING: Are you -really- sure about not training on GPU ?")
        log.log("==========================================================")
        log.log("")
        device = torch.device("cpu")

    log.log("--- Input tensor from Dataset: {} ---".format(args.dataset))
    num_channel = 1  # 1-based numerotation
    for channel in config["channels"]:
        for band in channel["bands"]:
            log.log("Channel {}:\t\t {}[band: {}]".format(
                num_channel, channel["name"], band))
            num_channel += 1

    log.log("--- Output Classes ---")
    for c, classe in enumerate(config["classes"]):
        log.log("Class {}:\t\t {}".format(c, classe["title"]))

    log.log("--- Hyper Parameters ---")
    for hp in config["model"]:
        log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp]))

    loader = load_module("robosat_pink.loaders.{}".format(
        config["model"]["loader"].lower()))
    loader_train = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"], os.path.join(args.dataset,
                                                    "training"), None, "train")
    loader_val = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"],
        os.path.join(args.dataset, "validation"), None, "train")

    encoder = config["model"]["encoder"].lower()
    nn_module = load_module("robosat_pink.nn.{}".format(
        config["model"]["nn"].lower()))
    nn = getattr(nn_module, config["model"]["nn"])(loader_train.shape_in,
                                                   loader_train.shape_out,
                                                   encoder, config).to(device)
    nn = torch.nn.DataParallel(nn)
    optimizer = Adam(nn.parameters(), lr=config["model"]["lr"])

    resume = 0
    if args.checkpoint:
        chkpt = torch.load(os.path.expanduser(args.checkpoint),
                           map_location=device)
        nn.load_state_dict(chkpt["state_dict"])
        log.log("--- Using Checkpoint ---")
        log.log("Path:\t\t {}".format(args.checkpoint))
        log.log("UUID:\t\t {}".format(chkpt["uuid"]))

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]
            assert resume < args.epochs, "Epoch asked, already reached by the given checkpoint"

    loss_module = load_module("robosat_pink.losses.{}".format(
        config["model"]["loss"].lower()))
    criterion = getattr(loss_module, config["model"]["loss"])().to(device)

    bs = config["model"]["bs"]
    train_loader = DataLoader(loader_train,
                              batch_size=bs,
                              shuffle=True,
                              drop_last=True,
                              num_workers=args.workers)
    val_loader = DataLoader(loader_val,
                            batch_size=bs,
                            shuffle=False,
                            drop_last=True,
                            num_workers=args.workers)

    if args.no_training:
        epoch = 0
        process(val_loader, config, log, csv_val, epoch, device, nn, criterion,
                "eval")
        sys.exit()

    for epoch in range(resume + 1, args.epochs + 1):  # 1-N based
        UUID = uuid.uuid1()
        log.log("---{}Epoch: {}/{} -- UUID: {}".format(os.linesep, epoch,
                                                       args.epochs, UUID))

        process(train_loader, config, log, csv_train, epoch, device, nn,
                criterion, "train", optimizer)

        try:  # https://github.com/pytorch/pytorch/issues/9176
            nn_doc = nn.module.doc
            nn_version = nn.module.version
        except AttributeError:
            nn_doc = nn.doc
            nn_version = nn.version

        states = {
            "uuid": UUID,
            "model_version": nn_version,
            "producer_name": "RoboSat.pink",
            "producer_version": rsp.__version__,
            "model_licence": "MIT",
            "domain": "pink.RoboSat",  # reverse-DNS
            "doc_string": nn_doc,
            "shape_in": loader_train.shape_in,
            "shape_out": loader_train.shape_out,
            "state_dict": nn.state_dict(),
            "epoch": epoch,
            "nn": config["model"]["nn"],
            "encoder": config["model"]["encoder"],
            "optimizer": optimizer.state_dict(),
            "loader": config["model"]["loader"],
        }
        checkpoint_path = os.path.join(args.out,
                                       "checkpoint-{:05d}.pth".format(epoch))
        if epoch == args.epochs or not (epoch % args.saving):
            log.log("[Saving checkpoint]")
            torch.save(states, checkpoint_path)

        if not args.no_validation:
            process(val_loader, config, log, csv_val, epoch, device, nn,
                    criterion, "eval")
コード例 #6
0
ファイル: predict.py プロジェクト: dselivanov/robosat.pink
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    palette = make_palette([classe["color"] for classe in config["classes"]])
    args.workers = torch.cuda.device_count() * 2 if torch.device(
        "cuda") and not args.workers else args.workers
    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        log.log("")
        log.log("============================================================")
        log.log("WARNING: Are you -really- sure about not predicting on GPU ?")
        log.log("============================================================")
        log.log("")
        device = torch.device("cpu")

    chkpt = torch.load(args.checkpoint, map_location=device)
    model_module = load_module("robosat_pink.models.{}".format(
        chkpt["nn"].lower()))
    nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"],
                                            chkpt["shape_out"]).to(device)
    nn = torch.nn.DataParallel(nn)
    nn.load_state_dict(chkpt["state_dict"])
    nn.eval()

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    mode = "predict" if not args.translate else "predict_translate"
    loader_module = load_module("robosat_pink.loaders.{}".format(
        chkpt["loader"].lower()))
    loader_predict = getattr(loader_module,
                             chkpt["loader"])(config,
                                              chkpt["shape_in"][1:3],
                                              args.dataset,
                                              cover,
                                              mode=mode)

    loader = DataLoader(loader_predict,
                        batch_size=args.bs,
                        num_workers=args.workers)
    assert len(loader), "Empty predict dataset directory. Check your path."

    tiled = []
    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):

            images = images.to(device)

            outputs = nn(images)
            probs = torch.nn.functional.softmax(outputs,
                                                dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))
                mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()
                if args.translate:
                    tile_translate_to_file(args.out, mercantile.Tile(x, y, z),
                                           palette, mask)
                else:
                    tile_label_to_file(args.out, mercantile.Tile(x, y, z),
                                       palette, mask)
                tiled.append(mercantile.Tile(x, y, z))

    if not args.no_web_ui and not args.translate:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiled, tiled, "png", template)
コード例 #7
0
def main(args):

    assert not (args.sql and
                args.geojson), "You can only use at once --pg OR --geojson."
    assert not (args.pg and not args.sql
                ), "With PostgreSQL --pg, --sql must also be provided"
    assert len(args.ts.split(
        ",")) == 2, "--ts expect width,height value (e.g 512,512)"

    config = load_config(args.config)
    check_classes(config)

    palette = make_palette([classe["color"] for classe in config["classes"]],
                           complementary=True)
    index = [
        config["classes"].index(classe) for classe in config["classes"]
        if classe["title"] == args.type
    ]
    assert index, "Requested type is not contains in your config file classes."
    burn_value = int(math.pow(2, index[0] - 1))  # 8bits One Hot Encoding
    assert 0 <= burn_value <= 128

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    if args.geojson:

        tiles = [
            tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
        ]
        assert tiles, "Empty cover"

        zoom = tiles[0].z
        assert not [tile for tile in tiles if tile.z != zoom
                    ], "Unsupported zoom mixed cover. Use PostGIS instead"

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                feature_collection = json.load(geojson)
                srid = geojson_srid(feature_collection)

                feature_map = collections.defaultdict(list)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):
                    feature_map = geojson_parse_feature(
                        zoom, srid, feature_map, feature)

        features = args.geojson

    if args.pg:

        conn = psycopg2.connect(args.pg)
        db = conn.cursor()

        assert "limit" not in args.sql.lower(), "LIMIT is not supported"
        assert "TILE_GEOM" in args.sql, "TILE_GEOM filter not found in your SQL"
        sql = re.sub(r"ST_Intersects( )*\((.*)?TILE_GEOM(.*)?\)", "1=1",
                     args.sql, re.I)
        assert sql and sql != args.sql

        db.execute(
            """SELECT ST_Srid("1") AS srid FROM ({} LIMIT 1) AS t("1")""".
            format(sql))
        srid = db.fetchone()[0]
        assert srid and int(srid) > 0, "Unable to retrieve geometry SRID."

        features = args.sql

    log.log(
        "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format(
            args.type, features, args.cover))
    with open(os.path.join(os.path.expanduser(args.out),
                           "instances_" + args.type.lower() + ".cover"),
              mode="w") as cover:

        for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))),
                         ascii=True,
                         unit="tile"):

            geojson = None

            if args.pg:

                w, s, e, n = tile_bbox(tile)
                tile_geom = "ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {})".format(
                    w, s, e, n, srid)

                query = """
                WITH
                  sql  AS ({}),
                  geom AS (SELECT "1" AS geom FROM sql AS t("1")),
                  json AS (SELECT '{{"type": "Feature", "geometry": '
                         || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
                         || '}}' AS features
                        FROM geom)
                SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}'
                FROM json
                """.format(args.sql.replace("TILE_GEOM", tile_geom))

                db.execute(query)
                row = db.fetchone()
                try:
                    geojson = json.loads(
                        row[0])["features"] if row and row[0] else None
                except Exception:
                    log.log("Warning: Invalid geometries, skipping {}".format(
                        tile))
                    conn = psycopg2.connect(args.pg)
                    db = conn.cursor()

            if args.geojson:
                geojson = feature_map[tile] if tile in feature_map else None

            if geojson:
                num = len(geojson)
                out = geojson_tile_burn(tile, geojson, 4326,
                                        list(map(int, args.ts.split(","))),
                                        burn_value)

            if not geojson or out is None:
                num = 0
                out = np.zeros(shape=list(map(int, args.ts.split(","))),
                               dtype=np.uint8)

            tile_label_to_file(args.out,
                               tile,
                               palette,
                               out,
                               append=args.append)
            cover.write("{},{},{}  {}{}".format(tile.x, tile.y, tile.z, num,
                                                os.linesep))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
コード例 #8
0
def main(args):
    config = load_config(args.config)
    args.out = os.path.expanduser(args.out)
    args.workers = torch.cuda.device_count() * 2 if torch.device("cuda") and not args.workers else args.workers
    config["model"]["loader"] = args.loader if args.loader else config["model"]["loader"]
    config["model"]["bs"] = args.bs if args.bs else config["model"]["bs"]
    config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"]
    config["model"]["ts"] = args.ts if args.ts else config["model"]["ts"]
    config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"]
    config["model"]["loss"] = args.loss if args.loss else config["model"]["loss"]
    config["model"]["da"] = args.da if args.da else config["model"]["da"]
    config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"]
    check_classes(config)
    check_channels(config)
    check_model(config)

    if not os.path.isdir(os.path.expanduser(args.dataset)):
        sys.exit("ERROR: dataset {} is not a directory".format(args.dataset))

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - training on {} GPUs, with {} workers".format(torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(torch.__version__, torch.version.cuda, torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - training on CPU, with {} workers - (Torch:{})".format(args.workers, torch.__version__))
        log.log("WARNING: Are you really sure sure about not training on GPU ?")
        device = torch.device("cpu")

    loader = load_module("robosat_pink.loaders.{}".format(config["model"]["loader"].lower()))
    loader_train = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"], os.path.join(args.dataset, "training"), "train"
    )
    loader_val = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"], os.path.join(args.dataset, "validation"), "train"
    )

    model_module = load_module("robosat_pink.models.{}".format(config["model"]["nn"].lower()))

    nn = getattr(model_module, config["model"]["nn"])(loader_train.shape_in, loader_train.shape_out, config).to(device)
    nn = torch.nn.DataParallel(nn)
    optimizer = Adam(nn.parameters(), lr=config["model"]["lr"])

    resume = 0
    if args.checkpoint:
        chkpt = torch.load(os.path.expanduser(args.checkpoint), map_location=device)
        nn.load_state_dict(chkpt["state_dict"])
        log.log("Using checkpoint: {}".format(args.checkpoint))

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]
            if resume >= args.epochs:
                sys.exit("ERROR: Epoch {} already reached by the given checkpoint".format(config["model"]["epochs"]))

    loss_module = load_module("robosat_pink.losses.{}".format(config["model"]["loss"].lower()))
    criterion = getattr(loss_module, config["model"]["loss"])().to(device)

    bs = config["model"]["bs"]
    train_loader = DataLoader(loader_train, batch_size=bs, shuffle=True, drop_last=True, num_workers=args.workers)
    val_loader = DataLoader(loader_val, batch_size=bs, shuffle=False, drop_last=True, num_workers=args.workers)

    log.log("--- Input tensor from Dataset: {} ---".format(args.dataset))
    num_channel = 1  # 1-based numerotation
    for channel in config["channels"]:
        for band in channel["bands"]:
            log.log("Channel {}:\t\t {}[band: {}]".format(num_channel, channel["name"], band))
            num_channel += 1

    log.log("--- Hyper Parameters ---")
    for hp in config["model"]:
        log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp]))

    for epoch in range(resume, args.epochs):
        UUID = uuid.uuid1()
        log.log("---{}Epoch: {}/{} -- UUID: {}".format(os.linesep, epoch + 1, args.epochs, UUID))

        process(train_loader, config, log, device, nn, criterion, "train", optimizer)
        if not args.no_validation:
            process(val_loader, config, log, device, nn, criterion, "eval")

        try:  # https://github.com/pytorch/pytorch/issues/9176
            nn_doc = nn.module.doc
            nn_version = nn.module.version
        except AttributeError:
            nn_version = nn.version
            nn_doc == nn.doc

        states = {
            "uuid": UUID,
            "model_version": nn_version,
            "producer_name": "RoboSat.pink",
            "producer_version": "0.4.0",
            "model_licence": "MIT",
            "domain": "pink.RoboSat",  # reverse-DNS
            "doc_string": nn_doc,
            "shape_in": loader_train.shape_in,
            "shape_out": loader_train.shape_out,
            "state_dict": nn.state_dict(),
            "epoch": epoch + 1,
            "nn": config["model"]["nn"],
            "optimizer": optimizer.state_dict(),
            "loader": config["model"]["loader"],
        }
        checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}.pth".format(epoch + 1))
        torch.save(states, checkpoint_path)
コード例 #9
0
def main(args):

    if args.pg:
        if not args.sql:
            sys.exit("ERROR: With PostgreSQL db, --sql must be provided")

    if (args.sql and args.geojson) or (args.sql and not args.pg):
        sys.exit(
            "ERROR: You can use either --pg or --geojson inputs, but only one at once."
        )

    config = load_config(args.config)
    check_classes(config)
    palette = make_palette(*[classe["color"] for classe in config["classes"]],
                           complementary=True)
    burn_value = next(config["classes"].index(classe)
                      for classe in config["classes"]
                      if classe["title"] == args.type)
    if "burn_value" not in locals():
        sys.exit(
            "ERROR: asked type to rasterize is not contains in your config file classes."
        )

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def geojson_parse_polygon(zoom, srid, feature_map, polygon, i):

        try:
            if srid != 4326:
                polygon = [
                    xy for xy in geojson_reproject(
                        {
                            "type": "feature",
                            "geometry": polygon
                        }, srid, 4326)
                ][0]

            for i, ring in enumerate(
                    polygon["coordinates"]
            ):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[
                    x, y
                ] for point in ring for x, y in zip([point[0]], [point[1]])]

            if polygon["coordinates"]:
                for tile in burntiles.burn([{
                        "type": "feature",
                        "geometry": polygon
                }],
                                           zoom=zoom):
                    feature_map[mercantile.Tile(*tile)].append({
                        "type":
                        "feature",
                        "geometry":
                        polygon
                    })

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def geojson_parse_geometry(zoom, srid, feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = geojson_parse_polygon(zoom, srid, feature_map,
                                                geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = geojson_parse_polygon(zoom, srid, feature_map, {
                    "type": "Polygon",
                    "coordinates": polygon
                }, i)
        else:
            log.log(
                "Notice: {} is a non surfacic geometry type, skipping feature {}"
                .format(geometry["type"], i))

        return feature_map

    if args.geojson:

        tiles = [
            tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
        ]
        assert tiles, "Empty cover"

        zoom = tiles[0].z
        assert not [tile for tile in tiles if tile.z != zoom
                    ], "Unsupported zoom mixed cover. Use PostGIS instead"

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                feature_collection = json.load(geojson)

                try:
                    crs_mapping = {"CRS84": "4326", "900913": "3857"}
                    srid = feature_collection["crs"]["properties"][
                        "name"].split(":")[-1]
                    srid = int(srid) if srid not in crs_mapping else int(
                        crs_mapping[srid])
                except:
                    srid = int(4326)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):

                    if feature["geometry"]["type"] == "GeometryCollection":
                        for geometry in feature["geometry"]["geometries"]:
                            feature_map = geojson_parse_geometry(
                                zoom, srid, feature_map, geometry, i)
                    else:
                        feature_map = geojson_parse_geometry(
                            zoom, srid, feature_map, feature["geometry"], i)
        features = args.geojson

    if args.pg:

        conn = psycopg2.connect(args.pg)
        db = conn.cursor()

        assert "limit" not in args.sql.lower(), "LIMIT is not supported"
        db.execute(
            "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(
                args.sql))
        srid = db.fetchone()[0]
        assert srid, "Unable to retrieve geometry SRID."

        if "where" not in args.sql.lower(
        ):  # TODO: Find a more reliable way to handle feature filtering
            args.sql += " WHERE ST_Intersects(tile.geom, geom)"
        else:
            args.sql += " AND ST_Intersects(tile.geom, geom)"
        features = args.sql

    log.log(
        "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format(
            args.type, features, args.cover))
    with open(os.path.join(os.path.expanduser(args.out), "instances.cover"),
              mode="w") as cover:

        for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))),
                         ascii=True,
                         unit="tile"):

            geojson = None

            if args.pg:

                w, s, e, n = tile_bbox(tile)

                query = """
                WITH
                  tile AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}) AS geom),
                  geom AS (SELECT ST_Intersection(tile.geom, sql.geom) AS geom FROM tile CROSS JOIN LATERAL ({}) sql),
                  json AS (SELECT '{{"type": "Feature", "geometry": '
                         || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
                         || '}}' AS features
                        FROM geom)
                SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}'
                FROM json
                """.format(w, s, e, n, srid, args.sql)

                db.execute(query)
                row = db.fetchone()
                try:
                    geojson = json.loads(
                        row[0])["features"] if row and row[0] else None
                except Exception:
                    log.log("Warning: Invalid geometries, skipping {}".format(
                        tile))
                    conn = psycopg2.connect(args.pg)
                    db = conn.cursor()

            if args.geojson:
                geojson = feature_map[tile] if tile in feature_map else None

            if geojson:
                num = len(geojson)
                out = geojson_tile_burn(tile, geojson, 4326, args.ts,
                                        burn_value)

            if not geojson or out is None:
                num = 0
                out = np.zeros(shape=(args.ts, args.ts), dtype=np.uint8)

            tile_label_to_file(args.out, tile, palette, out)
            cover.write("{},{},{}  {}{}".format(tile.x, tile.y, tile.z, num,
                                                os.linesep))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
コード例 #10
0
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    palette = make_palette([classe["color"] for classe in config["classes"]])
    if not args.bs:
        try:
            args.bs = config["model"]["bs"]
        except:
            pass

    assert args.bs, "For rsp predict, model/bs must be set either in config file, or pass trought parameter --bs"
    args.workers = args.bs if not args.workers else args.workers
    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        log.log("")
        log.log("============================================================")
        log.log("WARNING: Are you -really- sure about not predicting on GPU ?")
        log.log("============================================================")
        log.log("")
        device = torch.device("cpu")

    chkpt = torch.load(args.checkpoint, map_location=device)
    nn_module = load_module("robosat_pink.nn.{}".format(chkpt["nn"].lower()))
    nn = getattr(nn_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"],
                                         chkpt["encoder"].lower()).to(device)
    nn = torch.nn.DataParallel(nn)
    nn.load_state_dict(chkpt["state_dict"])
    nn.eval()

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        tiled = []
        if args.passes in ["first", "both"]:
            log.log("== Predict First Pass ==")
            tiled = predict(config, cover, args, palette, chkpt, nn, device,
                            "predict")

        if args.passes in ["second", "both"]:
            log.log("== Predict Second Pass ==")
            predict(config, cover, args, palette, chkpt, nn, device,
                    "predict_translate")

    if not args.no_web_ui and tiled:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiled, tiled, "png", template)
コード例 #11
0
ファイル: tile.py プロジェクト: ndavid/robosat.pink
def main(args):

    if not args.workers:
        args.workers = min(os.cpu_count(), len(args.rasters))

    if args.label:
        config = load_config(args.config)
        check_classes(config)
        colors = [classe["color"] for classe in config["classes"]]
        palette = make_palette(colors)

    assert len(args.ts.split(
        ",")) == 2, "--ts expect width,height value (e.g 512,512)"
    width, height = list(map(int, args.ts.split(",")))

    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    splits_path = os.path.join(os.path.expanduser(args.out), ".splits")
    tiles_map = {}

    print("RoboSat.pink - tile on CPU, with {} workers".format(args.workers),
          file=sys.stderr,
          flush=True)

    bands = -1
    for path in args.rasters:
        raster = rasterio_open(path)
        w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds)

        if bands != -1:
            assert bands == len(
                raster.indexes), "Coverage must be bands consistent"
        bands = len(raster.indexes)

        tiles = [
            mercantile.Tile(x=x, y=y, z=z)
            for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
        ]
        tiles = list(set(tiles) & set(cover)) if cover else tiles

        for tile in tiles:
            tile_key = (str(tile.x), str(tile.y), str(tile.z))
            if tile_key not in tiles_map.keys():
                tiles_map[tile_key] = []
            tiles_map[tile_key].append(path)

    if args.label:
        ext = "png"
        bands = 1
    if not args.label:
        if bands == 1:
            ext = "png"
        if bands == 3:
            ext = "webp"
        if bands > 3:
            ext = "tiff"

    tiles = []
    progress = tqdm(total=len(tiles_map), ascii=True, unit="tile")
    # Begin to tile plain tiles
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(path):

            raster = rasterio_open(path)
            w, s, e, n = transform_bounds(raster.crs, "EPSG:4326",
                                          *raster.bounds)
            tiles = [
                mercantile.Tile(x=x, y=y, z=z)
                for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
            ]
            tiled = []

            for tile in tiles:

                if cover and tile not in cover:
                    continue

                w, s, e, n = mercantile.xy_bounds(tile)

                warp_vrt = WarpedVRT(
                    raster,
                    crs="epsg:3857",
                    resampling=Resampling.bilinear,
                    add_alpha=False,
                    transform=from_bounds(w, s, e, n, width, height),
                    width=width,
                    height=height,
                )
                data = warp_vrt.read(out_shape=(len(raster.indexes), width,
                                                height),
                                     window=warp_vrt.window(w, s, e, n))

                if data.dtype == "uint16":  # GeoTiff could be 16 bits
                    data = np.uint8(data / 256)
                elif data.dtype == "uint32":  # or 32 bits
                    data = np.uint8(data / (256 * 256))

                image = np.moveaxis(data, 0, 2)  # C,H,W -> H,W,C

                tile_key = (str(tile.x), str(tile.y), str(tile.z))
                if (not args.label
                        and len(tiles_map[tile_key]) == 1 and is_nodata(
                            image, args.nodata, args.nodata_threshold,
                            args.keep_borders)):
                    progress.update()
                    continue

                if len(tiles_map[tile_key]) > 1:
                    out = os.path.join(splits_path,
                                       str(tiles_map[tile_key].index(path)))
                else:
                    out = args.out

                x, y, z = map(int, tile)

                if not args.label:
                    tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z),
                                       image)
                if args.label:
                    tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z),
                                       palette, image)

                if len(tiles_map[tile_key]) == 1:
                    progress.update()
                    tiled.append(mercantile.Tile(x=x, y=y, z=z))

            return tiled

        for tiled in executor.map(worker, args.rasters):
            if tiled is not None:
                tiles.extend(tiled)

    # Aggregate remaining tiles splits
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(tile_key):

            if len(tiles_map[tile_key]) == 1:
                return

            image = np.zeros((width, height, bands), np.uint8)

            x, y, z = map(int, tile_key)
            for i in range(len(tiles_map[tile_key])):
                root = os.path.join(splits_path, str(i))
                _, path = tile_from_xyz(root, x, y, z)

                if not args.label:
                    split = tile_image_from_file(path)
                if args.label:
                    split = tile_label_from_file(path)
                    split = split.reshape((width, height, 1))  # H,W -> H,W,C

                assert image.shape == split.shape
                image[np.where(image == 0)] += split[np.where(image == 0)]

            if not args.label and is_nodata(image, args.nodata,
                                            args.nodata_threshold,
                                            args.keep_borders):
                progress.update()
                return

            tile = mercantile.Tile(x=x, y=y, z=z)

            if not args.label:
                tile_image_to_file(args.out, tile, image)

            if args.label:
                tile_label_to_file(args.out, tile, palette, image)

            progress.update()
            return tile

        for tiled in executor.map(worker, tiles_map.keys()):
            if tiled is not None:
                tiles.append(tiled)

        if splits_path and os.path.isdir(splits_path):
            shutil.rmtree(splits_path)  # Delete suffixes dir if any

    if tiles and not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiles, tiles, ext, template)
コード例 #12
0
def main(args):

    if (args.geojson and args.postgis) or (not args.geojson
                                           and not args.postgis):
        sys.exit(
            "ERROR: Input features to rasterize must be either GeoJSON or PostGIS"
        )

    if args.postgis and not args.pg_dsn:
        sys.exit(
            "ERROR: With PostGIS input features, --pg_dsn must be provided")

    config = load_config(args.config)
    check_classes(config)
    palette = make_palette(*[classe["color"] for classe in config["classes"]],
                           complementary=True)
    burn_value = next(config["classes"].index(classe)
                      for classe in config["classes"]
                      if classe["title"] == args.type)
    if "burn_value" not in locals():
        sys.exit(
            "ERROR: asked type to rasterize is not contains in your config file classes."
        )

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def geojson_parse_polygon(zoom, srid, feature_map, polygon, i):

        try:
            if srid != 4326:
                polygon = [
                    xy for xy in geojson_reproject(
                        {
                            "type": "feature",
                            "geometry": polygon
                        }, srid, 4326)
                ][0]

            for i, ring in enumerate(
                    polygon["coordinates"]
            ):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[
                    x, y
                ] for point in ring for x, y in zip([point[0]], [point[1]])]

            if polygon["coordinates"]:
                for tile in burntiles.burn([{
                        "type": "feature",
                        "geometry": polygon
                }],
                                           zoom=zoom):
                    feature_map[mercantile.Tile(*tile)].append({
                        "type":
                        "feature",
                        "geometry":
                        polygon
                    })

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def geojson_parse_geometry(zoom, srid, feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = geojson_parse_polygon(zoom, srid, feature_map,
                                                geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = geojson_parse_polygon(zoom, srid, feature_map, {
                    "type": "Polygon",
                    "coordinates": polygon
                }, i)
        else:
            log.log(
                "Notice: {} is a non surfacic geometry type, skipping feature {}"
                .format(geometry["type"], i))

        return feature_map

    if args.geojson:

        try:
            tiles = [
                tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
            ]
            zoom = tiles[0].z
            assert not [tile for tile in tiles if tile.z != zoom]
        except:
            sys.exit("ERROR: Inconsistent cover {}".format(args.cover))

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                feature_collection = json.load(geojson)

                try:
                    crs_mapping = {"CRS84": "4326", "900913": "3857"}
                    srid = feature_collection["crs"]["properties"][
                        "name"].split(":")[-1]
                    srid = int(srid) if srid not in crs_mapping else int(
                        crs_mapping[srid])
                except:
                    srid = int(4326)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):

                    if feature["geometry"]["type"] == "GeometryCollection":
                        for geometry in feature["geometry"]["geometries"]:
                            feature_map = geojson_parse_geometry(
                                zoom, srid, feature_map, geometry, i)
                    else:
                        feature_map = geojson_parse_geometry(
                            zoom, srid, feature_map, feature["geometry"], i)
        features = args.geojson

    if args.postgis:

        pg_conn = psycopg2.connect(args.pg_dsn)
        pg = pg_conn.cursor()

        pg.execute(
            "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(
                args.postgis))
        try:
            srid = pg.fetchone()[0]
        except Exception:
            sys.exit("Unable to retrieve geometry SRID.")

        features = args.postgis

    log.log(
        "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format(
            args.type, features, args.cover))
    with open(os.path.join(os.path.expanduser(args.out), "instances.cover"),
              mode="w") as cover:

        for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))),
                         ascii=True,
                         unit="tile"):

            if args.postgis:

                s, w, e, n = mercantile.bounds(tile)

                query = """
                WITH
                  a AS ({}),
                  b AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}) AS geom)
                SELECT '{{
  "type": "FeatureCollection", "features": [{{"type": "Feature", "geometry": '
  || ST_AsGeoJSON(ST_Transform(ST_Intersection(a.geom, b.geom), 4326), 6)
  || '}}]}}'
                FROM a, b
                WHERE ST_Intersects(a.geom, b.geom)
                """.format(args.postgis, s, w, e, n, srid)

                try:
                    pg.execute(query)
                    row = pg.fetchone()
                    geojson = json.loads(row[0])["features"] if row else None

                except Exception:
                    log.log("Warning: Invalid geometries, skipping {}".format(
                        tile))
                    pg_conn = psycopg2.connect(args.pg_dsn)
                    pg = pg_conn.cursor()

            if args.geojson:
                geojson = feature_map[tile] if tile in feature_map else None

            if geojson:
                num = len(geojson)
                out = geojson_tile_burn(tile, geojson, 4326, args.ts,
                                        burn_value)

            if not geojson or out is None:
                num = 0
                out = np.zeros(shape=(args.ts, args.ts), dtype=np.uint8)

            tile_label_to_file(args.out, tile, palette, out)
            cover.write("{},{},{}  {}{}".format(tile.x, tile.y, tile.z, num,
                                                os.linesep))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)