Ejemplo n.º 1
0
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    palette = make_palette([classe["color"] for classe in config["classes"]])
    args.workers = torch.cuda.device_count() * 2 if torch.device(
        "cuda") and not args.workers else args.workers
    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        log.log("")
        log.log("============================================================")
        log.log("WARNING: Are you -really- sure about not predicting on GPU ?")
        log.log("============================================================")
        log.log("")
        device = torch.device("cpu")

    chkpt = torch.load(args.checkpoint, map_location=device)
    model_module = load_module("robosat_pink.models.{}".format(
        chkpt["nn"].lower()))
    nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"],
                                            chkpt["shape_out"]).to(device)
    nn = torch.nn.DataParallel(nn)
    nn.load_state_dict(chkpt["state_dict"])
    nn.eval()

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    mode = "predict" if not args.translate else "predict_translate"
    loader_module = load_module("robosat_pink.loaders.{}".format(
        chkpt["loader"].lower()))
    loader_predict = getattr(loader_module,
                             chkpt["loader"])(config,
                                              chkpt["shape_in"][1:3],
                                              args.dataset,
                                              cover,
                                              mode=mode)

    loader = DataLoader(loader_predict,
                        batch_size=args.bs,
                        num_workers=args.workers)
    assert len(loader), "Empty predict dataset directory. Check your path."

    tiled = []
    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):

            images = images.to(device)

            outputs = nn(images)
            probs = torch.nn.functional.softmax(outputs,
                                                dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))
                mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()
                if args.translate:
                    tile_translate_to_file(args.out, mercantile.Tile(x, y, z),
                                           palette, mask)
                else:
                    tile_label_to_file(args.out, mercantile.Tile(x, y, z),
                                       palette, mask)
                tiled.append(mercantile.Tile(x, y, z))

    if not args.no_web_ui and not args.translate:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiled, tiled, "png", template)
