def main(args): assert os.path.isdir(os.path.expanduser( args.dataset)), "--dataset path is not a directory" args.cover = [ tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] if args.cover else None config = load_config(args.config) if not args.workers: args.workers = os.cpu_count() print("neo dataset {} on CPU, with {} workers".format( args.mode, args.workers), file=sys.stderr, flush=True) if args.mode == "check": check_classes(config) check_channels(config) # TODO check dataset if args.mode == "weights": check_classes(config) weights = compute_classes_weights(args.dataset, config["classes"], args.cover, args.workers) print(",".join(map(str, weights)))
def main(args): config = load_config(args.config) check_channels(config) check_classes(config) assert torch.cuda.is_available( ), "No GPU support found. Check CUDA and NVidia Driver install." assert torch.distributed.is_nccl_available( ), "No NCCL support found. Check your PyTorch install." world_size = torch.cuda.device_count() args.bs = args.bs if args.bs is not None else math.floor(os.cpu_count() / world_size) args.workers = args.workers if args.workers is not None else args.bs palette, transparency = make_palette( [classe["color"] for classe in config["classes"]]) args.cover = [ tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] if args.cover else None args.out = os.path.expanduser(args.out) log = Logs(os.path.join(args.out, "log")) chkpt = torch.load(args.checkpoint, map_location=torch.device("cpu")) chkpt["loader"] = "SemSeg" log.log("neo predict on {} GPUs, with {} workers/GPU and {} tiles/batch". format(world_size, args.workers, args.bs)) log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"])) log.log("---") loader = load_module("neat_eo.loaders.{}".format(chkpt["loader"].lower())) lock_file = os.path.abspath(os.path.join(args.out, str(uuid.uuid1()))) dataset = getattr(loader, chkpt["loader"])( config, chkpt["shape_in"][1:3], args.dataset, args.cover, mode="predict", metatiles=args.metatiles, keep_borders=args.keep_borders, ) mp.spawn(gpu_worker, nprocs=world_size, args=(world_size, lock_file, args, config, dataset, palette, transparency)) if os.path.exists(lock_file): os.remove(lock_file) if not args.no_web_ui and dataset.cover: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "." web_ui(args.out, base_url, dataset.cover, dataset.cover, "png", template)
def main(args): config = load_config(args.config) check_classes(config) index = [ i for i in (list(range(len(config["classes"])))) if config["classes"][i]["title"] == args.type ] assert index, "Requested type {} not found among classes title in the config file.".format( args.type) masks = list(tiles_from_dir(args.masks, xyz_path=True)) assert len(masks), "empty masks directory: {}".format(args.masks) print("neo vectorize {} from {}".format(args.type, args.masks), file=sys.stderr, flush=True) if os.path.dirname(os.path.expanduser(args.out)): os.makedirs(os.path.dirname(os.path.expanduser(args.out)), exist_ok=True) out = open(args.out, "w", encoding="utf-8") assert out, "Unable to write in output file" out.write('{"type":"FeatureCollection","features":[') first = True for tile, path in tqdm(masks, ascii=True, unit="mask"): mask = (np.array(Image.open(path).convert("P"), dtype=np.uint8) == index).astype(np.uint8) try: C, W, H = mask.shape except: W, H = mask.shape transform = rasterio.transform.from_bounds( (*mercantile.bounds(tile.x, tile.y, tile.z)), W, H) for shape, value in rasterio.features.shapes(mask, transform=transform, mask=mask): geom = '"geometry":{{"type": "Polygon", "coordinates":{}}}'.format( json.dumps(shape["coordinates"])) out.write('{}{{"type":"Feature",{}}}'.format( "" if first else ",", geom)) first = False out.write("]}")
def main(args): config = load_config(args.config) args.out = os.path.expanduser(args.out) args.cover = [ tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] if args.cover else None if args.classes_weights: try: args.classes_weights = list( map(float, args.classes_weights.split(","))) except: assert args.classes_weights == "auto", "invalid --classes_weights value" else: args.classes_weights = [ classe["weight"] for classe in config["classes"] ] args.tiles_weights = ([(tile, weight) for tile, weight in tiles_from_csv( os.path.expanduser(args.tiles_weights), extra_columns=True)] if args.tiles_weights else None) config["model"][ "loader"] = args.loader if args.loader else config["model"]["loader"] config["model"]["ts"] = tuple(map( int, args.ts.split(","))) if args.ts else config["model"]["ts"] config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"] config["model"]["encoder"] = args.encoder if args.encoder else config[ "model"]["encoder"] config["train"]["bs"] = args.bs if args.bs else config["train"]["bs"] config["train"][ "loss"] = args.loss if args.loss else config["train"]["loss"] config["train"]["optimizer"][ "name"] = args.optimizer if args.optimizer else config["train"][ "optimizer"]["name"] config["train"]["optimizer"][ "lr"] = args.lr if args.lr else config["train"]["optimizer"]["lr"] check_classes(config) check_channels(config) check_model(config) log = Logs(os.path.join(args.out, "log")) assert torch.cuda.is_available( ), "No GPU support found. Check CUDA and NVidia Driver install." assert torch.distributed.is_nccl_available( ), "No NCCL support found. Check your PyTorch install." world_size = torch.cuda.device_count() if args.train_dataset else 1 args.workers = min( config["train"]["bs"] if not args.workers else args.workers, math.floor(os.cpu_count() / world_size)) assert args.eval_dataset or args.train_dataset, "Provide at least one dataset" if args.eval_dataset and not args.train_dataset and not args.checkpoint: log.log( "\n\nNOTICE: No Checkpoint provided for eval only. Seems peculiar.\n\n" ) log.log("neo train/eval on {} GPUs, with {} workers/GPU".format( world_size, args.workers)) log.log("---") loader = load_module("neat_eo.loaders.{}".format( config["model"]["loader"].lower())) train_dataset = None if args.train_dataset: assert os.path.isdir(os.path.expanduser( args.train_dataset)), "--train_dataset path is not a directory" train_dataset = getattr(loader, config["model"]["loader"])( config, config["model"]["ts"], args.train_dataset, args.cover, args.tiles_weights, "train") assert len(train_dataset), "Empty or Invalid --train_dataset content" shape_in = train_dataset.shape_in shape_out = train_dataset.shape_out log.log("\nDataSet Training: {}".format(args.train_dataset)) if args.classes_weights == "auto": args.classes_weights = compute_classes_weights( args.train_dataset, config["classes"], args.cover, os.cpu_count()) eval_dataset = None if args.eval_dataset: assert os.path.isdir(os.path.expanduser( args.eval_dataset)), "--eval_dataset path is not a directory" eval_dataset = getattr(loader, config["model"]["loader"])( config, config["model"]["ts"], args.eval_dataset, args.cover, args.tiles_weights, "eval") assert len(eval_dataset), "Empty or Invalid --eval_dataset content" shape_in = eval_dataset.shape_in shape_out = eval_dataset.shape_out log.log("DataSet Eval: {}".format(args.eval_dataset)) if not args.train_dataset and args.classes_weights == "auto": args.classes_weights = compute_classes_weights( args.eval_dataset, config["classes"], args.cover, os.cpu_count()) log.log("\n--- Input tensor") num_channel = 1 # 1-based numerotation for channel in config["channels"]: for band in channel["bands"]: log.log("Channel {}:\t\t {} - (band:{})".format( num_channel, channel["name"], band)) num_channel += 1 log.log("\n--- Output Classes ---") for c, classe in enumerate(config["classes"]): log.log("Class {}:\t\t {} ({:.2f})".format(c, classe["title"], args.classes_weights[c])) log.log("\n--- Model ---") for hp in config["model"]: log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp])) lock_file = os.path.abspath(os.path.join(args.out, str(uuid.uuid1()))) mp.spawn( gpu_worker, nprocs=world_size, args=(world_size, lock_file, train_dataset, eval_dataset, shape_in, shape_out, args, config), ) if os.path.exists(lock_file): os.remove(lock_file)
def main(args): assert not (args.label and args.format), "Format option not supported for label, output must be kept as png" try: args.bands = list(map(int, args.bands.split(","))) if args.bands else None except: raise ValueError("invalid --args.bands value") if not args.workers: args.workers = min(os.cpu_count(), len(args.rasters)) if args.label: config = load_config(args.config) check_classes(config) colors = [classe["color"] for classe in config["classes"]] palette = make_palette(colors) assert len(args.ts.split(",")) == 2, "--ts expect width,height value (e.g 512,512)" width, height = list(map(int, args.ts.split(","))) cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))] if args.cover else None splits_path = os.path.join(os.path.expanduser(args.out), ".splits") args.out = os.path.expanduser(args.out) if os.path.dirname(os.path.expanduser(args.out)): os.makedirs(args.out, exist_ok=True) log = Logs(os.path.join(args.out, "log"), out=sys.stderr) raster = rasterio_open(os.path.expanduser(args.rasters[0])) args.bands = args.bands if args.bands else raster.indexes raster.close() print( "neo tile {} rasters on bands {}, on CPU with {} workers".format(len(args.rasters), args.bands, args.workers), file=sys.stderr, flush=True, ) skip = [] tiles_map = {} total = 0 for path in args.rasters: raster = rasterio_open(os.path.expanduser(path)) assert set(args.bands).issubset(set(raster.indexes)), "Missing bands in raster {}".format(path) try: w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) except: log.log("WARNING: missing or invalid raster projection, SKIPPING: {}".format(path)) skip.append(path) continue tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)] tiles = list(set(tiles) & set(cover)) if cover else tiles total += len(tiles) for tile in tiles: tile_key = (str(tile.x), str(tile.y), str(tile.z)) if tile_key not in tiles_map.keys(): tiles_map[tile_key] = [] tiles_map[tile_key].append(path) raster.close() assert total, "Nothing left to tile" if len(args.bands) == 1 or args.label: ext = "png" if args.format is None else args.format if len(args.bands) == 3: ext = "webp" if args.format is None else args.format if len(args.bands) > 3: ext = "tiff" if args.format is None else args.format tiles = [] progress = tqdm(desc="Coverage tiling", total=total, ascii=True, unit="tile") with futures.ThreadPoolExecutor(args.workers) as executor: def worker(path): if path in skip: return None raster = rasterio_open(path) w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)] tiled = [] for tile in tiles: if cover and tile not in cover: continue w, s, e, n = mercantile.xy_bounds(tile) warp_vrt = WarpedVRT( raster, crs="epsg:3857", resampling=Resampling.bilinear, add_alpha=False, transform=from_bounds(w, s, e, n, width, height), width=width, height=height, ) data = warp_vrt.read( out_shape=(len(args.bands), width, height), indexes=args.bands, window=warp_vrt.window(w, s, e, n) ) if data.dtype == "uint16": # GeoTiff could be 16 bits data = np.uint8(data / 256) elif data.dtype == "uint32": # or 32 bits data = np.uint8(data / (256 * 256)) image = np.moveaxis(data, 0, 2) # C,H,W -> H,W,C tile_key = (str(tile.x), str(tile.y), str(tile.z)) if ( not args.label and len(tiles_map[tile_key]) == 1 and is_nodata(image, args.nodata, args.nodata_threshold, args.keep_borders) ): progress.update() continue if len(tiles_map[tile_key]) > 1: out = os.path.join(splits_path, str(tiles_map[tile_key].index(path))) else: out = args.out x, y, z = map(int, tile) if not args.label: tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image, ext=ext) if args.label: tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, args.nodata, image) if len(tiles_map[tile_key]) == 1: tiled.append(mercantile.Tile(x=x, y=y, z=z)) progress.update() raster.close() return tiled for tiled in executor.map(worker, args.rasters): if tiled is not None: tiles.extend(tiled) total = sum([1 for tile_key in tiles_map.keys() if len(tiles_map[tile_key]) > 1]) progress = tqdm(desc="Aggregate splits", total=total, ascii=True, unit="tile") with futures.ThreadPoolExecutor(args.workers) as executor: def worker(tile_key): if len(tiles_map[tile_key]) == 1: return image = np.zeros((width, height, len(args.bands)), np.uint8) x, y, z = map(int, tile_key) for i in range(len(tiles_map[tile_key])): root = os.path.join(splits_path, str(i)) _, path = tile_from_xyz(root, x, y, z) if not args.label: split = tile_image_from_file(path) if args.label: split = tile_label_from_file(path) if len(split.shape) == 2: split = split.reshape((width, height, 1)) # H,W -> H,W,C assert image.shape == split.shape, "{}, {}".format(image.shape, split.shape) image[np.where(image == 0)] += split[np.where(image == 0)] if not args.label and is_nodata(image, args.nodata, args.nodata_threshold, args.keep_borders): progress.update() return tile = mercantile.Tile(x=x, y=y, z=z) if not args.label: tile_image_to_file(args.out, tile, image) if args.label: tile_label_to_file(args.out, tile, palette, image) progress.update() return tile for tiled in executor.map(worker, tiles_map.keys()): if tiled is not None: tiles.append(tiled) if splits_path and os.path.isdir(splits_path): shutil.rmtree(splits_path) # Delete suffixes dir if any if tiles and not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "." web_ui(args.out, base_url, tiles, tiles, ext, template)
def main(args): assert not (args.geojson is not None and args.pg is not None), "You have to choose between --pg or --geojson" assert len(args.ts.split(",")) == 2, "--ts expect width,height value (e.g 512,512)" config = load_config(args.config) check_classes(config) args.pg = config["auth"]["pg"] if not args.pg and "pg" in config["auth"].keys() else args.pg assert not (args.sql and not args.pg), "With --sql option, --pg dsn setting must also be provided" palette, transparency = make_palette([classe["color"] for classe in config["classes"]], complementary=True) index = [config["classes"].index(classe) for classe in config["classes"] if classe["title"] == args.type] assert index, "Requested type is not contains in your config file classes." burn_value = index[0] assert 0 < burn_value <= 255 if args.sql: assert "limit" not in args.sql.lower(), "LIMIT is not supported" assert "TILE_GEOM" in args.sql, "TILE_GEOM filter not found in your SQL" sql = re.sub(r"ST_Intersects( )*\((.*)?TILE_GEOM(.*)?\)", "1=1", args.sql, re.I) assert sql and sql != args.sql, "Incorrect TILE_GEOM filter in your SQL" if os.path.dirname(os.path.expanduser(args.out)): os.makedirs(os.path.expanduser(args.out), exist_ok=True) args.out = os.path.expanduser(args.out) log = Logs(os.path.join(args.out, "log"), out=sys.stderr) tiles = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))] assert len(tiles), "Empty Cover: {}".format(args.cover) if args.geojson: zoom = tiles[0].z assert not [tile for tile in tiles if tile.z != zoom], "Unsupported zoom mixed cover. Use PostGIS instead" feature_map = collections.defaultdict(list) log.log("neo rasterize - Compute spatial index") for geojson_file in args.geojson: with open(os.path.expanduser(geojson_file)) as geojson: feature_collection = json.load(geojson) srid = geojson_srid(feature_collection) for i, feature in enumerate(tqdm(feature_collection["features"], ascii=True, unit="feature")): feature_map = geojson_parse_feature(zoom, srid, feature_map, feature, args.buffer) features = args.geojson if args.sql: conn = psycopg2.connect(args.pg) db = conn.cursor() db.execute("""SELECT ST_Srid("1") AS srid FROM ({} LIMIT 1) AS t("1")""".format(sql)) srid = db.fetchone()[0] assert srid and int(srid) > 0, "Unable to retrieve geometry SRID." features = args.sql if not len(feature_map): log.log("-----------------------------------------------") log.log("NOTICE: no feature to rasterize, seems peculiar") log.log("-----------------------------------------------") log.log("neo rasterize - rasterizing {} from {} on cover {}".format(args.type, features, args.cover)) with open(os.path.join(os.path.expanduser(args.out), args.type.lower() + "_cover.csv"), mode="w") as cover: for tile in tqdm(tiles, ascii=True, unit="tile"): geojson = None if args.sql: w, s, e, n = tile_bbox(tile) tile_geom = "ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {})".format(w, s, e, n, srid) query = """ WITH sql AS ({}), geom AS (SELECT "1" AS geom FROM sql AS t("1")), json AS (SELECT '{{"type": "Feature", "geometry": ' || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6) || '}}' AS features FROM geom) SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}' FROM json """.format( args.sql.replace("TILE_GEOM", tile_geom) ) db.execute(query) row = db.fetchone() try: geojson = json.loads(row[0])["features"] if row and row[0] else None except Exception: log.log("Warning: Invalid geometries, skipping {}".format(tile)) conn = psycopg2.connect(args.pg) db = conn.cursor() if args.geojson: geojson = feature_map[tile] if tile in feature_map else None if geojson: num = len(geojson) out = geojson_tile_burn(tile, geojson, 4326, list(map(int, args.ts.split(","))), burn_value) if not geojson or out is None: num = 0 out = np.zeros(shape=list(map(int, args.ts.split(","))), dtype=np.uint8) tile_label_to_file(args.out, tile, palette, transparency, out, append=args.append) cover.write("{},{},{} {}{}".format(tile.x, tile.y, tile.z, num, os.linesep)) if not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "." tiles = [tile for tile in tiles_from_csv(args.cover)] web_ui(args.out, base_url, tiles, tiles, "png", template)
def main(args): if not args.masks or not args.labels: assert args.mode != "list", "Parameters masks and labels are mandatories in list mode." assert not (args.min or args.max), "Both --masks and --labels mandatory, for metric filtering." if args.min or args.max: config = load_config(args.config) args.out = os.path.expanduser(args.out) cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))] if args.cover else None args_minmax = set() args.min = {(m[0], m[1]): m[2] for m in args.min} if args.min else dict() args.max = {(m[0], m[1]): m[2] for m in args.max} if args.max else dict() args_minmax.update(args.min.keys()) args_minmax.update(args.max.keys()) minmax = dict() for mm in args_minmax: mm_min = float(args.min[mm]) if mm in args.min else 0.0 mm_max = float(args.max[mm]) if mm in args.max else 1.0 assert mm_min < mm_max, "--min must be lower than --max, on {}".format(mm) minmax[mm] = { "min": mm_min, "max": mm_max, "class_id": [c for c, classe in enumerate(config["classes"]) if classe["title"] == mm[0]][0], "module": load_module("neat_eo.metrics." + mm[1]), } if not args.workers: args.workers = os.cpu_count() print("neo compare {} on CPU, with {} workers".format(args.mode, args.workers), file=sys.stderr, flush=True) if args.images: tiles = [tile for tile in tiles_from_dir(args.images[0], cover=cover)] assert len(tiles), "Empty images dir: {}".format(args.images[0]) for image in args.images[1:]: assert sorted(tiles) == sorted([tile for tile in tiles_from_dir(image, cover=cover)]), "Unconsistent images dirs" if args.labels and args.masks: tiles_masks = [tile for tile in tiles_from_dir(args.masks, cover=cover)] tiles_labels = [tile for tile in tiles_from_dir(args.labels, cover=cover)] if args.images: assert sorted(tiles) == sorted(tiles_masks) == sorted(tiles_labels), "Unconsistent images/label/mask directories" else: assert len(tiles_masks), "Empty masks dir: {}".format(args.masks) assert len(tiles_labels), "Empty labels dir: {}".format(args.labels) assert sorted(tiles_masks) == sorted(tiles_labels), "Label and Mask directories are not consistent" tiles = tiles_masks tiles_list = [] tiles_compare = [] progress = tqdm(total=len(tiles), ascii=True, unit="tile") log = False if args.mode == "list" else Logs(os.path.join(args.out, "log")) with futures.ThreadPoolExecutor(args.workers) as executor: def worker(tile): x, y, z = list(map(str, tile)) if args.masks and args.labels: label = np.array(Image.open(os.path.join(args.labels, z, x, "{}.png".format(y)))) mask = np.array(Image.open(os.path.join(args.masks, z, x, "{}.png".format(y)))) assert label.shape == mask.shape, "Inconsistent tiles (size or dimensions)" metrics = dict() for mm in minmax: try: metrics[mm] = getattr(minmax[mm]["module"], "get")( torch.as_tensor(label, device="cpu"), torch.as_tensor(mask, device="cpu"), minmax[mm]["class_id"], ) except: progress.update() return False, tile if not (minmax[mm]["min"] <= metrics[mm] <= minmax[mm]["max"]): progress.update() return True, tile tiles_compare.append(tile) if args.mode == "side": for i, root in enumerate(args.images): img = tile_image_from_file(tile_from_xyz(root, x, y, z)[1], force_rgb=True) if i == 0: side = np.zeros((img.shape[0], img.shape[1] * len(args.images), 3)) side = np.swapaxes(side, 0, 1) if args.vertical else side image_shape = img.shape else: assert image_shape[0:2] == img.shape[0:2], "Unconsistent image size to compare" if args.vertical: side[i * image_shape[0] : (i + 1) * image_shape[0], :, :] = img else: side[:, i * image_shape[0] : (i + 1) * image_shape[0], :] = img tile_image_to_file(args.out, tile, np.uint8(side)) elif args.mode == "stack": for i, root in enumerate(args.images): tile_image = tile_image_from_file(tile_from_xyz(root, x, y, z)[1], force_rgb=True) if i == 0: image_shape = tile_image.shape[0:2] stack = tile_image / len(args.images) else: assert image_shape == tile_image.shape[0:2], "Unconsistent image size to compare" stack = stack + (tile_image / len(args.images)) tile_image_to_file(args.out, tile, np.uint8(stack)) elif args.mode == "list": tiles_list.append([tile, metrics]) progress.update() return True, tile for ok, tile in executor.map(worker, tiles): if not ok and log: log.log("Warning: skipping. {}".format(str(tile))) if args.mode == "list": with open(args.out, mode="w") as out: if args.geojson: out.write('{"type":"FeatureCollection","features":[') first = True for tile_list in tiles_list: tile, metrics = tile_list x, y, z = list(map(str, tile)) if args.geojson: prop = '"properties":{{"x":{},"y":{},"z":{}'.format(x, y, z) for metric in metrics: prop += ',"{}":{:.3f}'.format(metric, metrics[metric]) geom = '"geometry":{}'.format(json.dumps(feature(tile, precision=6)["geometry"])) out.write('{}{{"type":"Feature",{},{}}}}}'.format("," if not first else "", geom, prop)) first = False if not args.geojson: out.write("{},{},{}".format(x, y, z)) for metric in metrics: out.write("\t{:.3f}".format(metrics[metric])) out.write(os.linesep) if args.geojson: out.write("]}") out.close() base_url = args.web_ui_base_url if args.web_ui_base_url else "." if args.mode == "side" and not args.no_web_ui: template = "compare.html" if not args.web_ui_template else args.web_ui_template web_ui(args.out, base_url, tiles, tiles_compare, args.format, template, union_tiles=False) if args.mode == "stack" and not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template tiles = [tile for tile in tiles_from_dir(args.images[0])] web_ui(args.out, base_url, tiles, tiles_compare, args.format, template)
def main(args): config = load_config(args.config) args.cover = [ tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] if args.cover else None if args.classes_weights: try: args.classes_weights = list( map(float, args.classes_weights.split(","))) except: assert args.classes_weights == "auto", "invalid --classes_weights value" args.classes_weights = compute_classes_weights( args.dataset, config["classes"], args.cover, os.cpu_count()) else: args.classes_weights = [ classe["weight"] for classe in config["classes"] ] args.tiles_weights = ([(tile, weight) for tile, weight in tiles_from_csv( os.path.expanduser(args.tiles_weights), extra_columns=True)] if args.tiles_weights else None) args.bs = args.bs if args.bs else config["train"]["bs"] check_classes(config) check_channels(config) check_model(config) assert torch.cuda.is_available( ), "No GPU support found. Check CUDA and NVidia Driver install." assert torch.distributed.is_nccl_available( ), "No NCCL support found. Check your PyTorch install." world_size = 1 # Hard Coded since eval MultiGPUs not yet implemented args.workers = min(args.bs if not args.workers else args.workers, math.floor(os.cpu_count() / world_size)) print("neo eval on 1 GPU, with {} workers, and {} tiles/batch".format( args.workers, args.bs)) loader = load_module("neat_eo.loaders.{}".format( config["model"]["loader"].lower())) assert os.path.isdir(os.path.expanduser( args.dataset)), "--dataset path is not a directory" dataset = getattr(loader, config["model"]["loader"])(config, config["model"]["ts"], args.dataset, args.cover, args.tiles_weights, "eval") assert len(dataset), "Empty or Invalid --dataset content" shape_in = dataset.shape_in shape_out = dataset.shape_out print("DataSet Eval: {}".format(args.dataset)) print("\n--- Input tensor") num_channel = 1 # 1-based numerotation for channel in config["channels"]: for band in channel["bands"]: print("Channel {}:\t\t {} - (band:{})".format( num_channel, channel["name"], band)) num_channel += 1 print("\n--- Output Classes ---") for c, classe in enumerate(config["classes"]): print("Class {}:\t\t {} ({:.2f})".format(c, classe["title"], args.classes_weights[c])) print("\n--- Model ---") for hp in config["model"]: print("{}{}".format(hp.ljust(25, " "), config["model"][hp])) lock_file = os.path.abspath(os.path.join("/tmp", str(uuid.uuid1()))) mp.spawn(gpu_worker, nprocs=world_size, args=(world_size, lock_file, dataset, shape_in, shape_out, args, config)) if os.path.exists(lock_file): os.remove(lock_file)
def main(args): assert args.cover or args.granules or args.scenes, "Either --cover OR --granules OR --scenes is mandatory" assert not (args.download and not args.out), "--download implies out parameter" assert args.limit, "What about increasing --limit value ?" config = load_config(args.config) if args.cover: args.pg = args.pg if args.pg else config["auth"]["pg"] assert args.pg, "PostgreSQL connection settting is mandatory with --cover" args.granules = tiles_to_granules( tiles_from_csv(os.path.expanduser(args.cover)), args.pg) if args.out: args.out = os.path.expanduser(args.out) os.makedirs(args.out, exist_ok=True) log = Logs(os.path.join(args.out, "log"), out=sys.stderr) else: log = Logs(None, out=sys.stderr) log.log("neo sat on granules: {}".format(" ".join(args.granules))) scenes = search_scenes(args, log) if args.download: log.log("") log.log( "=============================================================================" ) log.log("Downloading selected scenes") log.log( "=============================================================================" ) report = [] login, password = dict([ auth.split("=") for auth in config["auth"]["theia"].split(" ") ]).values() with futures.ThreadPoolExecutor(args.workers) as executor: def worker(scene): scene_dir = os.path.join( args.out, scene["dir"] [:42]) # 42 related to Theia MD issue, dirty workaround if not os.path.isabs(scene_dir): scene_dir = "./" + scene_dir if glob.glob(scene_dir + "*"): scene["dir"] = glob.glob(scene_dir + "*")[0] return scene, None, True # Already Downloaded token = get_token(login, password) url = THEIA_URL + "/resto2/collections/SENTINEL2/{}/download/?issuerId=theia".format( scene["uuid"]) resp = requests.get( url, headers={"Authorization": "Bearer {}".format(token)}, stream=True) if resp is None: return scene, None, False # Auth issue zip_path = os.path.join(args.out, scene["uuid"] + ".zip") with open(zip_path, "wb") as fp: progress = tqdm(unit="B", desc=scene["uuid"], total=int(resp.headers["Content-Length"])) for chunk in resp.iter_content(chunk_size=16384): progress.update(16384) fp.write(chunk) return scene, zip_path, True return scene, None, False # Write issue for scene, zip_path, ok in executor.map(worker, scenes): if zip_path and md5(zip_path) == scene["checksum"]: scene["dir"] = os.path.dirname( ZipFile(zip_path).namelist()[0]) ZipFile(zip_path).extractall(args.out) os.remove(zip_path) report.append("Scene {} available in {}".format( scene["uuid"], scene["dir"])) elif ok: report.append( "SKIPPING downloading {}, as already in {}".format( scene["uuid"], scene["dir"])) else: report.append("ERROR: Unable to retrieve Scene {}".format( scene["uuid"])) log.log("") log.log( "=============================================================================" ) for line in report: log.log(line) log.log( "=============================================================================" )