def main(args): config = load_config(args.config) check_classes(config) index = [i for i in (list(range(len(config["classes"])))) if config["classes"][i]["title"] == args.type] assert index, "Requested type {} not found among classes title in the config file.".format(args.type) print("RoboSat.pink - vectorize {} from {}".format(args.type, args.masks), file=sys.stderr, flush=True) out = open(args.out, "w", encoding="utf-8") assert out, "Unable to write in output file" out.write('{"type":"FeatureCollection","features":[') first = True for tile, path in tqdm(list(tiles_from_dir(args.masks, xyz_path=True)), ascii=True, unit="mask"): mask = (np.array(Image.open(path).convert("P"), dtype=np.uint8) == index).astype(np.uint8) try: C, W, H = mask.shape except: W, H = mask.shape transform = rasterio.transform.from_bounds((*mercantile.bounds(tile.x, tile.y, tile.z)), W, H) for shape, value in rasterio.features.shapes(mask, transform=transform, mask=mask): geom = '"geometry":{{"type": "Polygon", "coordinates":{}}}'.format(json.dumps(shape["coordinates"])) out.write('{}{{"type":"Feature",{}}}'.format("" if first else ",", geom)) first = False out.write("]}")
def main(args): config = load_config(args.config) check_classes(config) index = [ i for i in (list(range(len(config["classes"])))) if config["classes"][i]["title"] == args.type ] assert index, "Requested type {} not found among classes title in the config file.".format( args.type) print("RoboSat.pink - vectorize {} from {}".format(args.type, args.masks)) with open(args.out, "w", encoding="utf-8") as out: first = True out.write('{"type":"FeatureCollection","features":[') for tile, path in tqdm(list(tiles_from_slippy_map(args.masks)), ascii=True, unit="mask"): features = (np.array(Image.open(path).convert("P"), dtype=np.uint8) == index).astype(np.uint8) try: C, W, H = features.shape except: W, H = features.shape transform = rasterio.transform.from_bounds( (*mercantile.bounds(tile.x, tile.y, tile.z)), W, H) for shape, value in rasterio.features.shapes(features, transform=transform): prop = '"properties":{{"x":{},"y":{},"z":{}}}'.format( int(tile.x), int(tile.y), int(tile.z)) geom = '"geometry":{{"type": "Polygon", "coordinates":{}}}'.format( json.dumps(shape["coordinates"])) out.write('{}{{"type":"Feature",{},{}}}'.format( "," if not first else "", geom, prop)) first = False out.write("]}")
def main(args): if not args.workers: args.workers = min(os.cpu_count(), len(args.rasters)) if args.label: config = load_config(args.config) check_classes(config) colors = [classe["color"] for classe in config["classes"]] palette = make_palette(*colors) cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] if args.cover else None splits_path = os.path.join(os.path.expanduser(args.out), ".splits") tiles_map = {} print("RoboSat.pink - tile on CPU, with {} workers".format(args.workers)) bands = -1 for path in args.rasters: raster = rasterio_open(path) w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) if bands != -1: assert bands == len( raster.indexes), "Coverage must be bands consistent" bands = len(raster.indexes) tiles = [ mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom) ] tiles = list(set(tiles) & set(cover)) if cover else tiles for tile in tiles: tile_key = (str(tile.x), str(tile.y), str(tile.z)) if tile_key not in tiles_map.keys(): tiles_map[tile_key] = [] tiles_map[tile_key].append(path) if args.label: ext = "png" bands = 1 if not args.label: if bands == 1: ext = "png" if bands == 3: ext = "webp" if bands > 3: ext = "tiff" tiles = [] progress = tqdm(total=len(tiles_map), ascii=True, unit="tile") # Begin to tile plain tiles with futures.ThreadPoolExecutor(args.workers) as executor: def worker(path): raster = rasterio_open(path) w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) transform, _, _ = calculate_default_transform( raster.crs, "EPSG:3857", raster.width, raster.height, w, s, e, n) tiles = [ mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom) ] tiled = [] for tile in tiles: if cover and tile not in cover: continue w, s, e, n = mercantile.xy_bounds(tile) # inspired by rio-tiler, cf: https://github.com/mapbox/rio-tiler/pull/45 warp_vrt = WarpedVRT( raster, crs="epsg:3857", resampling=Resampling.bilinear, add_alpha=False, transform=from_bounds(w, s, e, n, args.ts, args.ts), width=math.ceil((e - w) / transform.a), height=math.ceil((s - n) / transform.e), ) data = warp_vrt.read(out_shape=(len(raster.indexes), args.ts, args.ts), window=warp_vrt.window(w, s, e, n)) image = np.moveaxis(data, 0, 2) # C,H,W -> H,W,C tile_key = (str(tile.x), str(tile.y), str(tile.z)) if not args.label and len( tiles_map[tile_key]) == 1 and is_nodata( image, threshold=args.nodata_threshold): progress.update() continue if len(tiles_map[tile_key]) > 1: out = os.path.join(splits_path, str(tiles_map[tile_key].index(path))) else: out = args.out x, y, z = map(int, tile) if not args.label: ret = tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image) if args.label: ret = tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, image) assert ret, "Unable to write tile {} from raster {}.".format( str(tile), raster) if len(tiles_map[tile_key]) == 1: progress.update() tiled.append(mercantile.Tile(x=x, y=y, z=z)) return tiled for tiled in executor.map(worker, args.rasters): if tiled is not None: tiles.extend(tiled) # Aggregate remaining tiles splits with futures.ThreadPoolExecutor(args.workers) as executor: def worker(tile_key): if len(tiles_map[tile_key]) == 1: return image = np.zeros((args.ts, args.ts, bands), np.uint8) x, y, z = map(int, tile_key) for i in range(len(tiles_map[tile_key])): root = os.path.join(splits_path, str(i)) _, path = tile_from_xyz(root, x, y, z) if not args.label: split = tile_image_from_file(path) if args.label: split = tile_label_from_file(path) split = split.reshape( (args.ts, args.ts, 1)) # H,W -> H,W,C assert image.shape == split.shape image[:, :, :] += split[:, :, :] if not args.label and is_nodata(image, threshold=args.nodata_threshold): progress.update() return tile = mercantile.Tile(x=x, y=y, z=z) if not args.label: ret = tile_image_to_file(args.out, tile, image) if args.label: ret = tile_label_to_file(args.out, tile, palette, image) assert ret, "Unable to write tile {} from raster {}.".format( str(tile_key)) progress.update() return tile for tiled in executor.map(worker, tiles_map.keys()): if tiled is not None: tiles.append(tiled) if splits_path and os.path.isdir(splits_path): shutil.rmtree(splits_path) # Delete suffixes dir if any if tiles and not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "." web_ui(args.out, base_url, tiles, tiles, ext, template)
def main(args): config = load_config(args.config) args.out = os.path.expanduser(args.out) config["model"][ "loader"] = args.loader if args.loader else config["model"]["loader"] config["model"]["bs"] = args.bs if args.bs else config["model"]["bs"] config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"] config["model"]["ts"] = tuple(map( int, args.ts.split(","))) if args.ts else config["model"]["ts"] config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"] config["model"]["encoder"] = args.encoder if args.encoder else config[ "model"]["encoder"] config["model"][ "loss"] = args.loss if args.loss else config["model"]["loss"] config["model"]["da"] = args.da if args.da else config["model"]["da"] config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"] args.workers = config["model"]["bs"] if not args.workers else args.workers check_classes(config) check_channels(config) check_model(config) assert os.path.isdir(os.path.expanduser( args.dataset)), "Dataset is not a directory" if args.no_training and args.no_validation: sys.exit() log = Logs(os.path.join(args.out, "log")) csv_train = None if args.no_training else open( os.path.join(args.out, "training.csv"), mode="a") csv_val = None if args.no_validation else open( os.path.join(args.out, "validation.csv"), mode="a") if torch.cuda.is_available(): log.log("RoboSat.pink - training on {} GPUs, with {} workers".format( torch.cuda.device_count(), args.workers)) log.log("(Torch:{} Cuda:{} CudNN:{})".format( torch.__version__, torch.version.cuda, torch.backends.cudnn.version())) device = torch.device("cuda") torch.backends.cudnn.benchmark = True else: log.log("RoboSat.pink - training on CPU, with {} workers - (Torch:{})". format(args.workers, torch.__version__)) log.log("") log.log("==========================================================") log.log("WARNING: Are you -really- sure about not training on GPU ?") log.log("==========================================================") log.log("") device = torch.device("cpu") log.log("--- Input tensor from Dataset: {} ---".format(args.dataset)) num_channel = 1 # 1-based numerotation for channel in config["channels"]: for band in channel["bands"]: log.log("Channel {}:\t\t {}[band: {}]".format( num_channel, channel["name"], band)) num_channel += 1 log.log("--- Output Classes ---") for c, classe in enumerate(config["classes"]): log.log("Class {}:\t\t {}".format(c, classe["title"])) log.log("--- Hyper Parameters ---") for hp in config["model"]: log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp])) loader = load_module("robosat_pink.loaders.{}".format( config["model"]["loader"].lower())) loader_train = getattr(loader, config["model"]["loader"])( config, config["model"]["ts"], os.path.join(args.dataset, "training"), None, "train") loader_val = getattr(loader, config["model"]["loader"])( config, config["model"]["ts"], os.path.join(args.dataset, "validation"), None, "train") encoder = config["model"]["encoder"].lower() nn_module = load_module("robosat_pink.nn.{}".format( config["model"]["nn"].lower())) nn = getattr(nn_module, config["model"]["nn"])(loader_train.shape_in, loader_train.shape_out, encoder, config).to(device) nn = torch.nn.DataParallel(nn) optimizer = Adam(nn.parameters(), lr=config["model"]["lr"]) resume = 0 if args.checkpoint: chkpt = torch.load(os.path.expanduser(args.checkpoint), map_location=device) nn.load_state_dict(chkpt["state_dict"]) log.log("--- Using Checkpoint ---") log.log("Path:\t\t {}".format(args.checkpoint)) log.log("UUID:\t\t {}".format(chkpt["uuid"])) if args.resume: optimizer.load_state_dict(chkpt["optimizer"]) resume = chkpt["epoch"] assert resume < args.epochs, "Epoch asked, already reached by the given checkpoint" loss_module = load_module("robosat_pink.losses.{}".format( config["model"]["loss"].lower())) criterion = getattr(loss_module, config["model"]["loss"])().to(device) bs = config["model"]["bs"] train_loader = DataLoader(loader_train, batch_size=bs, shuffle=True, drop_last=True, num_workers=args.workers) val_loader = DataLoader(loader_val, batch_size=bs, shuffle=False, drop_last=True, num_workers=args.workers) if args.no_training: epoch = 0 process(val_loader, config, log, csv_val, epoch, device, nn, criterion, "eval") sys.exit() for epoch in range(resume + 1, args.epochs + 1): # 1-N based UUID = uuid.uuid1() log.log("---{}Epoch: {}/{} -- UUID: {}".format(os.linesep, epoch, args.epochs, UUID)) process(train_loader, config, log, csv_train, epoch, device, nn, criterion, "train", optimizer) try: # https://github.com/pytorch/pytorch/issues/9176 nn_doc = nn.module.doc nn_version = nn.module.version except AttributeError: nn_doc = nn.doc nn_version = nn.version states = { "uuid": UUID, "model_version": nn_version, "producer_name": "RoboSat.pink", "producer_version": rsp.__version__, "model_licence": "MIT", "domain": "pink.RoboSat", # reverse-DNS "doc_string": nn_doc, "shape_in": loader_train.shape_in, "shape_out": loader_train.shape_out, "state_dict": nn.state_dict(), "epoch": epoch, "nn": config["model"]["nn"], "encoder": config["model"]["encoder"], "optimizer": optimizer.state_dict(), "loader": config["model"]["loader"], } checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}.pth".format(epoch)) if epoch == args.epochs or not (epoch % args.saving): log.log("[Saving checkpoint]") torch.save(states, checkpoint_path) if not args.no_validation: process(val_loader, config, log, csv_val, epoch, device, nn, criterion, "eval")
def main(args): config = load_config(args.config) check_channels(config) check_classes(config) palette = make_palette([classe["color"] for classe in config["classes"]]) args.workers = torch.cuda.device_count() * 2 if torch.device( "cuda") and not args.workers else args.workers cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] if args.cover else None log = Logs(os.path.join(args.out, "log")) if torch.cuda.is_available(): log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format( torch.cuda.device_count(), args.workers)) log.log("(Torch:{} Cuda:{} CudNN:{})".format( torch.__version__, torch.version.cuda, torch.backends.cudnn.version())) device = torch.device("cuda") torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True else: log.log("RoboSat.pink - predict on CPU, with {} workers".format( args.workers)) log.log("") log.log("============================================================") log.log("WARNING: Are you -really- sure about not predicting on GPU ?") log.log("============================================================") log.log("") device = torch.device("cpu") chkpt = torch.load(args.checkpoint, map_location=device) model_module = load_module("robosat_pink.models.{}".format( chkpt["nn"].lower())) nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"]).to(device) nn = torch.nn.DataParallel(nn) nn.load_state_dict(chkpt["state_dict"]) nn.eval() log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"])) mode = "predict" if not args.translate else "predict_translate" loader_module = load_module("robosat_pink.loaders.{}".format( chkpt["loader"].lower())) loader_predict = getattr(loader_module, chkpt["loader"])(config, chkpt["shape_in"][1:3], args.dataset, cover, mode=mode) loader = DataLoader(loader_predict, batch_size=args.bs, num_workers=args.workers) assert len(loader), "Empty predict dataset directory. Check your path." tiled = [] with torch.no_grad( ): # don't track tensors with autograd during prediction for images, tiles in tqdm(loader, desc="Eval", unit="batch", ascii=True): images = images.to(device) outputs = nn(images) probs = torch.nn.functional.softmax(outputs, dim=1).data.cpu().numpy() for tile, prob in zip(tiles, probs): x, y, z = list(map(int, tile)) mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze() if args.translate: tile_translate_to_file(args.out, mercantile.Tile(x, y, z), palette, mask) else: tile_label_to_file(args.out, mercantile.Tile(x, y, z), palette, mask) tiled.append(mercantile.Tile(x, y, z)) if not args.no_web_ui and not args.translate: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "." web_ui(args.out, base_url, tiled, tiled, "png", template)
def main(args): assert not (args.sql and args.geojson), "You can only use at once --pg OR --geojson." assert not (args.pg and not args.sql ), "With PostgreSQL --pg, --sql must also be provided" assert len(args.ts.split( ",")) == 2, "--ts expect width,height value (e.g 512,512)" config = load_config(args.config) check_classes(config) palette = make_palette([classe["color"] for classe in config["classes"]], complementary=True) index = [ config["classes"].index(classe) for classe in config["classes"] if classe["title"] == args.type ] assert index, "Requested type is not contains in your config file classes." burn_value = int(math.pow(2, index[0] - 1)) # 8bits One Hot Encoding assert 0 <= burn_value <= 128 args.out = os.path.expanduser(args.out) os.makedirs(args.out, exist_ok=True) log = Logs(os.path.join(args.out, "log"), out=sys.stderr) if args.geojson: tiles = [ tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] assert tiles, "Empty cover" zoom = tiles[0].z assert not [tile for tile in tiles if tile.z != zoom ], "Unsupported zoom mixed cover. Use PostGIS instead" feature_map = collections.defaultdict(list) log.log("RoboSat.pink - rasterize - Compute spatial index") for geojson_file in args.geojson: with open(os.path.expanduser(geojson_file)) as geojson: feature_collection = json.load(geojson) srid = geojson_srid(feature_collection) feature_map = collections.defaultdict(list) for i, feature in enumerate( tqdm(feature_collection["features"], ascii=True, unit="feature")): feature_map = geojson_parse_feature( zoom, srid, feature_map, feature) features = args.geojson if args.pg: conn = psycopg2.connect(args.pg) db = conn.cursor() assert "limit" not in args.sql.lower(), "LIMIT is not supported" assert "TILE_GEOM" in args.sql, "TILE_GEOM filter not found in your SQL" sql = re.sub(r"ST_Intersects( )*\((.*)?TILE_GEOM(.*)?\)", "1=1", args.sql, re.I) assert sql and sql != args.sql db.execute( """SELECT ST_Srid("1") AS srid FROM ({} LIMIT 1) AS t("1")""". format(sql)) srid = db.fetchone()[0] assert srid and int(srid) > 0, "Unable to retrieve geometry SRID." features = args.sql log.log( "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format( args.type, features, args.cover)) with open(os.path.join(os.path.expanduser(args.out), "instances_" + args.type.lower() + ".cover"), mode="w") as cover: for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))), ascii=True, unit="tile"): geojson = None if args.pg: w, s, e, n = tile_bbox(tile) tile_geom = "ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {})".format( w, s, e, n, srid) query = """ WITH sql AS ({}), geom AS (SELECT "1" AS geom FROM sql AS t("1")), json AS (SELECT '{{"type": "Feature", "geometry": ' || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6) || '}}' AS features FROM geom) SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}' FROM json """.format(args.sql.replace("TILE_GEOM", tile_geom)) db.execute(query) row = db.fetchone() try: geojson = json.loads( row[0])["features"] if row and row[0] else None except Exception: log.log("Warning: Invalid geometries, skipping {}".format( tile)) conn = psycopg2.connect(args.pg) db = conn.cursor() if args.geojson: geojson = feature_map[tile] if tile in feature_map else None if geojson: num = len(geojson) out = geojson_tile_burn(tile, geojson, 4326, list(map(int, args.ts.split(","))), burn_value) if not geojson or out is None: num = 0 out = np.zeros(shape=list(map(int, args.ts.split(","))), dtype=np.uint8) tile_label_to_file(args.out, tile, palette, out, append=args.append) cover.write("{},{},{} {}{}".format(tile.x, tile.y, tile.z, num, os.linesep)) if not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "." tiles = [tile for tile in tiles_from_csv(args.cover)] web_ui(args.out, base_url, tiles, tiles, "png", template)
def main(args): config = load_config(args.config) args.out = os.path.expanduser(args.out) args.workers = torch.cuda.device_count() * 2 if torch.device("cuda") and not args.workers else args.workers config["model"]["loader"] = args.loader if args.loader else config["model"]["loader"] config["model"]["bs"] = args.bs if args.bs else config["model"]["bs"] config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"] config["model"]["ts"] = args.ts if args.ts else config["model"]["ts"] config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"] config["model"]["loss"] = args.loss if args.loss else config["model"]["loss"] config["model"]["da"] = args.da if args.da else config["model"]["da"] config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"] check_classes(config) check_channels(config) check_model(config) if not os.path.isdir(os.path.expanduser(args.dataset)): sys.exit("ERROR: dataset {} is not a directory".format(args.dataset)) log = Logs(os.path.join(args.out, "log")) if torch.cuda.is_available(): log.log("RoboSat.pink - training on {} GPUs, with {} workers".format(torch.cuda.device_count(), args.workers)) log.log("(Torch:{} Cuda:{} CudNN:{})".format(torch.__version__, torch.version.cuda, torch.backends.cudnn.version())) device = torch.device("cuda") torch.backends.cudnn.benchmark = True else: log.log("RoboSat.pink - training on CPU, with {} workers - (Torch:{})".format(args.workers, torch.__version__)) log.log("WARNING: Are you really sure sure about not training on GPU ?") device = torch.device("cpu") loader = load_module("robosat_pink.loaders.{}".format(config["model"]["loader"].lower())) loader_train = getattr(loader, config["model"]["loader"])( config, config["model"]["ts"], os.path.join(args.dataset, "training"), "train" ) loader_val = getattr(loader, config["model"]["loader"])( config, config["model"]["ts"], os.path.join(args.dataset, "validation"), "train" ) model_module = load_module("robosat_pink.models.{}".format(config["model"]["nn"].lower())) nn = getattr(model_module, config["model"]["nn"])(loader_train.shape_in, loader_train.shape_out, config).to(device) nn = torch.nn.DataParallel(nn) optimizer = Adam(nn.parameters(), lr=config["model"]["lr"]) resume = 0 if args.checkpoint: chkpt = torch.load(os.path.expanduser(args.checkpoint), map_location=device) nn.load_state_dict(chkpt["state_dict"]) log.log("Using checkpoint: {}".format(args.checkpoint)) if args.resume: optimizer.load_state_dict(chkpt["optimizer"]) resume = chkpt["epoch"] if resume >= args.epochs: sys.exit("ERROR: Epoch {} already reached by the given checkpoint".format(config["model"]["epochs"])) loss_module = load_module("robosat_pink.losses.{}".format(config["model"]["loss"].lower())) criterion = getattr(loss_module, config["model"]["loss"])().to(device) bs = config["model"]["bs"] train_loader = DataLoader(loader_train, batch_size=bs, shuffle=True, drop_last=True, num_workers=args.workers) val_loader = DataLoader(loader_val, batch_size=bs, shuffle=False, drop_last=True, num_workers=args.workers) log.log("--- Input tensor from Dataset: {} ---".format(args.dataset)) num_channel = 1 # 1-based numerotation for channel in config["channels"]: for band in channel["bands"]: log.log("Channel {}:\t\t {}[band: {}]".format(num_channel, channel["name"], band)) num_channel += 1 log.log("--- Hyper Parameters ---") for hp in config["model"]: log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp])) for epoch in range(resume, args.epochs): UUID = uuid.uuid1() log.log("---{}Epoch: {}/{} -- UUID: {}".format(os.linesep, epoch + 1, args.epochs, UUID)) process(train_loader, config, log, device, nn, criterion, "train", optimizer) if not args.no_validation: process(val_loader, config, log, device, nn, criterion, "eval") try: # https://github.com/pytorch/pytorch/issues/9176 nn_doc = nn.module.doc nn_version = nn.module.version except AttributeError: nn_version = nn.version nn_doc == nn.doc states = { "uuid": UUID, "model_version": nn_version, "producer_name": "RoboSat.pink", "producer_version": "0.4.0", "model_licence": "MIT", "domain": "pink.RoboSat", # reverse-DNS "doc_string": nn_doc, "shape_in": loader_train.shape_in, "shape_out": loader_train.shape_out, "state_dict": nn.state_dict(), "epoch": epoch + 1, "nn": config["model"]["nn"], "optimizer": optimizer.state_dict(), "loader": config["model"]["loader"], } checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}.pth".format(epoch + 1)) torch.save(states, checkpoint_path)
def main(args): if args.pg: if not args.sql: sys.exit("ERROR: With PostgreSQL db, --sql must be provided") if (args.sql and args.geojson) or (args.sql and not args.pg): sys.exit( "ERROR: You can use either --pg or --geojson inputs, but only one at once." ) config = load_config(args.config) check_classes(config) palette = make_palette(*[classe["color"] for classe in config["classes"]], complementary=True) burn_value = next(config["classes"].index(classe) for classe in config["classes"] if classe["title"] == args.type) if "burn_value" not in locals(): sys.exit( "ERROR: asked type to rasterize is not contains in your config file classes." ) args.out = os.path.expanduser(args.out) os.makedirs(args.out, exist_ok=True) log = Logs(os.path.join(args.out, "log"), out=sys.stderr) def geojson_parse_polygon(zoom, srid, feature_map, polygon, i): try: if srid != 4326: polygon = [ xy for xy in geojson_reproject( { "type": "feature", "geometry": polygon }, srid, 4326) ][0] for i, ring in enumerate( polygon["coordinates"] ): # GeoJSON coordinates could be N dimensionals polygon["coordinates"][i] = [[ x, y ] for point in ring for x, y in zip([point[0]], [point[1]])] if polygon["coordinates"]: for tile in burntiles.burn([{ "type": "feature", "geometry": polygon }], zoom=zoom): feature_map[mercantile.Tile(*tile)].append({ "type": "feature", "geometry": polygon }) except ValueError: log.log("Warning: invalid feature {}, skipping".format(i)) return feature_map def geojson_parse_geometry(zoom, srid, feature_map, geometry, i): if geometry["type"] == "Polygon": feature_map = geojson_parse_polygon(zoom, srid, feature_map, geometry, i) elif geometry["type"] == "MultiPolygon": for polygon in geometry["coordinates"]: feature_map = geojson_parse_polygon(zoom, srid, feature_map, { "type": "Polygon", "coordinates": polygon }, i) else: log.log( "Notice: {} is a non surfacic geometry type, skipping feature {}" .format(geometry["type"], i)) return feature_map if args.geojson: tiles = [ tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] assert tiles, "Empty cover" zoom = tiles[0].z assert not [tile for tile in tiles if tile.z != zoom ], "Unsupported zoom mixed cover. Use PostGIS instead" feature_map = collections.defaultdict(list) log.log("RoboSat.pink - rasterize - Compute spatial index") for geojson_file in args.geojson: with open(os.path.expanduser(geojson_file)) as geojson: feature_collection = json.load(geojson) try: crs_mapping = {"CRS84": "4326", "900913": "3857"} srid = feature_collection["crs"]["properties"][ "name"].split(":")[-1] srid = int(srid) if srid not in crs_mapping else int( crs_mapping[srid]) except: srid = int(4326) for i, feature in enumerate( tqdm(feature_collection["features"], ascii=True, unit="feature")): if feature["geometry"]["type"] == "GeometryCollection": for geometry in feature["geometry"]["geometries"]: feature_map = geojson_parse_geometry( zoom, srid, feature_map, geometry, i) else: feature_map = geojson_parse_geometry( zoom, srid, feature_map, feature["geometry"], i) features = args.geojson if args.pg: conn = psycopg2.connect(args.pg) db = conn.cursor() assert "limit" not in args.sql.lower(), "LIMIT is not supported" db.execute( "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format( args.sql)) srid = db.fetchone()[0] assert srid, "Unable to retrieve geometry SRID." if "where" not in args.sql.lower( ): # TODO: Find a more reliable way to handle feature filtering args.sql += " WHERE ST_Intersects(tile.geom, geom)" else: args.sql += " AND ST_Intersects(tile.geom, geom)" features = args.sql log.log( "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format( args.type, features, args.cover)) with open(os.path.join(os.path.expanduser(args.out), "instances.cover"), mode="w") as cover: for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))), ascii=True, unit="tile"): geojson = None if args.pg: w, s, e, n = tile_bbox(tile) query = """ WITH tile AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}) AS geom), geom AS (SELECT ST_Intersection(tile.geom, sql.geom) AS geom FROM tile CROSS JOIN LATERAL ({}) sql), json AS (SELECT '{{"type": "Feature", "geometry": ' || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6) || '}}' AS features FROM geom) SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}' FROM json """.format(w, s, e, n, srid, args.sql) db.execute(query) row = db.fetchone() try: geojson = json.loads( row[0])["features"] if row and row[0] else None except Exception: log.log("Warning: Invalid geometries, skipping {}".format( tile)) conn = psycopg2.connect(args.pg) db = conn.cursor() if args.geojson: geojson = feature_map[tile] if tile in feature_map else None if geojson: num = len(geojson) out = geojson_tile_burn(tile, geojson, 4326, args.ts, burn_value) if not geojson or out is None: num = 0 out = np.zeros(shape=(args.ts, args.ts), dtype=np.uint8) tile_label_to_file(args.out, tile, palette, out) cover.write("{},{},{} {}{}".format(tile.x, tile.y, tile.z, num, os.linesep)) if not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "./" tiles = [tile for tile in tiles_from_csv(args.cover)] web_ui(args.out, base_url, tiles, tiles, "png", template)
def main(args): config = load_config(args.config) check_channels(config) check_classes(config) palette = make_palette([classe["color"] for classe in config["classes"]]) if not args.bs: try: args.bs = config["model"]["bs"] except: pass assert args.bs, "For rsp predict, model/bs must be set either in config file, or pass trought parameter --bs" args.workers = args.bs if not args.workers else args.workers cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] if args.cover else None log = Logs(os.path.join(args.out, "log")) if torch.cuda.is_available(): log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format( torch.cuda.device_count(), args.workers)) log.log("(Torch:{} Cuda:{} CudNN:{})".format( torch.__version__, torch.version.cuda, torch.backends.cudnn.version())) device = torch.device("cuda") torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True else: log.log("RoboSat.pink - predict on CPU, with {} workers".format( args.workers)) log.log("") log.log("============================================================") log.log("WARNING: Are you -really- sure about not predicting on GPU ?") log.log("============================================================") log.log("") device = torch.device("cpu") chkpt = torch.load(args.checkpoint, map_location=device) nn_module = load_module("robosat_pink.nn.{}".format(chkpt["nn"].lower())) nn = getattr(nn_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"], chkpt["encoder"].lower()).to(device) nn = torch.nn.DataParallel(nn) nn.load_state_dict(chkpt["state_dict"]) nn.eval() log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"])) with torch.no_grad( ): # don't track tensors with autograd during prediction tiled = [] if args.passes in ["first", "both"]: log.log("== Predict First Pass ==") tiled = predict(config, cover, args, palette, chkpt, nn, device, "predict") if args.passes in ["second", "both"]: log.log("== Predict Second Pass ==") predict(config, cover, args, palette, chkpt, nn, device, "predict_translate") if not args.no_web_ui and tiled: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "." web_ui(args.out, base_url, tiled, tiled, "png", template)
def main(args): if not args.workers: args.workers = min(os.cpu_count(), len(args.rasters)) if args.label: config = load_config(args.config) check_classes(config) colors = [classe["color"] for classe in config["classes"]] palette = make_palette(colors) assert len(args.ts.split( ",")) == 2, "--ts expect width,height value (e.g 512,512)" width, height = list(map(int, args.ts.split(","))) cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] if args.cover else None splits_path = os.path.join(os.path.expanduser(args.out), ".splits") tiles_map = {} print("RoboSat.pink - tile on CPU, with {} workers".format(args.workers), file=sys.stderr, flush=True) bands = -1 for path in args.rasters: raster = rasterio_open(path) w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) if bands != -1: assert bands == len( raster.indexes), "Coverage must be bands consistent" bands = len(raster.indexes) tiles = [ mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom) ] tiles = list(set(tiles) & set(cover)) if cover else tiles for tile in tiles: tile_key = (str(tile.x), str(tile.y), str(tile.z)) if tile_key not in tiles_map.keys(): tiles_map[tile_key] = [] tiles_map[tile_key].append(path) if args.label: ext = "png" bands = 1 if not args.label: if bands == 1: ext = "png" if bands == 3: ext = "webp" if bands > 3: ext = "tiff" tiles = [] progress = tqdm(total=len(tiles_map), ascii=True, unit="tile") # Begin to tile plain tiles with futures.ThreadPoolExecutor(args.workers) as executor: def worker(path): raster = rasterio_open(path) w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) tiles = [ mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom) ] tiled = [] for tile in tiles: if cover and tile not in cover: continue w, s, e, n = mercantile.xy_bounds(tile) warp_vrt = WarpedVRT( raster, crs="epsg:3857", resampling=Resampling.bilinear, add_alpha=False, transform=from_bounds(w, s, e, n, width, height), width=width, height=height, ) data = warp_vrt.read(out_shape=(len(raster.indexes), width, height), window=warp_vrt.window(w, s, e, n)) if data.dtype == "uint16": # GeoTiff could be 16 bits data = np.uint8(data / 256) elif data.dtype == "uint32": # or 32 bits data = np.uint8(data / (256 * 256)) image = np.moveaxis(data, 0, 2) # C,H,W -> H,W,C tile_key = (str(tile.x), str(tile.y), str(tile.z)) if (not args.label and len(tiles_map[tile_key]) == 1 and is_nodata( image, args.nodata, args.nodata_threshold, args.keep_borders)): progress.update() continue if len(tiles_map[tile_key]) > 1: out = os.path.join(splits_path, str(tiles_map[tile_key].index(path))) else: out = args.out x, y, z = map(int, tile) if not args.label: tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image) if args.label: tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, image) if len(tiles_map[tile_key]) == 1: progress.update() tiled.append(mercantile.Tile(x=x, y=y, z=z)) return tiled for tiled in executor.map(worker, args.rasters): if tiled is not None: tiles.extend(tiled) # Aggregate remaining tiles splits with futures.ThreadPoolExecutor(args.workers) as executor: def worker(tile_key): if len(tiles_map[tile_key]) == 1: return image = np.zeros((width, height, bands), np.uint8) x, y, z = map(int, tile_key) for i in range(len(tiles_map[tile_key])): root = os.path.join(splits_path, str(i)) _, path = tile_from_xyz(root, x, y, z) if not args.label: split = tile_image_from_file(path) if args.label: split = tile_label_from_file(path) split = split.reshape((width, height, 1)) # H,W -> H,W,C assert image.shape == split.shape image[np.where(image == 0)] += split[np.where(image == 0)] if not args.label and is_nodata(image, args.nodata, args.nodata_threshold, args.keep_borders): progress.update() return tile = mercantile.Tile(x=x, y=y, z=z) if not args.label: tile_image_to_file(args.out, tile, image) if args.label: tile_label_to_file(args.out, tile, palette, image) progress.update() return tile for tiled in executor.map(worker, tiles_map.keys()): if tiled is not None: tiles.append(tiled) if splits_path and os.path.isdir(splits_path): shutil.rmtree(splits_path) # Delete suffixes dir if any if tiles and not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "." web_ui(args.out, base_url, tiles, tiles, ext, template)
def main(args): if (args.geojson and args.postgis) or (not args.geojson and not args.postgis): sys.exit( "ERROR: Input features to rasterize must be either GeoJSON or PostGIS" ) if args.postgis and not args.pg_dsn: sys.exit( "ERROR: With PostGIS input features, --pg_dsn must be provided") config = load_config(args.config) check_classes(config) palette = make_palette(*[classe["color"] for classe in config["classes"]], complementary=True) burn_value = next(config["classes"].index(classe) for classe in config["classes"] if classe["title"] == args.type) if "burn_value" not in locals(): sys.exit( "ERROR: asked type to rasterize is not contains in your config file classes." ) args.out = os.path.expanduser(args.out) os.makedirs(args.out, exist_ok=True) log = Logs(os.path.join(args.out, "log"), out=sys.stderr) def geojson_parse_polygon(zoom, srid, feature_map, polygon, i): try: if srid != 4326: polygon = [ xy for xy in geojson_reproject( { "type": "feature", "geometry": polygon }, srid, 4326) ][0] for i, ring in enumerate( polygon["coordinates"] ): # GeoJSON coordinates could be N dimensionals polygon["coordinates"][i] = [[ x, y ] for point in ring for x, y in zip([point[0]], [point[1]])] if polygon["coordinates"]: for tile in burntiles.burn([{ "type": "feature", "geometry": polygon }], zoom=zoom): feature_map[mercantile.Tile(*tile)].append({ "type": "feature", "geometry": polygon }) except ValueError: log.log("Warning: invalid feature {}, skipping".format(i)) return feature_map def geojson_parse_geometry(zoom, srid, feature_map, geometry, i): if geometry["type"] == "Polygon": feature_map = geojson_parse_polygon(zoom, srid, feature_map, geometry, i) elif geometry["type"] == "MultiPolygon": for polygon in geometry["coordinates"]: feature_map = geojson_parse_polygon(zoom, srid, feature_map, { "type": "Polygon", "coordinates": polygon }, i) else: log.log( "Notice: {} is a non surfacic geometry type, skipping feature {}" .format(geometry["type"], i)) return feature_map if args.geojson: try: tiles = [ tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] zoom = tiles[0].z assert not [tile for tile in tiles if tile.z != zoom] except: sys.exit("ERROR: Inconsistent cover {}".format(args.cover)) feature_map = collections.defaultdict(list) log.log("RoboSat.pink - rasterize - Compute spatial index") for geojson_file in args.geojson: with open(os.path.expanduser(geojson_file)) as geojson: feature_collection = json.load(geojson) try: crs_mapping = {"CRS84": "4326", "900913": "3857"} srid = feature_collection["crs"]["properties"][ "name"].split(":")[-1] srid = int(srid) if srid not in crs_mapping else int( crs_mapping[srid]) except: srid = int(4326) for i, feature in enumerate( tqdm(feature_collection["features"], ascii=True, unit="feature")): if feature["geometry"]["type"] == "GeometryCollection": for geometry in feature["geometry"]["geometries"]: feature_map = geojson_parse_geometry( zoom, srid, feature_map, geometry, i) else: feature_map = geojson_parse_geometry( zoom, srid, feature_map, feature["geometry"], i) features = args.geojson if args.postgis: pg_conn = psycopg2.connect(args.pg_dsn) pg = pg_conn.cursor() pg.execute( "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format( args.postgis)) try: srid = pg.fetchone()[0] except Exception: sys.exit("Unable to retrieve geometry SRID.") features = args.postgis log.log( "RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format( args.type, features, args.cover)) with open(os.path.join(os.path.expanduser(args.out), "instances.cover"), mode="w") as cover: for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))), ascii=True, unit="tile"): if args.postgis: s, w, e, n = mercantile.bounds(tile) query = """ WITH a AS ({}), b AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}) AS geom) SELECT '{{ "type": "FeatureCollection", "features": [{{"type": "Feature", "geometry": ' || ST_AsGeoJSON(ST_Transform(ST_Intersection(a.geom, b.geom), 4326), 6) || '}}]}}' FROM a, b WHERE ST_Intersects(a.geom, b.geom) """.format(args.postgis, s, w, e, n, srid) try: pg.execute(query) row = pg.fetchone() geojson = json.loads(row[0])["features"] if row else None except Exception: log.log("Warning: Invalid geometries, skipping {}".format( tile)) pg_conn = psycopg2.connect(args.pg_dsn) pg = pg_conn.cursor() if args.geojson: geojson = feature_map[tile] if tile in feature_map else None if geojson: num = len(geojson) out = geojson_tile_burn(tile, geojson, 4326, args.ts, burn_value) if not geojson or out is None: num = 0 out = np.zeros(shape=(args.ts, args.ts), dtype=np.uint8) tile_label_to_file(args.out, tile, palette, out) cover.write("{},{},{} {}{}".format(tile.x, tile.y, tile.z, num, os.linesep)) if not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "./" tiles = [tile for tile in tiles_from_csv(args.cover)] web_ui(args.out, base_url, tiles, tiles, "png", template)