Ejemplo n.º 2
0
def main(args):
    config = load_config(args.config)
    args.out = os.path.expanduser(args.out)
    config["model"][
        "loader"] = args.loader if args.loader else config["model"]["loader"]
    config["model"]["bs"] = args.bs if args.bs else config["model"]["bs"]
    config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"]
    config["model"]["ts"] = tuple(map(
        int, args.ts.split(","))) if args.ts else config["model"]["ts"]
    config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"]
    config["model"]["encoder"] = args.encoder if args.encoder else config[
        "model"]["encoder"]
    config["model"][
        "loss"] = args.loss if args.loss else config["model"]["loss"]
    config["model"]["da"] = args.da if args.da else config["model"]["da"]
    config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"]
    args.workers = config["model"]["bs"] if not args.workers else args.workers
    check_classes(config)
    check_channels(config)
    check_model(config)

    assert os.path.isdir(os.path.expanduser(
        args.dataset)), "Dataset is not a directory"
    if args.no_training and args.no_validation:
        sys.exit()

    log = Logs(os.path.join(args.out, "log"))
    csv_train = None if args.no_training else open(
        os.path.join(args.out, "training.csv"), mode="a")
    csv_val = None if args.no_validation else open(
        os.path.join(args.out, "validation.csv"), mode="a")

    if torch.cuda.is_available():
        log.log("RoboSat.pink - training on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - training on CPU, with {} workers - (Torch:{})".
                format(args.workers, torch.__version__))
        log.log("")
        log.log("==========================================================")
        log.log("WARNING: Are you -really- sure about not training on GPU ?")
        log.log("==========================================================")
        log.log("")
        device = torch.device("cpu")

    log.log("--- Input tensor from Dataset: {} ---".format(args.dataset))
    num_channel = 1  # 1-based numerotation
    for channel in config["channels"]:
        for band in channel["bands"]:
            log.log("Channel {}:\t\t {}[band: {}]".format(
                num_channel, channel["name"], band))
            num_channel += 1

    log.log("--- Output Classes ---")
    for c, classe in enumerate(config["classes"]):
        log.log("Class {}:\t\t {}".format(c, classe["title"]))

    log.log("--- Hyper Parameters ---")
    for hp in config["model"]:
        log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp]))

    loader = load_module("robosat_pink.loaders.{}".format(
        config["model"]["loader"].lower()))
    loader_train = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"], os.path.join(args.dataset,
                                                    "training"), None, "train")
    loader_val = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"],
        os.path.join(args.dataset, "validation"), None, "train")

    encoder = config["model"]["encoder"].lower()
    nn_module = load_module("robosat_pink.nn.{}".format(
        config["model"]["nn"].lower()))
    nn = getattr(nn_module, config["model"]["nn"])(loader_train.shape_in,
                                                   loader_train.shape_out,
                                                   encoder, config).to(device)
    nn = torch.nn.DataParallel(nn)
    optimizer = Adam(nn.parameters(), lr=config["model"]["lr"])

    resume = 0
    if args.checkpoint:
        chkpt = torch.load(os.path.expanduser(args.checkpoint),
                           map_location=device)
        nn.load_state_dict(chkpt["state_dict"])
        log.log("--- Using Checkpoint ---")
        log.log("Path:\t\t {}".format(args.checkpoint))
        log.log("UUID:\t\t {}".format(chkpt["uuid"]))

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]
            assert resume < args.epochs, "Epoch asked, already reached by the given checkpoint"

    loss_module = load_module("robosat_pink.losses.{}".format(
        config["model"]["loss"].lower()))
    criterion = getattr(loss_module, config["model"]["loss"])().to(device)

    bs = config["model"]["bs"]
    train_loader = DataLoader(loader_train,
                              batch_size=bs,
                              shuffle=True,
                              drop_last=True,
                              num_workers=args.workers)
    val_loader = DataLoader(loader_val,
                            batch_size=bs,
                            shuffle=False,
                            drop_last=True,
                            num_workers=args.workers)

    if args.no_training:
        epoch = 0
        process(val_loader, config, log, csv_val, epoch, device, nn, criterion,
                "eval")
        sys.exit()

    for epoch in range(resume + 1, args.epochs + 1):  # 1-N based
        UUID = uuid.uuid1()
        log.log("---{}Epoch: {}/{} -- UUID: {}".format(os.linesep, epoch,
                                                       args.epochs, UUID))

        process(train_loader, config, log, csv_train, epoch, device, nn,
                criterion, "train", optimizer)

        try:  # https://github.com/pytorch/pytorch/issues/9176
            nn_doc = nn.module.doc
            nn_version = nn.module.version
        except AttributeError:
            nn_doc = nn.doc
            nn_version = nn.version

        states = {
            "uuid": UUID,
            "model_version": nn_version,
            "producer_name": "RoboSat.pink",
            "producer_version": rsp.__version__,
            "model_licence": "MIT",
            "domain": "pink.RoboSat",  # reverse-DNS
            "doc_string": nn_doc,
            "shape_in": loader_train.shape_in,
            "shape_out": loader_train.shape_out,
            "state_dict": nn.state_dict(),
            "epoch": epoch,
            "nn": config["model"]["nn"],
            "encoder": config["model"]["encoder"],
            "optimizer": optimizer.state_dict(),
            "loader": config["model"]["loader"],
        }
        checkpoint_path = os.path.join(args.out,
                                       "checkpoint-{:05d}.pth".format(epoch))
        if epoch == args.epochs or not (epoch % args.saving):
            log.log("[Saving checkpoint]")
            torch.save(states, checkpoint_path)

        if not args.no_validation:
            process(val_loader, config, log, csv_val, epoch, device, nn,
                    criterion, "eval")
Ejemplo n.º 3
0
def main(args):

    tiles = list(tiles_from_csv(args.cover))
    os.makedirs(os.path.expanduser(args.out), exist_ok=True)

    if not args.workers:
        args.workers = max(1, math.floor(os.cpu_count() * 0.5))

    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)
    log.log(
        "RoboSat.pink - download with {} workers, at max {} req/s, from: {}".
        format(args.workers, args.rate, args.url))

    already_dl = 0
    dl = 0

    with requests.Session() as session:

        progress = tqdm(total=len(tiles), ascii=True, unit="image")
        with futures.ThreadPoolExecutor(args.workers) as executor:

            def worker(tile):
                tick = time.monotonic()
                progress.update()

                try:
                    x, y, z = map(str, [tile.x, tile.y, tile.z])
                    os.makedirs(os.path.join(args.out, z, x), exist_ok=True)
                except:
                    return tile, None, False

                path = os.path.join(args.out, z, x,
                                    "{}.{}".format(y, args.format))
                if os.path.isfile(path):  # already downloaded
                    return tile, None, True

                if args.type == "XYZ":
                    url = args.url.format(x=tile.x, y=tile.y, z=tile.z)
                elif args.type == "TMS":
                    y = (2**tile.z) - tile.y - 1
                    url = args.url.format(x=tile.x, y=y, z=tile.z)
                elif args.type == "WMS":
                    xmin, ymin, xmax, ymax = xy_bounds(tile)
                    url = args.url.format(xmin=xmin,
                                          ymin=ymin,
                                          xmax=xmax,
                                          ymax=ymax)

                res = tile_image_from_url(session, url, args.timeout)
                if res is None:  # let's retry once
                    res = tile_image_from_url(session, url, args.timeout)
                    if res is None:
                        return tile, url, False

                try:
                    tile_image_to_file(args.out, tile, res)
                except OSError:
                    return tile, url, False

                tock = time.monotonic()

                time_for_req = tock - tick
                time_per_worker = args.workers / args.rate

                if time_for_req < time_per_worker:
                    time.sleep(time_per_worker - time_for_req)

                return tile, url, True

            for tile, url, ok in executor.map(worker, tiles):
                if url and ok:
                    dl += 1
                elif not url and ok:
                    already_dl += 1
                else:
                    log.log("Warning:\n {} failed, skipping.\n {}\n".format(
                        tile, url))

    if already_dl:
        log.log(
            "Notice: {} tiles were already downloaded previously, and so skipped now."
            .format(already_dl))
    if already_dl + dl == len(tiles):
        log.log("Notice: Coverage is fully downloaded.")

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        web_ui(args.out, base_url, tiles, tiles, args.format, template)
Ejemplo n.º 4
0
def main(args):

    args.out = os.path.expanduser(args.out)

    if not args.workers:
        args.workers = os.cpu_count()

    print("RoboSat.pink - compare {} on CPU, with {} workers".format(args.mode, args.workers), file=sys.stderr, flush=True)

    if not args.masks or not args.labels:
        assert args.mode != "list", "Parameters masks and labels are mandatories in list mode."
        assert args.minimum_fg == 0.0 and args.maximum_fg == 100.0, "Both masks and labels mandatory in QoD filtering."
        assert args.minimum_qod == 0.0 and args.maximum_qod == 100.0, "Both masks and labels mandatory in QoD filtering."

    if args.images:
        tiles = [tile for tile in tiles_from_dir(args.images[0])]
        for image in args.images[1:]:
            assert sorted(tiles) == sorted([tile for tile in tiles_from_dir(image)]), "Unconsistent images directories"

    if args.labels and args.masks:
        tiles_masks = [tile for tile in tiles_from_dir(args.masks)]
        tiles_labels = [tile for tile in tiles_from_dir(args.labels)]
        if args.images:
            assert sorted(tiles) == sorted(tiles_masks) == sorted(tiles_labels), "Unconsistent images/label/mask directories"
        else:
            assert sorted(tiles_masks) == sorted(tiles_labels), "Label and Mask directories are not consistent"
            tiles = tiles_masks

    tiles_list = []
    tiles_compare = []
    progress = tqdm(total=len(tiles), ascii=True, unit="tile")
    log = False if args.mode == "list" else Logs(os.path.join(args.out, "log"))

    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(tile):
            x, y, z = list(map(str, tile))

            if args.masks and args.labels:

                label = np.array(Image.open(os.path.join(args.labels, z, x, "{}.png".format(y))))
                mask = np.array(Image.open(os.path.join(args.masks, z, x, "{}.png".format(y))))

                assert label.shape == mask.shape, "Inconsistent tiles (size or dimensions)"

                try:
                    dist, fg_ratio, qod = compare(torch.as_tensor(label, device="cpu"), torch.as_tensor(mask, device="cpu"))
                except:
                    progress.update()
                    return False, tile

                if not args.minimum_fg <= fg_ratio <= args.maximum_fg or not args.minimum_qod <= qod <= args.maximum_qod:
                    progress.update()
                    return True, tile

            tiles_compare.append(tile)

            if args.mode == "side":
                for i, root in enumerate(args.images):
                    img = tile_image_from_file(tile_from_xyz(root, x, y, z)[1])

                    if i == 0:
                        side = np.zeros((img.shape[0], img.shape[1] * len(args.images), 3))
                        side = np.swapaxes(side, 0, 1) if args.vertical else side
                        image_shape = img.shape
                    else:
                        assert image_shape[0:2] == img.shape[0:2], "Unconsistent image size to compare"

                    if args.vertical:
                        side[i * image_shape[0] : (i + 1) * image_shape[0], :, :] = img
                    else:
                        side[:, i * image_shape[0] : (i + 1) * image_shape[0], :] = img

                tile_image_to_file(args.out, tile, np.uint8(side))

            elif args.mode == "stack":
                for i, root in enumerate(args.images):
                    tile_image = tile_image_from_file(tile_from_xyz(root, x, y, z)[1])

                    if i == 0:
                        image_shape = tile_image.shape[0:2]
                        stack = tile_image / len(args.images)
                    else:
                        assert image_shape == tile_image.shape[0:2], "Unconsistent image size to compare"
                        stack = stack + (tile_image / len(args.images))

                tile_image_to_file(args.out, tile, np.uint8(stack))

            elif args.mode == "list":
                tiles_list.append([tile, fg_ratio, qod])

            progress.update()
            return True, tile

        for tile, ok in executor.map(worker, tiles):
            if not ok and log:
                log.log("Warning: skipping. {}".format(str(tile)))

    if args.mode == "list":
        with open(args.out, mode="w") as out:

            if args.geojson:
                out.write('{"type":"FeatureCollection","features":[')

            first = True
            for tile_list in tiles_list:
                tile, fg_ratio, qod = tile_list
                x, y, z = list(map(str, tile))
                if args.geojson:
                    prop = '"properties":{{"x":{},"y":{},"z":{},"fg":{:.1f},"qod":{:.1f}}}'.format(x, y, z, fg_ratio, qod)
                    geom = '"geometry":{}'.format(json.dumps(feature(tile, precision=6)["geometry"]))
                    out.write('{}{{"type":"Feature",{},{}}}'.format("," if not first else "", geom, prop))
                    first = False
                else:
                    out.write("{},{},{}\t{:.1f}\t{:.1f}{}".format(x, y, z, fg_ratio, qod, os.linesep))

            if args.geojson:
                out.write("]}")
            out.close()

    base_url = args.web_ui_base_url if args.web_ui_base_url else "."

    if args.mode == "side" and not args.no_web_ui:
        template = "compare.html" if not args.web_ui_template else args.web_ui_template
        web_ui(args.out, base_url, tiles, tiles_compare, args.format, template, union_tiles=False)

    if args.mode == "stack" and not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile in tiles_from_dir(args.images[0])]
        web_ui(args.out, base_url, tiles, tiles_compare, args.format, template)
Ejemplo n.º 5
0
def main(args):

    assert not (args.sql and
                args.geojson), "You can only use at once --pg OR --geojson."
    assert not (args.pg and not args.sql
                ), "With PostgreSQL --pg, --sql must also be provided"
    assert len(args.ts.split(
        ",")) == 2, "--ts expect width,height value (e.g 512,512)"

    config = load_config(args.config)
    check_classes(config)

    palette = make_palette([classe["color"] for classe in config["classes"]],
                           complementary=True)
    index = [
        config["classes"].index(classe) for classe in config["classes"]
        if classe["title"] == args.type
    ]
    assert index, "Requested type is not contains in your config file classes."
    burn_value = int(math.pow(2, index[0] - 1))  # 8bits One Hot Encoding
    assert 0 <= burn_value <= 128

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    if args.geojson:

        tiles = [
            tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
        ]
        assert tiles, "Empty cover"

        zoom = tiles[0].z
        assert not [tile for tile in tiles if tile.z != zoom
                    ], "Unsupported zoom mixed cover. Use PostGIS instead"

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                feature_collection = json.load(geojson)
                srid = geojson_srid(feature_collection)

                feature_map = collections.defaultdict(list)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):
                    feature_map = geojson_parse_feature(
                        zoom, srid, feature_map, feature)

        features = args.geojson

    if args.pg:

        conn = psycopg2.connect(args.pg)
        db = conn.cursor()

        assert "limit" not in args.sql.lower(), "LIMIT is not supported"
        assert "TILE_GEOM" in args.sql, "TILE_GEOM filter not found in your SQL"
        sql = re.sub(r"ST_Intersects( )*\((.*)?TILE_GEOM(.*)?\)", "1=1",
                     args.sql, re.I)
        assert sql and sql != args.sql

        db.execute(
            """SELECT ST_Srid("1") AS srid FROM ({} LIMIT 1) AS t("1")""".
            format(sql))
        srid = db.fetchone()[0]
        assert srid and int(srid) > 0, "Unable to retrieve geometry SRID."

        features = args.sql

    log.log(
        "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format(
            args.type, features, args.cover))
    with open(os.path.join(os.path.expanduser(args.out),
                           "instances_" + args.type.lower() + ".cover"),
              mode="w") as cover:

        for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))),
                         ascii=True,
                         unit="tile"):

            geojson = None

            if args.pg:

                w, s, e, n = tile_bbox(tile)
                tile_geom = "ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {})".format(
                    w, s, e, n, srid)

                query = """
                WITH
                  sql  AS ({}),
                  geom AS (SELECT "1" AS geom FROM sql AS t("1")),
                  json AS (SELECT '{{"type": "Feature", "geometry": '
                         || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
                         || '}}' AS features
                        FROM geom)
                SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}'
                FROM json
                """.format(args.sql.replace("TILE_GEOM", tile_geom))

                db.execute(query)
                row = db.fetchone()
                try:
                    geojson = json.loads(
                        row[0])["features"] if row and row[0] else None
                except Exception:
                    log.log("Warning: Invalid geometries, skipping {}".format(
                        tile))
                    conn = psycopg2.connect(args.pg)
                    db = conn.cursor()

            if args.geojson:
                geojson = feature_map[tile] if tile in feature_map else None

            if geojson:
                num = len(geojson)
                out = geojson_tile_burn(tile, geojson, 4326,
                                        list(map(int, args.ts.split(","))),
                                        burn_value)

            if not geojson or out is None:
                num = 0
                out = np.zeros(shape=list(map(int, args.ts.split(","))),
                               dtype=np.uint8)

            tile_label_to_file(args.out,
                               tile,
                               palette,
                               out,
                               append=args.append)
            cover.write("{},{},{}  {}{}".format(tile.x, tile.y, tile.z, num,
                                                os.linesep))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
Ejemplo n.º 6
0
def main(args):
    config = load_config(args.config)
    args.out = os.path.expanduser(args.out)
    args.workers = torch.cuda.device_count() * 2 if torch.device("cuda") and not args.workers else args.workers
    config["model"]["loader"] = args.loader if args.loader else config["model"]["loader"]
    config["model"]["bs"] = args.bs if args.bs else config["model"]["bs"]
    config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"]
    config["model"]["ts"] = args.ts if args.ts else config["model"]["ts"]
    config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"]
    config["model"]["loss"] = args.loss if args.loss else config["model"]["loss"]
    config["model"]["da"] = args.da if args.da else config["model"]["da"]
    config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"]
    check_classes(config)
    check_channels(config)
    check_model(config)

    if not os.path.isdir(os.path.expanduser(args.dataset)):
        sys.exit("ERROR: dataset {} is not a directory".format(args.dataset))

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - training on {} GPUs, with {} workers".format(torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(torch.__version__, torch.version.cuda, torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - training on CPU, with {} workers - (Torch:{})".format(args.workers, torch.__version__))
        log.log("WARNING: Are you really sure sure about not training on GPU ?")
        device = torch.device("cpu")

    loader = load_module("robosat_pink.loaders.{}".format(config["model"]["loader"].lower()))
    loader_train = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"], os.path.join(args.dataset, "training"), "train"
    )
    loader_val = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"], os.path.join(args.dataset, "validation"), "train"
    )

    model_module = load_module("robosat_pink.models.{}".format(config["model"]["nn"].lower()))

    nn = getattr(model_module, config["model"]["nn"])(loader_train.shape_in, loader_train.shape_out, config).to(device)
    nn = torch.nn.DataParallel(nn)
    optimizer = Adam(nn.parameters(), lr=config["model"]["lr"])

    resume = 0
    if args.checkpoint:
        chkpt = torch.load(os.path.expanduser(args.checkpoint), map_location=device)
        nn.load_state_dict(chkpt["state_dict"])
        log.log("Using checkpoint: {}".format(args.checkpoint))

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]
            if resume >= args.epochs:
                sys.exit("ERROR: Epoch {} already reached by the given checkpoint".format(config["model"]["epochs"]))

    loss_module = load_module("robosat_pink.losses.{}".format(config["model"]["loss"].lower()))
    criterion = getattr(loss_module, config["model"]["loss"])().to(device)

    bs = config["model"]["bs"]
    train_loader = DataLoader(loader_train, batch_size=bs, shuffle=True, drop_last=True, num_workers=args.workers)
    val_loader = DataLoader(loader_val, batch_size=bs, shuffle=False, drop_last=True, num_workers=args.workers)

    log.log("--- Input tensor from Dataset: {} ---".format(args.dataset))
    num_channel = 1  # 1-based numerotation
    for channel in config["channels"]:
        for band in channel["bands"]:
            log.log("Channel {}:\t\t {}[band: {}]".format(num_channel, channel["name"], band))
            num_channel += 1

    log.log("--- Hyper Parameters ---")
    for hp in config["model"]:
        log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp]))

    for epoch in range(resume, args.epochs):
        UUID = uuid.uuid1()
        log.log("---{}Epoch: {}/{} -- UUID: {}".format(os.linesep, epoch + 1, args.epochs, UUID))

        process(train_loader, config, log, device, nn, criterion, "train", optimizer)
        if not args.no_validation:
            process(val_loader, config, log, device, nn, criterion, "eval")

        try:  # https://github.com/pytorch/pytorch/issues/9176
            nn_doc = nn.module.doc
            nn_version = nn.module.version
        except AttributeError:
            nn_version = nn.version
            nn_doc == nn.doc

        states = {
            "uuid": UUID,
            "model_version": nn_version,
            "producer_name": "RoboSat.pink",
            "producer_version": "0.4.0",
            "model_licence": "MIT",
            "domain": "pink.RoboSat",  # reverse-DNS
            "doc_string": nn_doc,
            "shape_in": loader_train.shape_in,
            "shape_out": loader_train.shape_out,
            "state_dict": nn.state_dict(),
            "epoch": epoch + 1,
            "nn": config["model"]["nn"],
            "optimizer": optimizer.state_dict(),
            "loader": config["model"]["loader"],
        }
        checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}.pth".format(epoch + 1))
        torch.save(states, checkpoint_path)
Ejemplo n.º 7
0
def main(args):

    if args.pg:
        if not args.sql:
            sys.exit("ERROR: With PostgreSQL db, --sql must be provided")

    if (args.sql and args.geojson) or (args.sql and not args.pg):
        sys.exit(
            "ERROR: You can use either --pg or --geojson inputs, but only one at once."
        )

    config = load_config(args.config)
    check_classes(config)
    palette = make_palette(*[classe["color"] for classe in config["classes"]],
                           complementary=True)
    burn_value = next(config["classes"].index(classe)
                      for classe in config["classes"]
                      if classe["title"] == args.type)
    if "burn_value" not in locals():
        sys.exit(
            "ERROR: asked type to rasterize is not contains in your config file classes."
        )

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def geojson_parse_polygon(zoom, srid, feature_map, polygon, i):

        try:
            if srid != 4326:
                polygon = [
                    xy for xy in geojson_reproject(
                        {
                            "type": "feature",
                            "geometry": polygon
                        }, srid, 4326)
                ][0]

            for i, ring in enumerate(
                    polygon["coordinates"]
            ):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[
                    x, y
                ] for point in ring for x, y in zip([point[0]], [point[1]])]

            if polygon["coordinates"]:
                for tile in burntiles.burn([{
                        "type": "feature",
                        "geometry": polygon
                }],
                                           zoom=zoom):
                    feature_map[mercantile.Tile(*tile)].append({
                        "type":
                        "feature",
                        "geometry":
                        polygon
                    })

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def geojson_parse_geometry(zoom, srid, feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = geojson_parse_polygon(zoom, srid, feature_map,
                                                geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = geojson_parse_polygon(zoom, srid, feature_map, {
                    "type": "Polygon",
                    "coordinates": polygon
                }, i)
        else:
            log.log(
                "Notice: {} is a non surfacic geometry type, skipping feature {}"
                .format(geometry["type"], i))

        return feature_map

    if args.geojson:

        tiles = [
            tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
        ]
        assert tiles, "Empty cover"

        zoom = tiles[0].z
        assert not [tile for tile in tiles if tile.z != zoom
                    ], "Unsupported zoom mixed cover. Use PostGIS instead"

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                feature_collection = json.load(geojson)

                try:
                    crs_mapping = {"CRS84": "4326", "900913": "3857"}
                    srid = feature_collection["crs"]["properties"][
                        "name"].split(":")[-1]
                    srid = int(srid) if srid not in crs_mapping else int(
                        crs_mapping[srid])
                except:
                    srid = int(4326)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):

                    if feature["geometry"]["type"] == "GeometryCollection":
                        for geometry in feature["geometry"]["geometries"]:
                            feature_map = geojson_parse_geometry(
                                zoom, srid, feature_map, geometry, i)
                    else:
                        feature_map = geojson_parse_geometry(
                            zoom, srid, feature_map, feature["geometry"], i)
        features = args.geojson

    if args.pg:

        conn = psycopg2.connect(args.pg)
        db = conn.cursor()

        assert "limit" not in args.sql.lower(), "LIMIT is not supported"
        db.execute(
            "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(
                args.sql))
        srid = db.fetchone()[0]
        assert srid, "Unable to retrieve geometry SRID."

        if "where" not in args.sql.lower(
        ):  # TODO: Find a more reliable way to handle feature filtering
            args.sql += " WHERE ST_Intersects(tile.geom, geom)"
        else:
            args.sql += " AND ST_Intersects(tile.geom, geom)"
        features = args.sql

    log.log(
        "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format(
            args.type, features, args.cover))
    with open(os.path.join(os.path.expanduser(args.out), "instances.cover"),
              mode="w") as cover:

        for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))),
                         ascii=True,
                         unit="tile"):

            geojson = None

            if args.pg:

                w, s, e, n = tile_bbox(tile)

                query = """
                WITH
                  tile AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}) AS geom),
                  geom AS (SELECT ST_Intersection(tile.geom, sql.geom) AS geom FROM tile CROSS JOIN LATERAL ({}) sql),
                  json AS (SELECT '{{"type": "Feature", "geometry": '
                         || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
                         || '}}' AS features
                        FROM geom)
                SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}'
                FROM json
                """.format(w, s, e, n, srid, args.sql)

                db.execute(query)
                row = db.fetchone()
                try:
                    geojson = json.loads(
                        row[0])["features"] if row and row[0] else None
                except Exception:
                    log.log("Warning: Invalid geometries, skipping {}".format(
                        tile))
                    conn = psycopg2.connect(args.pg)
                    db = conn.cursor()

            if args.geojson:
                geojson = feature_map[tile] if tile in feature_map else None

            if geojson:
                num = len(geojson)
                out = geojson_tile_burn(tile, geojson, 4326, args.ts,
                                        burn_value)

            if not geojson or out is None:
                num = 0
                out = np.zeros(shape=(args.ts, args.ts), dtype=np.uint8)

            tile_label_to_file(args.out, tile, palette, out)
            cover.write("{},{},{}  {}{}".format(tile.x, tile.y, tile.z, num,
                                                os.linesep))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
Ejemplo n.º 8
0
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    palette = make_palette([classe["color"] for classe in config["classes"]])
    if not args.bs:
        try:
            args.bs = config["model"]["bs"]
        except:
            pass

    assert args.bs, "For rsp predict, model/bs must be set either in config file, or pass trought parameter --bs"
    args.workers = args.bs if not args.workers else args.workers
    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        log.log("")
        log.log("============================================================")
        log.log("WARNING: Are you -really- sure about not predicting on GPU ?")
        log.log("============================================================")
        log.log("")
        device = torch.device("cpu")

    chkpt = torch.load(args.checkpoint, map_location=device)
    nn_module = load_module("robosat_pink.nn.{}".format(chkpt["nn"].lower()))
    nn = getattr(nn_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"],
                                         chkpt["encoder"].lower()).to(device)
    nn = torch.nn.DataParallel(nn)
    nn.load_state_dict(chkpt["state_dict"])
    nn.eval()

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        tiled = []
        if args.passes in ["first", "both"]:
            log.log("== Predict First Pass ==")
            tiled = predict(config, cover, args, palette, chkpt, nn, device,
                            "predict")

        if args.passes in ["second", "both"]:
            log.log("== Predict Second Pass ==")
            predict(config, cover, args, palette, chkpt, nn, device,
                    "predict_translate")

    if not args.no_web_ui and tiled:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiled, tiled, "png", template)
Ejemplo n.º 9
0
def main(args):

    if (args.geojson and args.postgis) or (not args.geojson
                                           and not args.postgis):
        sys.exit(
            "ERROR: Input features to rasterize must be either GeoJSON or PostGIS"
        )

    if args.postgis and not args.pg_dsn:
        sys.exit(
            "ERROR: With PostGIS input features, --pg_dsn must be provided")

    config = load_config(args.config)
    check_classes(config)
    palette = make_palette(*[classe["color"] for classe in config["classes"]],
                           complementary=True)
    burn_value = next(config["classes"].index(classe)
                      for classe in config["classes"]
                      if classe["title"] == args.type)
    if "burn_value" not in locals():
        sys.exit(
            "ERROR: asked type to rasterize is not contains in your config file classes."
        )

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def geojson_parse_polygon(zoom, srid, feature_map, polygon, i):

        try:
            if srid != 4326:
                polygon = [
                    xy for xy in geojson_reproject(
                        {
                            "type": "feature",
                            "geometry": polygon
                        }, srid, 4326)
                ][0]

            for i, ring in enumerate(
                    polygon["coordinates"]
            ):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[
                    x, y
                ] for point in ring for x, y in zip([point[0]], [point[1]])]

            if polygon["coordinates"]:
                for tile in burntiles.burn([{
                        "type": "feature",
                        "geometry": polygon
                }],
                                           zoom=zoom):
                    feature_map[mercantile.Tile(*tile)].append({
                        "type":
                        "feature",
                        "geometry":
                        polygon
                    })

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def geojson_parse_geometry(zoom, srid, feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = geojson_parse_polygon(zoom, srid, feature_map,
                                                geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = geojson_parse_polygon(zoom, srid, feature_map, {
                    "type": "Polygon",
                    "coordinates": polygon
                }, i)
        else:
            log.log(
                "Notice: {} is a non surfacic geometry type, skipping feature {}"
                .format(geometry["type"], i))

        return feature_map

    if args.geojson:

        try:
            tiles = [
                tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
            ]
            zoom = tiles[0].z
            assert not [tile for tile in tiles if tile.z != zoom]
        except:
            sys.exit("ERROR: Inconsistent cover {}".format(args.cover))

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                feature_collection = json.load(geojson)

                try:
                    crs_mapping = {"CRS84": "4326", "900913": "3857"}
                    srid = feature_collection["crs"]["properties"][
                        "name"].split(":")[-1]
                    srid = int(srid) if srid not in crs_mapping else int(
                        crs_mapping[srid])
                except:
                    srid = int(4326)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):

                    if feature["geometry"]["type"] == "GeometryCollection":
                        for geometry in feature["geometry"]["geometries"]:
                            feature_map = geojson_parse_geometry(
                                zoom, srid, feature_map, geometry, i)
                    else:
                        feature_map = geojson_parse_geometry(
                            zoom, srid, feature_map, feature["geometry"], i)
        features = args.geojson

    if args.postgis:

        pg_conn = psycopg2.connect(args.pg_dsn)
        pg = pg_conn.cursor()

        pg.execute(
            "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(
                args.postgis))
        try:
            srid = pg.fetchone()[0]
        except Exception:
            sys.exit("Unable to retrieve geometry SRID.")

        features = args.postgis

    log.log(
        "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format(
            args.type, features, args.cover))
    with open(os.path.join(os.path.expanduser(args.out), "instances.cover"),
              mode="w") as cover:

        for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))),
                         ascii=True,
                         unit="tile"):

            if args.postgis:

                s, w, e, n = mercantile.bounds(tile)

                query = """
                WITH
                  a AS ({}),
                  b AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}) AS geom)
                SELECT '{{
  "type": "FeatureCollection", "features": [{{"type": "Feature", "geometry": '
  || ST_AsGeoJSON(ST_Transform(ST_Intersection(a.geom, b.geom), 4326), 6)
  || '}}]}}'
                FROM a, b
                WHERE ST_Intersects(a.geom, b.geom)
                """.format(args.postgis, s, w, e, n, srid)

                try:
                    pg.execute(query)
                    row = pg.fetchone()
                    geojson = json.loads(row[0])["features"] if row else None

                except Exception:
                    log.log("Warning: Invalid geometries, skipping {}".format(
                        tile))
                    pg_conn = psycopg2.connect(args.pg_dsn)
                    pg = pg_conn.cursor()

            if args.geojson:
                geojson = feature_map[tile] if tile in feature_map else None

            if geojson:
                num = len(geojson)
                out = geojson_tile_burn(tile, geojson, 4326, args.ts,
                                        burn_value)

            if not geojson or out is None:
                num = 0
                out = np.zeros(shape=(args.ts, args.ts), dtype=np.uint8)

            tile_label_to_file(args.out, tile, palette, out)
            cover.write("{},{},{}  {}{}".format(tile.x, tile.y, tile.z, num,
                                                os.linesep))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)