def main(): ap = argparse.ArgumentParser(description='run model on tileset') ap.add_argument('model', help='path to model checkpoint') ap.add_argument('config', help='path to model config') ap.add_argument('tiles', help='path to XYZ tile folder') ap.add_argument('outputdir', help='name for tile output') ap.add_argument('--aws_profile', help='AWS Profile Name', default='default') args = ap.parse_args() config = load_config(args.config) tiles = S3SlippyMapTiles(args.tiles, mode='multibands', aws_profile=args.aws_profile) net = model(config, args.model) loader = DataLoader(tiles, batch_size=config['model']['batch_size'], shuffle=True, num_workers=1) palette = make_palette(config["classes"][0]["color"]) fs = s3fs.S3FileSystem(session=boto3.Session(profile_name='esip')) outputdir = args.outputdir[5:] + '/' + os.path.basename(args.tiles) print("Saving predictions to {}.".format(outputdir)) predict(net, loader, outputdir, palette, fs)
def main(args): config = load_config(args.config) check_classes(config) index = [ i for i in (list(range(len(config["classes"])))) if config["classes"][i]["title"] == args.type ] if not index: sys.exit( "ERROR: Requested type {} not found among classes title in the config file." .format(args.type)) print("RoboSat.pink - vectorize {} from {}".format(args.type, args.masks)) with open(args.out, "w", encoding="utf-8") as out: first = True out.write('{"type":"FeatureCollection","features":[') for tile, path in tqdm(list(tiles_from_slippy_map(args.masks)), ascii=True, unit="mask"): try: features = (np.array(Image.open(path).convert("P"), dtype=np.uint8) == index).astype(np.uint8) try: C, W, H = features.shape except: W, H = features.shape transform = rasterio.transform.from_bounds( (*mercantile.bounds(tile.x, tile.y, tile.z)), W, H) for shape, value in rasterio.features.shapes( features, transform=transform): prop = '"properties":{{"x":{},"y":{},"z":{}}}'.format( int(tile.x), int(tile.y), int(tile.z)) geom = '"geometry":{{"type": "Polygon", "coordinates":{}}}'.format( json.dumps(shape["coordinates"])) out.write('{}{{"type":"Feature",{},{}}}'.format( "," if not first else "", geom, prop)) first = False except: sys.exit("ERROR: Unable to vectorize tile {}.".format( str(tile))) out.write("]}")
def main(args): module_search_path = [args.path] if args.path else [] module_search_path.append( os.path.join(Path(__file__).parent.parent, "features")) modules = [(path, name) for path, name, _ in pkgutil.iter_modules(module_search_path) if name != "core"] if args.type not in [name for _, name in modules]: sys.exit("Unknown type, thoses available are {}".format( [name for _, name in modules])) config = load_config(args.config) labels = config["classes"]["titles"] if args.type not in labels: sys.exit( "The type you asked is not consistent with yours classes in the config file provided." ) index = labels.index(args.type) if args.path: sys.path.append(args.path) module = import_module(args.type) else: module = import_module("robosat_pink.features.{}".format(args.type)) handler = getattr(module, "{}Handler".format(args.type.title()))() for tile, path in tqdm(list(tiles_from_slippy_map(args.masks)), ascii=True, unit="mask"): image = np.array(Image.open(path).convert("P"), dtype=np.uint8) mask = (image == index).astype(np.uint8) handler.apply(tile, mask) handler.save(args.out)
def main(args): config = load_config(args.config) print(config) log = Logs(os.path.join(args.out, "log")) if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cudnn.benchmark = True log.log("RoboSat - training on {} GPUs, with {} workers".format( torch.cuda.device_count(), args.workers)) else: device = torch.device("cpu") log.log("RoboSat - training on CPU, with {} workers".format( args.workers)) num_classes = len(config["classes"]) num_channels = 0 for channel in config["channels"]: num_channels += len(channel["bands"]) pretrained = config["model"]["pretrained"] encoder = config["model"]["encoder"] models = [ name for _, name, _ in pkgutil.iter_modules( [os.path.dirname(robosat_pink.models.__file__)]) ] if config["model"]["name"] not in [model for model in models]: sys.exit("Unknown model, thoses available are {}".format( [model for model in models])) model_module = import_module("robosat_pink.models.{}".format( config["model"]["name"])) net = getattr(model_module, "{}".format(config["model"]["name"].title()))( num_classes=num_classes, num_channels=num_channels, encoder=encoder, pretrained=pretrained).to(device) net = torch.nn.DataParallel(net) optimizer = Adam(net.parameters(), lr=config["model"]["lr"], weight_decay=config["model"]["decay"]) resume = 0 # check checkpoint situation + load if ncessary checkpoint = None # no checkpoint if args.checkpoint: # command line checkpoint checkpoint = args.checkpoint try: # config file checkpoint checkpoint = config["checkpoint"]['path'] except: # no checkpoint in config file pass S3_CHECKPOINT = False if checkpoint: if checkpoint.startswith("s3://"): S3_CHECKPOINT = True # load from s3 checkpoint = checkpoint[5:] sess = boto3.Session(profile_name=config['dataset']['aws_profile']) fs = s3fs.S3FileSystem(session=sess) s3ckpt = s3fs.S3File(fs, checkpoint, 'rb') def map_location(storage, _): return storage.cuda() if torch.cuda.is_available( ) else storage.cpu() if checkpoint is not None: def map_location(storage, _): return storage.cuda() if torch.cuda.is_available( ) else storage.cpu() try: if S3_CHECKPOINT: with s3fs.S3File(fs, checkpoint, 'rb') as C: state = torch.load(io.BytesIO(C.read()), map_location=map_location) else: state = torch.load(checkpoint) optimizer.load_state_dict(state['optimizer']) net.load_state_dict(state['state_dict']) net.to(device) except FileNotFoundError as f: print("{} checkpoint not found.".format(CHECKPOINT)) log.log("Using checkpoint: {}".format(checkpoint)) losses = [ name for _, name, _ in pkgutil.iter_modules( [os.path.dirname(robosat_pink.losses.__file__)]) ] if config["model"]["loss"] not in [loss for loss in losses]: sys.exit("Unknown loss, thoses available are {}".format( [loss for loss in losses])) loss_module = import_module("robosat_pink.losses.{}".format( config["model"]["loss"])) criterion = getattr(loss_module, "{}".format( config["model"]["loss"].title()))().to(device) train_loader, val_loader = get_dataset_loaders(config, args.workers, idDir=args.out) if resume >= config["model"]["epochs"]: sys.exit( "Error: Epoch {} set in {} already reached by the checkpoint provided" .format(config["model"]["epochs"], args.config)) log.log("") log.log("--- Input tensor from Dataset: {} ---".format( config["dataset"]["image_bucket"] + '/' + config['dataset']['imagery_directory_regex'])) log.log("") log.log("--- Hyper Parameters ---") log.log("Model:\t\t\t {}".format(config["model"]["name"])) log.log("Encoder model:\t\t {}".format(config["model"]["encoder"])) log.log("Loss function:\t\t {}".format(config["model"]["loss"])) log.log("ResNet pre-trained:\t {}".format(config["model"]["pretrained"])) log.log("Batch Size:\t\t {}".format(config["model"]["batch_size"])) log.log("Tile Size:\t\t {}".format(config["model"]["tile_size"])) log.log("Data Augmentation:\t {}".format( config["model"]["data_augmentation"])) log.log("Learning Rate:\t\t {}".format(config["model"]["lr"])) log.log("Weight Decay:\t\t {}".format(config["model"]["decay"])) log.log("") for epoch in range(resume, config["model"]["epochs"]): log.log("---") log.log("Epoch: {}/{}".format(epoch + 1, config["model"]["epochs"])) train_hist = train(train_loader, num_classes, device, net, optimizer, criterion) log.log( "Train loss: {:.4f}, mIoU: {:.3f}, IoU: {:.3f}, precision: {:.3f}, recall: {:.3f}" .format( train_hist["loss"], train_hist["miou"], train_hist["fg_iou"], train_hist["precision"], train_hist["recall"], )) val_hist = validate(val_loader, num_classes, device, net, criterion) log.log( "Validate loss: {:.4f}, mIoU: {:.3f}, IoU: {:.3f}, precision: {:.3f}, recall: {:.3f}" .format( train_hist["loss"], train_hist["miou"], train_hist["fg_iou"], train_hist["precision"], train_hist["recall"], )) states = { "epoch": epoch + 1, "state_dict": net.state_dict(), "optimizer": optimizer.state_dict() } checkpoint_path = os.path.join( args.out, "checkpoint-{:05d}-of-{:05d}.pth".format( epoch + 1, config["model"]["epochs"])) torch.save(states, checkpoint_path)
def main(args): if not args.masks or not args.labels or not args.config: if args.mode == "list": sys.exit( "Parameters masks, labels and config, are all mandatories in list mode." ) if args.minimum_fg > 0 or args.maximum_fg < 100 or args.minimum_qod > 0 or args.maximum_qod < 100: sys.exit( "Parameters masks, labels and config, are all mandatories in QoD filtering." ) if args.images: tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])] for image in args.images[1:]: assert sorted(tiles) == sorted([ tile for tile, _ in tiles_from_slippy_map(image) ]), "inconsistent coverages" if args.labels and args.masks: tiles_masks = [tile for tile, _ in tiles_from_slippy_map(args.masks)] tiles_labels = [tile for tile, _ in tiles_from_slippy_map(args.labels)] if args.images: assert sorted(tiles) == sorted(tiles_masks) == sorted( tiles_labels), "inconsistent coverages" else: assert sorted(tiles_masks) == sorted( tiles_labels), "inconsistent coverages" tiles = tiles_masks if args.mode == "list": out = open(args.out, mode="w") if args.geojson: out.write('{"type":"FeatureCollection","features":[') first = True tiles_compare = [] for tile in tqdm(list(tiles), desc="Compare", unit="tile", ascii=True): x, y, z = list(map(str, tile)) if args.masks and args.labels and args.config: titles = [ classe["title"] for classe in load_config(args.config)["classes"] ] dist, fg_ratio, qod = compare(args.masks, args.labels, tile, titles) if not args.minimum_fg <= fg_ratio <= args.maximum_fg or not args.minimum_qod <= qod <= args.maximum_qod: continue tiles_compare.append(tile) if args.mode == "side": for i, root in enumerate(args.images): img = tile_image(tile_from_slippy_map(root, x, y, z)[1]) if i == 0: side = np.zeros( (img.shape[0], img.shape[1] * len(args.images), 3)) side = np.swapaxes(side, 0, 1) if args.vertical else side image_shape = img.shape else: assert image_shape == img.shape, "Unconsistent image size to compare" if args.vertical: side[i * image_shape[0]:(i + 1) * image_shape[0], :, :] = img else: side[:, i * image_shape[0]:(i + 1) * image_shape[0], :] = img os.makedirs(os.path.join(args.out, z, x), exist_ok=True) side = Image.fromarray(np.uint8(side)) side.save(os.path.join(args.out, z, x, "{}.{}".format(y, args.ext)), optimize=True) elif args.mode == "stack": for i, root in enumerate(args.images): img = tile_image(tile_from_slippy_map(root, x, y, z)[1]) if i == 0: image_shape = img.shape[0:2] stack = img / len(args.images) else: assert image_shape == img.shape[ 0:2], "Unconsistent image size to compare" stack = stack + (img / len(args.images)) os.makedirs(os.path.join(args.out, str(z), str(x)), exist_ok=True) stack = Image.fromarray(np.uint8(stack)) stack.save(os.path.join(args.out, str(z), str(x), "{}.{}".format(y, args.ext)), optimize=True) elif args.mode == "list": if args.geojson: prop = '"properties":{{"x":{},"y":{},"z":{},"fg":{:.1f},"qod":{:.1f}}}'.format( x, y, z, fg_ratio, qod) geom = '"geometry":{}'.format( json.dumps(feature(tile, precision=6)["geometry"])) out.write('{}{{"type":"Feature",{},{}}}'.format( "," if not first else "", geom, prop)) first = False else: out.write("{},{},{}\t\t{:.1f}\t\t{:.1f}{}".format( x, y, z, fg_ratio, qod, os.linesep)) if args.mode == "list": if args.geojson: out.write("]}") out.close() base_url = args.web_ui_base_url if args.web_ui_base_url else "./" if args.mode == "side" and args.web_ui: template = "compare.html" if not args.web_ui_template else args.web_ui_template web_ui(args.out, base_url, None, tiles_compare, args.ext, template) if args.mode == "stack" and args.web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])] web_ui(args.out, base_url, tiles, tiles_compare, args.ext, template)
def main(args): if not args.workers: args.workers = max(1, math.floor(os.cpu_count() * 0.5)) if args.label: config = load_config(args.config) check_classes(config) colors = [classe["color"] for classe in config["classes"]] palette = make_palette(*colors) splits_path = os.path.join(os.path.expanduser(args.out), ".splits") tiles_map = {} print("RoboSat.pink - tile on CPU, with {} workers".format(args.workers)) bands = -1 for path in args.rasters: try: raster = rasterio_open(path) w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) except: sys.exit("Error: Unable to load raster {} or deal with it's projection".format(args.raster)) if bands != -1: assert bands == len(raster.indexes), "Coverage must be bands consistent" bands = len(raster.indexes) tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)] for tile in tiles: tile_key = (str(tile.x), str(tile.y), str(tile.z)) if tile_key not in tiles_map.keys(): tiles_map[tile_key] = [] tiles_map[tile_key].append(path) if args.label: ext = "png" bands = 1 if not args.label: if bands == 1: ext = "png" if bands == 3: ext = "webp" if bands > 3: ext = "tiff" tiles = [] progress = tqdm(total=len(tiles_map), ascii=True, unit="tile") # Begin to tile plain tiles with futures.ThreadPoolExecutor(args.workers) as executor: def worker(path): raster = rasterio_open(path) w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) transform, _, _ = calculate_default_transform(raster.crs, "EPSG:3857", raster.width, raster.height, w, s, e, n) tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)] tiled = [] for tile in tiles: try: w, s, e, n = mercantile.xy_bounds(tile) # inspired by rio-tiler, cf: https://github.com/mapbox/rio-tiler/pull/45 warp_vrt = WarpedVRT( raster, crs="epsg:3857", resampling=Resampling.bilinear, add_alpha=False, transform=from_bounds(w, s, e, n, args.ts, args.ts), width=math.ceil((e - w) / transform.a), height=math.ceil((s - n) / transform.e), ) data = warp_vrt.read( out_shape=(len(raster.indexes), args.ts, args.ts), window=warp_vrt.window(w, s, e, n) ) image = np.moveaxis(data, 0, 2) # C,H,W -> H,W,C except: sys.exit("Error: Unable to tile {} from raster {}.".format(str(tile), raster)) tile_key = (str(tile.x), str(tile.y), str(tile.z)) if not args.label and len(tiles_map[tile_key]) == 1 and is_border(image): progress.update() continue if len(tiles_map[tile_key]) > 1: out = os.path.join(splits_path, str(tiles_map[tile_key].index(path))) else: out = args.out x, y, z = map(int, tile) if not args.label: ret = tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image) if args.label: ret = tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, image) if not ret: sys.exit("Error: Unable to write tile {} from raster {}.".format(str(tile), raster)) if len(tiles_map[tile_key]) == 1: progress.update() tiled.append(mercantile.Tile(x=x, y=y, z=z)) return tiled for tiled in executor.map(worker, args.rasters): if tiled is not None: tiles.extend(tiled) # Aggregate remaining tiles splits with futures.ThreadPoolExecutor(args.workers) as executor: def worker(tile_key): if len(tiles_map[tile_key]) == 1: return image = np.zeros((args.ts, args.ts, bands), np.uint8) x, y, z = map(int, tile_key) for i in range(len(tiles_map[tile_key])): root = os.path.join(splits_path, str(i)) _, path = tile_from_slippy_map(root, x, y, z) if not args.label: split = tile_image_from_file(path) if args.label: split = tile_label_from_file(path) split = split.reshape((args.ts, args.ts, 1)) # H,W -> H,W,C assert image.shape == split.shape image[:, :, :] += split[:, :, :] if not args.label and is_border(image): progress.update() return tile = mercantile.Tile(x=x, y=y, z=z) if not args.label: ret = tile_image_to_file(args.out, tile, image) if args.label: ret = tile_label_to_file(args.out, tile, palette, image) if not ret: sys.exit("Error: Unable to write tile {}.".format(str(tile_key))) progress.update() return tile for tiled in executor.map(worker, tiles_map.keys()): if tiled is not None: tiles.append(tiled) if splits_path and os.path.isdir(splits_path): shutil.rmtree(splits_path) # Delete suffixes dir if any if not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "./" web_ui(args.out, base_url, tiles, tiles, ext, template)
def main(args): config = load_config(args.config) num_classes = len(config["classes"]) num_channels = 0 for channel in config["channels"]: num_channels += len(channel["bands"]) export_channels = num_channels if not args.export_channels else args.export_channels if export_channels > num_channels: sys.exit( "Error: Attempt to export more channels than thoses dataset provide." ) def map_location(storage, _): return storage.cuda() if torch.cuda.is_available() else storage.cpu() if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") models = [ name for _, name, _ in pkgutil.iter_modules( [os.path.dirname(robosat_pink.models.__file__)]) ] if config["model"]["name"] not in [model for model in models]: sys.exit("Unknown model, thoses available are {}".format( [model for model in models])) encoder = config["model"]["encoder"] pretrained = config["model"]["pretrained"] model_module = import_module("robosat_pink.models.{}".format( config["model"]["name"])) net = getattr(model_module, "{}".format(config["model"]["name"].title()))( num_classes=num_classes, num_channels=num_channels, encoder=encoder, pretrained=pretrained).to(device) chkpt = torch.load(args.checkpoint, map_location=map_location) net = torch.nn.DataParallel(net) net.load_state_dict(chkpt["state_dict"]) if export_channels < num_channels: weights = torch.zeros((64, export_channels, 7, 7)) weights.data = net.module.resnet.conv1.weight.data[:, : export_channels, :, :] net.module.resnet.conv1 = torch.nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) net.module.resnet.conv1.weight = torch.nn.Parameter(weights) if args.type == "pth": states = { "epoch": chkpt["epoch"], "state_dict": net.state_dict(), "optimizer": chkpt["optimizer"] } torch.save(states, args.out)
def main(args): if (args.geojson and args.postgis) or (not args.geojson and not args.postgis): sys.exit("Input features to rasterize must be either GeoJSON or PostGIS") config = load_config(args.config) tile_size = args.tile_size if args.tile_size else config["model"]["tile_size"] colors = [classe["color"] for classe in config["classes"]] burn_value = 1 os.makedirs(args.out, exist_ok=True) log = Logs(os.path.join(args.out, "log"), out=sys.stderr) def geojson_parse_polygon(zoom, feature_map, polygon, i): try: for i, ring in enumerate(polygon["coordinates"]): # GeoJSON coordinates could be N dimensionals polygon["coordinates"][i] = [[x, y] for point in ring for x, y in zip([point[0]], [point[1]])] for tile in burntiles.burn([{"type": "feature", "geometry": polygon}], zoom=zoom): feature_map[mercantile.Tile(*tile)].append({"type": "feature", "geometry": polygon}) except ValueError: log.log("Warning: invalid feature {}, skipping".format(i)) return feature_map def geojson_parse_geometry(zoom, feature_map, geometry, i): if geometry["type"] == "Polygon": feature_map = geojson_parse_polygon(zoom, feature_map, geometry, i) elif geometry["type"] == "MultiPolygon": for polygon in geometry["coordinates"]: feature_map = geojson_parse_polygon(zoom, feature_map, {"type": "Polygon", "coordinates": polygon}, i) else: log.log("Notice: {} is a non surfacic geometry type, skipping feature {}".format(geometry["type"], i)) return feature_map if args.geojson: tiles = [tile for tile in tiles_from_csv(args.cover)] zoom = tiles[0].z if [tile for tile in tiles if tile.z != zoom]: sys.exit("With GeoJson input, all tiles z values have to be the same, in the csv cover file.") feature_map = collections.defaultdict(list) # Compute a spatial index like for geojson_file in args.geojson: with open(geojson_file) as geojson: feature_collection = json.load(geojson) for i, feature in enumerate(tqdm(feature_collection["features"], ascii=True, unit="feature")): if feature["geometry"]["type"] == "GeometryCollection": for geometry in feature["geometry"]["geometries"]: feature_map = geojson_parse_geometry(zoom, feature_map, geometry, i) else: feature_map = geojson_parse_geometry(zoom, feature_map, feature["geometry"], i) # Rasterize tiles for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"): if tile in feature_map: out = geojson_tile_burn(tile, feature_map[tile], tile_size, burn_value) else: out = np.zeros(shape=(tile_size, tile_size), dtype=np.uint8) write_tile(args.out, tile, colors, out) if args.postgis: try: pg_conn = psycopg2.connect(config["dataset"]["pg_dsn"]) pg = pg_conn.cursor() except Exception: sys.exit("Unable to connect PostgreSQL: {}".format(config["dataset"]["pg_dsn"])) try: pg.execute("SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(args.postgis)) srid = pg.fetchone()[0] except Exception: sys.exit("Unable to retrieve geometry SRID.") for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"): s, w, e, n = mercantile.bounds(tile) raster = np.zeros((tile_size, tile_size)) query = """ WITH bbox AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {} ) AS bbox), bbox_merc AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), 3857) AS bbox), rast_a AS (SELECT ST_AddBand( ST_SetSRID( ST_MakeEmptyRaster({}, {}, ST_Xmin(bbox), ST_Ymax(bbox), (ST_YMax(bbox) - ST_YMin(bbox)) / {}), 3857), '8BUI'::text, 0) AS rast FROM bbox_merc), features AS (SELECT ST_Union(ST_Transform(ST_Force2D(geom), 3857)) AS geom FROM ({}) AS sub, bbox WHERE ST_Intersects(geom, bbox)), rast_b AS (SELECT ST_AsRaster(geom, rast, '8BUI', {}) AS rast FROM features, rast_a WHERE NOT ST_IsEmpty(geom)) SELECT ST_AsBinary(ST_MapAlgebra(rast_a.rast, rast_b.rast, '{}', NULL, 'FIRST')) AS wkb FROM rast_a, rast_b """.format( s, w, e, n, srid, s, w, e, n, tile_size, tile_size, tile_size, args.postgis, burn_value, burn_value ) try: pg.execute(query) row = pg.fetchone() if row: raster = np.squeeze(wkb_to_numpy(io.BytesIO(row[0])), axis=2) except Exception: log.log("Warning: Invalid geometries, skipping {}".format(tile)) pg_conn = psycopg2.connect(config["dataset"]["pg_dsn"]) pg = pg_conn.cursor() write_tile(args.out, tile, colors, raster) if args.web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "./" tiles = [tile for tile in tiles_from_csv(args.cover)] web_ui(args.out, base_url, tiles, tiles, "png", template)
def main(args): configLoc = args[1] config = load_config(configLoc) print(config)
def main(args): config = load_config(args.config) config["dataset"][ "path"] = args.dataset if args.dataset else config["dataset"]["path"] config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"] config["model"][ "epochs"] = args.epochs if args.epochs else config["model"]["epochs"] config["model"][ "batch_size"] = args.batch_size if args.batch_size else config[ "model"]["batch_size"] log = Logs(os.path.join(args.out, "log")) if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cudnn.benchmark = True log.log("RoboSat - training on {} GPUs, with {} workers".format( torch.cuda.device_count(), args.workers)) else: device = torch.device("cpu") log.log("RoboSat - training on CPU, with {} workers".format( args.workers)) num_classes = len(config["classes"]) num_channels = 0 for channel in config["channels"]: num_channels += len(channel["bands"]) pretrained = config["model"]["pretrained"] encoder = config["model"]["encoder"] models = [ name for _, name, _ in pkgutil.iter_modules( [os.path.dirname(robosat_pink.models.__file__)]) ] if config["model"]["name"] not in [model for model in models]: sys.exit("Unknown model, thoses available are {}".format( [model for model in models])) model_module = import_module("robosat_pink.models.{}".format( config["model"]["name"])) net = getattr(model_module, "{}".format(config["model"]["name"].title()))( num_classes=num_classes, num_channels=num_channels, encoder=encoder, pretrained=pretrained).to(device) net = torch.nn.DataParallel(net) optimizer = Adam(net.parameters(), lr=config["model"]["lr"], weight_decay=config["model"]["decay"]) resume = 0 if args.checkpoint: def map_location(storage, _): return storage.cuda() if torch.cuda.is_available( ) else storage.cpu() # https://github.com/pytorch/pytorch/issues/7178 chkpt = torch.load(args.checkpoint, map_location=map_location) net.load_state_dict(chkpt["state_dict"]) log.log("Using checkpoint: {}".format(args.checkpoint)) if args.resume: optimizer.load_state_dict(chkpt["optimizer"]) resume = chkpt["epoch"] losses = [ name for _, name, _ in pkgutil.iter_modules( [os.path.dirname(robosat_pink.losses.__file__)]) ] if config["model"]["loss"] not in [loss for loss in losses]: sys.exit("Unknown loss, thoses available are {}".format( [loss for loss in losses])) loss_module = import_module("robosat_pink.losses.{}".format( config["model"]["loss"])) criterion = getattr(loss_module, "{}".format( config["model"]["loss"].title()))().to(device) train_loader, val_loader = get_dataset_loaders(config["dataset"]["path"], config, args.workers) if resume >= config["model"]["epochs"]: sys.exit( "Error: Epoch {} set in {} already reached by the checkpoint provided" .format(config["model"]["epochs"], args.config)) log.log("") log.log("--- Input tensor from Dataset: {} ---".format( config["dataset"]["path"])) num_channel = 1 for channel in config["channels"]: for band in channel["bands"]: log.log("Channel {}:\t\t {}[band: {}]".format( num_channel, channel["sub"], band)) num_channel += 1 log.log("") log.log("--- Hyper Parameters ---") log.log("Model:\t\t\t {}".format(config["model"]["name"])) log.log("Encoder model:\t\t {}".format(config["model"]["encoder"])) log.log("Loss function:\t\t {}".format(config["model"]["loss"])) log.log("ResNet pre-trained:\t {}".format(config["model"]["pretrained"])) log.log("Batch Size:\t\t {}".format(config["model"]["batch_size"])) log.log("Tile Size:\t\t {}".format(config["model"]["tile_size"])) log.log("Data Augmentation:\t {}".format( config["model"]["data_augmentation"])) log.log("Learning Rate:\t\t {}".format(config["model"]["lr"])) log.log("Weight Decay:\t\t {}".format(config["model"]["decay"])) log.log("") for epoch in range(resume, config["model"]["epochs"]): log.log("---") log.log("Epoch: {}/{}".format(epoch + 1, config["model"]["epochs"])) train_hist = train(train_loader, num_classes, device, net, optimizer, criterion) log.log( "Train loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}". format( train_hist["loss"], train_hist["miou"], config["classes"][1]["title"], train_hist["fg_iou"], train_hist["mcc"], )) val_hist = validate(val_loader, num_classes, device, net, criterion) log.log( "Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}". format(val_hist["loss"], val_hist["miou"], config["classes"][1]["title"], val_hist["fg_iou"], val_hist["mcc"])) states = { "epoch": epoch + 1, "state_dict": net.state_dict(), "optimizer": optimizer.state_dict() } checkpoint_path = os.path.join( args.out, "checkpoint-{:05d}-of-{:05d}.pth".format( epoch + 1, config["model"]["epochs"])) torch.save(states, checkpoint_path)
def main(args): config = load_config(args.config) num_classes = len(config["classes"]) batch_size = args.batch_size if args.batch_size else config["model"]["batch_size"] tile_size = config["model"]["tile_size"] if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") def map_location(storage, _): return storage.cuda() if torch.cuda.is_available() else storage.cpu() # check checkpoint situation + load if ncessary chkpt = None # no checkpoint if args.checkpoint: # command line checkpoint chkpt = args.checkpoint else: try: # config file checkpoint chkpt = config["checkpoint"]['path'] except: # no checkpoint in config file pass S3_CHECKPOINT = False if chkpt.startswith("s3://"): S3_CHECKPOINT = True # load from s3 chkpt = chkpt[5:] models = [name for _, name, _ in pkgutil.iter_modules([os.path.dirname(robosat_pink.models.__file__)])] if config["model"]["name"] not in [model for model in models]: sys.exit("Unknown model, thoses available are {}".format([model for model in models])) num_channels = 0 for channel in config["channels"]: num_channels += len(channel["bands"]) pretrained = config["model"]["pretrained"] encoder = config["model"]["encoder"] model_module = import_module("robosat_pink.models.{}".format(config["model"]["name"])) net = getattr(model_module, "{}".format(config["model"]["name"].title()))( num_classes=num_classes, num_channels=num_channels, encoder=encoder, pretrained=pretrained ).to(device) net = torch.nn.DataParallel(net) try: if S3_CHECKPOINT: sess = boto3.Session(profile_name=args.aws_profile) fs = s3fs.S3FileSystem(session=sess) with s3fs.S3File(fs, chkpt, 'rb') as C: state = torch.load(io.BytesIO(C.read()), map_location = map_location) else: state = torch.load(chkpt, map_location= map_location) net.load_state_dict(state['state_dict'], strict=False) net.to(device) except FileNotFoundError as f: print("{} checkpoint not found.".format(chkpt)) net.eval() tile_ids_filter = None if args.tile_ids is not None: tile_ids_filter = pd.read_csv(args.tile_ids, names=['ids']).ids.values ## Construct torch Dataset, either from single directory (if args.tiles is given) or from config. Used --tile_ids argument ## to determine how to filter resulting tiles (e.g. to only run prediction on a test set) if args.tiles is not None: imagery_locs = [args.tiles] # use tiledir provided if args.tiles.startswith('s3://'): allImageryDatasets = [S3SlippyMapTiles(args.tiles, mode='multibands', transform=None, aws_profile = args.aws_profile, ids = tile_ids_filter, buffered=args.buffer, buffered_overlap=args.buffer_overlap, tilesize=tile_size, bands=num_channels)] else: allImageryDatasets = [SlippyMapTiles(args.tiles, mode="multibands", transform = None)] # directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=tile_size,re overlap=args.overlap) else: # use config to search for tiles fs = s3fs.S3FileSystem(session = boto3.Session(profile_name = config['dataset']['aws_profile'])) p = pprint.PrettyPrinter() imagery_searchpath = config['dataset']['image_bucket'] + '/' + config['dataset']['imagery_directory_regex'] print("Searching for imagery...({})".format(imagery_searchpath)) imagery_candidates = fs.ls(config['dataset']['image_bucket']) print("candidates:") p.pprint(imagery_candidates) imagery_locs = [c for c in imagery_candidates if match(imagery_searchpath, c)] print("result:") p.pprint(imagery_locs) allImageryDatasets = [ S3SlippyMapTiles("s3://" + loc, mode='multibands', transform=None, aws_profile=args.aws_profile, ids=tile_ids_filter) for loc in imagery_locs ] palette = make_palette(config["classes"][0]["color"]) # don't track tensors with autograd during prediction with torch.no_grad(): for dataset, imageloc in zip(allImageryDatasets, imagery_locs): print("Prediction: {}".format(imageloc)) imageloc_path = imageloc.replace("/", ":") # to not recreate directory structure when saving loader = DataLoader(dataset, batch_size=batch_size, num_workers=args.workers) for tiles, images in tqdm(loader, desc="Eval", unit="batch", ascii=True): tiles = list(zip(tiles[0], tiles[1], tiles[2])) images = images.to(device) outputs = net(images) for i, (tile, prob) in enumerate(zip(tiles, outputs)): tile = Tile(tile[0].item(), tile[1].item(), tile[2].item()) savedir = args.preds # manually compute segmentation mask class probabilities per pixel image = (prob > args.threshold).cpu().numpy().astype(np.uint8) if args.buffer: image = allImageryDatasets[0].unbuffer(image) image = image.squeeze() _write_png(tile, image, os.path.join(savedir, imageloc_path), palette) if(args.create_tif): _write_tif(tile, image, os.path.join(savedir, imageloc_path))
def main(args): if (args.geojson and args.postgis) or (not args.geojson and not args.postgis): sys.exit( "ERROR: Input features to rasterize must be either GeoJSON or PostGIS" ) if args.postgis and not args.pg_dsn: sys.exit( "ERROR: With PostGIS input features, --pg_dsn must be provided") config = load_config(args.config) check_classes(config) palette = make_palette(*[classe["color"] for classe in config["classes"]], complementary=True) burn_value = 1 args.out = os.path.expanduser(args.out) os.makedirs(args.out, exist_ok=True) log = Logs(os.path.join(args.out, "log"), out=sys.stderr) def geojson_parse_polygon(zoom, srid, feature_map, polygon, i): try: if srid != 4326: polygon = [ xy for xy in geojson_reproject( { "type": "feature", "geometry": polygon }, srid, 4326) ][0] for i, ring in enumerate( polygon["coordinates"] ): # GeoJSON coordinates could be N dimensionals polygon["coordinates"][i] = [[ x, y ] for point in ring for x, y in zip([point[0]], [point[1]])] if polygon["coordinates"]: for tile in burntiles.burn([{ "type": "feature", "geometry": polygon }], zoom=zoom): feature_map[mercantile.Tile(*tile)].append({ "type": "feature", "geometry": polygon }) except ValueError: log.log("Warning: invalid feature {}, skipping".format(i)) return feature_map def geojson_parse_geometry(zoom, srid, feature_map, geometry, i): if geometry["type"] == "Polygon": feature_map = geojson_parse_polygon(zoom, srid, feature_map, geometry, i) elif geometry["type"] == "MultiPolygon": for polygon in geometry["coordinates"]: feature_map = geojson_parse_polygon(zoom, srid, feature_map, { "type": "Polygon", "coordinates": polygon }, i) else: log.log( "Notice: {} is a non surfacic geometry type, skipping feature {}" .format(geometry["type"], i)) return feature_map if args.geojson: try: tiles = [ tile for tile in tiles_from_csv(os.path.expanduser(args.cover)) ] zoom = tiles[0].z assert not [tile for tile in tiles if tile.z != zoom] except: sys.exit("ERROR: Inconsistent cover {}".format(args.cover)) feature_map = collections.defaultdict(list) log.log("RoboSat.pink - rasterize - Compute spatial index") for geojson_file in args.geojson: with open(os.path.expanduser(geojson_file)) as geojson: try: feature_collection = json.load(geojson) except: sys.exit("ERROR: {} is not a valid JSON file.".format( geojson_file)) try: crs_mapping = {"CRS84": "4326", "900913": "3857"} srid = feature_collection["crs"]["properties"][ "name"].split(":")[-1] srid = int(srid) if srid not in crs_mapping else int( crs_mapping[srid]) except: srid = int(4326) for i, feature in enumerate( tqdm(feature_collection["features"], ascii=True, unit="feature")): try: if feature["geometry"]["type"] == "GeometryCollection": for geometry in feature["geometry"]["geometries"]: feature_map = geojson_parse_geometry( zoom, srid, feature_map, geometry, i) else: feature_map = geojson_parse_geometry( zoom, srid, feature_map, feature["geometry"], i) except: sys.exit( "ERROR: Unable to parse {} file. Seems not a valid GEOJSON file." .format(geojson_file)) log.log( "RoboSat.pink - rasterize - rasterizing tiles from {} on cover {}". format(args.geojson, args.cover)) with open(os.path.join(os.path.expanduser(args.out), "instances.cover"), mode="w") as cover: for tile in tqdm(list( tiles_from_csv(os.path.expanduser(args.cover))), ascii=True, unit="tile"): try: if tile in feature_map: cover.write("{},{},{} {}{}".format( tile.x, tile.y, tile.z, len(feature_map[tile]), os.linesep)) out = geojson_tile_burn(tile, feature_map[tile], 4326, args.ts, burn_value) else: cover.write("{},{},{} {}{}".format( tile.x, tile.y, tile.z, 0, os.linesep)) out = np.zeros(shape=(args.ts, args.ts), dtype=np.uint8) tile_label_to_file(args.out, tile, palette, out) except: log.log("Warning: Unable to rasterize tile. Skipping {}". format(str(tile))) if args.postgis: try: pg_conn = psycopg2.connect(args.pg_dsn) pg = pg_conn.cursor() except Exception: sys.exit("Unable to connect PostgreSQL: {}".format(args.pg_dsn)) log.log( "RoboSat.pink - rasterize - rasterizing tiles from PostGIS on cover {}" .format(args.cover)) log.log(" SQL {}".format(args.postgis)) try: pg.execute( "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format( args.postgis)) srid = pg.fetchone()[0] except Exception: sys.exit("Unable to retrieve geometry SRID.") for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"): s, w, e, n = mercantile.bounds(tile) raster = np.zeros((args.ts, args.ts)) query = """ WITH bbox AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {} ) AS bbox), bbox_merc AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), 3857) AS bbox), rast_a AS (SELECT ST_AddBand( ST_SetSRID( ST_MakeEmptyRaster({}, {}, ST_Xmin(bbox), ST_Ymax(bbox), (ST_YMax(bbox) - ST_YMin(bbox)) / {}), 3857), '8BUI'::text, 0) AS rast FROM bbox_merc), features AS (SELECT ST_Union(ST_Transform(ST_Force2D(geom), 3857)) AS geom FROM ({}) AS sub, bbox WHERE ST_Intersects(geom, bbox)), rast_b AS (SELECT ST_AsRaster(geom, rast, '8BUI', {}) AS rast FROM features, rast_a WHERE NOT ST_IsEmpty(geom)) SELECT ST_AsBinary(ST_MapAlgebra(rast_a.rast, rast_b.rast, '{}', NULL, 'FIRST')) AS wkb FROM rast_a, rast_b """.format(s, w, e, n, srid, s, w, e, n, args.ts, args.ts, args.ts, args.postgis, burn_value, burn_value) try: pg.execute(query) row = pg.fetchone() if row: raster = np.squeeze(wkb_to_numpy(io.BytesIO(row[0])), axis=2) except Exception: log.log( "Warning: Invalid geometries, skipping {}".format(tile)) pg_conn = psycopg2.connect(args.pg_dsn) pg = pg_conn.cursor() try: tile_label_to_file(args.out, tile, palette, raster) except: log.log( "Warning: Unable to rasterize tile. Skipping {}".format( str(tile))) if not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "./" tiles = [tile for tile in tiles_from_csv(args.cover)] web_ui(args.out, base_url, tiles, tiles, "png", template)
def main(args): config = load_config(args.config) colors = [classe["color"] for classe in config["classes"]] tile_size = args.tile_size try: raster = rasterio_open(args.raster) w, s, e, n = bounds = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) transform, _, _ = calculate_default_transform(raster.crs, "EPSG:3857", raster.width, raster.height, *bounds) except: sys.exit("Error: Unable to load raster or deal with it's projection") tiles = [ mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom) ] tiles_nodata = [] for tile in tqdm(tiles, desc="Tiling", unit="tile", ascii=True): w, s, e, n = tile_bounds = mercantile.xy_bounds(tile) # Inspired by Rio-Tiler, cf: https://github.com/mapbox/rio-tiler/pull/45 warp_vrt = WarpedVRT( raster, crs="EPSG:3857", resampling=Resampling.bilinear, add_alpha=False, transform=from_bounds(*tile_bounds, args.size, args.size), width=math.ceil((e - w) / transform.a), height=math.ceil((s - n) / transform.e), ) data = warp_vrt.read(out_shape=(len(raster.indexes), tile_size, tile_size), window=warp_vrt.window(w, s, e, n)) # If no_data is set, remove all tiles with at least one whole border filled only with no_data (on all bands) if type(args.no_data) is not None and ( np.all(data[:, 0, :] == args.no_data) or np.all(data[:, -1, :] == args.no_data) or np.all(data[:, :, 0] == args.no_data) or np.all(data[:, :, -1] == args.no_data)): tiles_nodata.append(tile) continue C, W, H = data.shape os.makedirs(os.path.join(args.out, str(args.zoom), str(tile.x)), exist_ok=True) path = os.path.join(args.out, str(args.zoom), str(tile.x), str(tile.y)) if args.type == "label": assert C == 1, "Error: Label raster input should be 1 band" ext = "png" img = Image.fromarray(np.squeeze(data, axis=0), mode="P") img.putpalette(make_palette(colors[0], colors[1])) img.save("{}.{}".format(path, ext), optimize=True) elif args.type == "image": assert C == 1 or C == 3, "Error: Image raster input should be either 1 or 3 bands" # GeoTiff could be 16 or 32bits if data.dtype == "uint16": data = np.uint8(data / 256) elif data.dtype == "uint32": data = np.uint8(data / (256 * 256)) if C == 1: ext = "png" Image.fromarray(np.squeeze(data, axis=0), mode="L").save("{}.{}".format(path, ext), optimize=True) elif C == 3: ext = "webp" Image.fromarray(np.moveaxis(data, 0, 2), mode="RGB").save("{}.{}".format(path, ext), optimize=True) if args.web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template tiles = [tile for tile in tiles if tile not in tiles_nodata] base_url = args.web_ui_base_url if args.web_ui_base_url else "./" web_ui(args.out, base_url, tiles, tiles, ext, template)
def main(args): config = load_config(args.config) tile_size = args.tile_size if args.tile_size else config["model"]["tile_size"] colors = [classe["color"] for classe in config["classes"]] os.makedirs(args.out, exist_ok=True) # We can only rasterize all tiles at a single zoom. assert all(tile.z == args.zoom for tile in tiles_from_csv(args.cover)) # Find all tiles the features cover and make a map object for quick lookup. feature_map = collections.defaultdict(list) log = Logs(os.path.join(args.out, "log"), out=sys.stderr) def parse_polygon(feature_map, polygon, i): try: for i, ring in enumerate(polygon["coordinates"]): # GeoJSON coordinates could be N dimensionals polygon["coordinates"][i] = [[x, y] for point in ring for x, y in zip([point[0]], [point[1]])] for tile in burntiles.burn([{"type": "feature", "geometry": polygon}], zoom=args.zoom): feature_map[mercantile.Tile(*tile)].append({"type": "feature", "geometry": polygon}) except ValueError: log.log("Warning: invalid feature {}, skipping".format(i)) return feature_map def parse_geometry(feature_map, geometry, i): if geometry["type"] == "Polygon": feature_map = parse_polygon(feature_map, geometry, i) elif geometry["type"] == "MultiPolygon": for polygon in geometry["coordinates"]: feature_map = parse_polygon(feature_map, {"type": "Polygon", "coordinates": polygon}, i) else: log.log("Notice: {} is a non surfacic geometry type, skipping feature {}".format(geometry["type"], i)) return feature_map for feature in args.features: with open(feature) as f: fc = json.load(f) for i, feature in enumerate(tqdm(fc["features"], ascii=True, unit="feature")): if feature["geometry"]["type"] == "GeometryCollection": for geometry in feature["geometry"]["geometries"]: feature_map = parse_geometry(feature_map, geometry, i) else: feature_map = parse_geometry(feature_map, feature["geometry"], i) # Burn features to tiles and write to a slippy map directory. for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"): if tile in feature_map: out = burn(tile, feature_map[tile], tile_size) else: out = np.zeros(shape=(tile_size, tile_size), dtype=np.uint8) out_dir = os.path.join(args.out, str(tile.z), str(tile.x)) os.makedirs(out_dir, exist_ok=True) out_path = os.path.join(out_dir, "{}.png".format(tile.y)) if os.path.exists(out_path): prev = np.array(Image.open(out_path)) out = np.maximum(out, prev) out = Image.fromarray(out, mode="P") out_path = os.path.join(args.out, str(tile.z), str(tile.x)) os.makedirs(out_path, exist_ok=True) out.putpalette(complementary_palette(make_palette(colors[0], colors[1]))) out.save(os.path.join(out_path, "{}.png".format(tile.y)), optimize=True) if args.web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "./" tiles = [tile for tile in tiles_from_csv(args.cover)] web_ui(args.out, base_url, tiles, tiles, "png", template)
def main(args): config = load_config(args.config) num_classes = len(config["classes"]) batch_size = args.batch_size if args.batch_size else config["model"][ "batch_size"] tile_size = args.tile_size if args.tile_size else config["model"][ "tile_size"] if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") def map_location(storage, _): return storage.cuda() if torch.cuda.is_available() else storage.cpu() # https://github.com/pytorch/pytorch/issues/7178 # chkpt = torch.load(args.checkpoint, map_location=map_location) S3_CHECKPOINT = False chkpt = args.checkpoint if chkpt.startswith("s3://"): S3_CHECKPOINT = True # load from s3 chkpt = chkpt[5:] models = [ name for _, name, _ in pkgutil.iter_modules( [os.path.dirname(robosat_pink.models.__file__)]) ] if config["model"]["name"] not in [model for model in models]: sys.exit("Unknown model, thoses available are {}".format( [model for model in models])) num_channels = 0 for channel in config["channels"]: num_channels += len(channel["bands"]) pretrained = config["model"]["pretrained"] encoder = config["model"]["encoder"] model_module = import_module("robosat_pink.models.{}".format( config["model"]["name"])) net = getattr(model_module, "{}".format(config["model"]["name"].title()))( num_classes=num_classes, num_channels=num_channels, encoder=encoder, pretrained=pretrained).to(device) net = torch.nn.DataParallel(net) try: if S3_CHECKPOINT: sess = boto3.Session(profile_name=args.aws_profile) fs = s3fs.S3FileSystem(session=sess) with s3fs.S3File(fs, chkpt, 'rb') as C: state = torch.load(io.BytesIO(C.read()), map_location=map_location) else: state = torch.load(chkpt, map_location=map_location) net.load_state_dict(state['state_dict']) net.to(device) except FileNotFoundError as f: print("{} checkpoint not found.".format(CHECKPOINT)) net.eval() # # mean = np.array([[[8237.95084794]], # # [[6467.98702156]], # # [[6446.61743148]], # # [[4520.95360105]]]) # std = array([[[12067.03414753]], # # [[ 8810.00542703]], # # [[10710.64289882]], # # [[ 9024.92028515]]]) # #transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)]) # transform = A.Compose([ # A.Normalize(mean = mean, std = std, max_pixel_value = 1.0), # A.ToFloat() # ]) if args.tiles.startswith('s3://'): directory = S3SlippyMapTiles(args.tiles, mode='multibands', transform=None, aws_profile=args.aws_profile) else: directory = SlippyMapTiles(args.tiles, mode="multibands", transform=transform) # directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=tile_size, overlap=args.overlap) loader = DataLoader(directory, batch_size=batch_size, num_workers=args.workers) palette = make_palette(config["classes"][0]["color"]) # don't track tensors with autograd during prediction with torch.no_grad(): for tiles, images in tqdm(loader, desc="Eval", unit="batch", ascii=True): tiles = list(zip(tiles[0], tiles[1], tiles[2])) images = images.to(device) outputs = net(images) print(len(tiles), len(outputs)) for tile, prob in zip([tiles], outputs): savedir = args.probs x = tile[0].item() y = tile[1].item() z = tile[2].item() # manually compute segmentation mask class probabilities per pixel image = (prob > args.threshold).astype(np.uint8) out = Image.fromarray(image, mode="P") out.putpalette(palette) os.makedirs(os.path.join(args.probs, str(z), str(x)), exist_ok=True) path = os.path.join(args.probs, str(z), str(x), str(y) + ".png") out.save(path, optimize=True) if args.web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "./" tiles = [tile for tile, _ in tiles_from_slippy_map(args.tiles)] web_ui(args.probs, base_url, tiles, tiles, "png", template)
#get_ipython().run_line_magic('matplotlib', 'inline') import rasterio as rio from matplotlib import pyplot as plt import rasterio.plot import os from datetime import datetime as dt from rasterio.io import MemoryFile import tempfile sys.path.append("../model/robosat_pink/") from robosat_pink.config import load_config # original with 5/28 config_location= '/home/ubuntu/planet-snowcover/experiments/co-train.toml' # revised with neighboring watershed - revered May 1st as dont have DEM or VEG for that #config_location= '/home/ubuntu/planet-snowcover/experiments/co-train-neigh.toml' config = load_config(config_location) p = pprint.PrettyPrinter() fs = s3fs.S3FileSystem(session = boto3.Session(profile_name = config['dataset']['aws_profile'])) imagery_searchpath = config['dataset']['image_bucket'] + '/' + config['dataset']['imagery_directory_regex'] print("Searching for imagery...({})".format(imagery_searchpath)) imagery_candidates = fs.ls(config['dataset']['image_bucket']) #print("candidates:") #p.pprint(imagery_candidates) imagery_locs = [c for c in imagery_candidates if match(imagery_searchpath, c)] print("result:") p.pprint(imagery_locs)
def main(args): config = load_config(args.config) args.out = os.path.expanduser(args.out) args.workers = torch.cuda.device_count() * 2 if torch.device( "cuda") and not args.workers else args.workers config["model"][ "loader"] = args.loader if args.loader else config["model"]["loader"] config["model"]["bs"] = args.bs if args.bs else config["model"]["bs"] config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"] config["model"]["ts"] = args.ts if args.ts else config["model"]["ts"] config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"] config["model"][ "loss"] = args.loss if args.loss else config["model"]["loss"] config["model"]["da"] = args.da if args.da else config["model"]["da"] config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"] check_classes(config) check_channels(config) check_model(config) if not os.path.isdir(os.path.expanduser(args.dataset)): sys.exit("ERROR: dataset {} is not a directory".format(args.dataset)) log = Logs(os.path.join(args.out, "log")) if torch.cuda.is_available(): log.log("RoboSat.pink - training on {} GPUs, with {} workers".format( torch.cuda.device_count(), args.workers)) log.log("(Torch:{} Cuda:{} CudNN:{})".format( torch.__version__, torch.version.cuda, torch.backends.cudnn.version())) device = torch.device("cuda") torch.backends.cudnn.benchmark = True else: log.log("RoboSat.pink - training on CPU, with {} workers - (Torch:{})". format(args.workers, torch.__version__)) log.log( "WARNING: Are you really sure sure about not training on GPU ?") device = torch.device("cpu") try: loader = import_module("robosat_pink.loaders.{}".format( config["model"]["loader"].lower())) loader_train = getattr(loader, config["model"]["loader"])( config, config["model"]["ts"], os.path.join(args.dataset, "training"), "train") loader_val = getattr(loader, config["model"]["loader"])( config, config["model"]["ts"], os.path.join(args.dataset, "validation"), "train") except: sys.exit("ERROR: Unable to load data loaders") try: model_module = import_module("robosat_pink.models.{}".format( config["model"]["nn"].lower())) except: sys.exit("ERROR: Unable to load {} model".format( config["model"]["nn"])) nn = getattr(model_module, config["model"]["nn"])( loader_train.shape_in, loader_train.shape_out, config["model"]["pretrained"]).to(device) nn = torch.nn.DataParallel(nn) optimizer = Adam(nn.parameters(), lr=config["model"]["lr"]) resume = 0 if args.checkpoint: try: chkpt = torch.load(os.path.expanduser(args.checkpoint), map_location=device) nn.load_state_dict(chkpt["state_dict"]) log.log("Using checkpoint: {}".format(args.checkpoint)) except: sys.exit("ERROR: Unable to load {} checkpoint".format( args.checkpoint)) if args.resume: optimizer.load_state_dict(chkpt["optimizer"]) resume = chkpt["epoch"] if resume >= args.epochs: sys.exit( "ERROR: Epoch {} already reached by the given checkpoint". format(config["model"]["epochs"])) try: loss_module = import_module("robosat_pink.losses.{}".format( config["model"]["loss"].lower())) criterion = getattr(loss_module, config["model"]["loss"])().to(device) except: sys.exit("ERROR: Unable to load {} loss".format( config["model"]["loss"])) bs = config["model"]["bs"] train_loader = DataLoader(loader_train, batch_size=bs, shuffle=True, drop_last=True, num_workers=args.workers) val_loader = DataLoader(loader_val, batch_size=bs, shuffle=False, drop_last=True, num_workers=args.workers) log.log("--- Input tensor from Dataset: {} ---".format(args.dataset)) num_channel = 1 # 1-based numerotation for channel in config["channels"]: for band in channel["bands"]: log.log("Channel {}:\t\t {}[band: {}]".format( num_channel, channel["name"], band)) num_channel += 1 log.log("--- Hyper Parameters ---") for hp in config["model"]: log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp])) for epoch in range(resume, args.epochs): UUID = uuid.uuid1() log.log("---{}Epoch: {}/{} -- UUID: {}".format(os.linesep, epoch + 1, args.epochs, UUID)) train(train_loader, config, log, device, nn, optimizer, criterion) validate(val_loader, config, log, device, nn, criterion) try: # https://github.com/pytorch/pytorch/issues/9176 nn_doc = nn.module.doc nn_version = nn.module.version except AttributeError: nn_version = nn.version nn_doc == nn.doc states = { "uuid": UUID, "model_version": nn_version, "producer_name": "RoboSat.pink", "producer_version": "0.4.0", "model_licence": "MIT", "domain": "pink.RoboSat", # reverse-DNS "doc_string": nn_doc, "shape_in": loader_train.shape_in, "shape_out": loader_train.shape_out, "state_dict": nn.state_dict(), "epoch": epoch + 1, "nn": config["model"]["nn"], "optimizer": optimizer.state_dict(), "loader": config["model"]["loader"], } checkpoint_path = os.path.join( args.out, "checkpoint-{:05d}.pth".format(epoch + 1)) try: torch.save(states, checkpoint_path) except: sys.exit( "ERROR: Unable to save checkpoint {}".format(checkpoint_path))
def main(args): config = load_config(args.config) num_classes = len(config["classes"]) batch_size = args.batch_size if args.batch_size else config["model"][ "batch_size"] tile_size = args.tile_size if args.tile_size else config["model"][ "tile_size"] if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") def map_location(storage, _): return storage.cuda() if torch.cuda.is_available() else storage.cpu() # https://github.com/pytorch/pytorch/issues/7178 chkpt = torch.load(args.checkpoint, map_location=map_location) models = [ name for _, name, _ in pkgutil.iter_modules( [os.path.dirname(robosat_pink.models.__file__)]) ] if config["model"]["name"] not in [model for model in models]: sys.exit("Unknown model, thoses available are {}".format( [model for model in models])) std = [] mean = [] num_channels = 0 for channel in config["channels"]: std.extend(channel["std"]) mean.extend(channel["mean"]) num_channels += len(channel["bands"]) encoder = config["model"]["encoder"] pretrained = config["model"]["pretrained"] model_module = import_module("robosat_pink.models.{}".format( config["model"]["name"])) net = getattr(model_module, "{}".format(config["model"]["name"].title()))( num_classes=num_classes, num_channels=num_channels, encoder=encoder, pretrained=pretrained).to(device) net = torch.nn.DataParallel(net) net.load_state_dict(chkpt["state_dict"]) net.eval() transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)]) directory = BufferedSlippyMapTiles(args.tiles, transform=transform, size=tile_size, overlap=args.overlap) loader = DataLoader(directory, batch_size=batch_size, num_workers=args.workers) palette = make_palette(config["classes"][0]["color"], config["classes"][1]["color"]) # don't track tensors with autograd during prediction with torch.no_grad(): for images, tiles in tqdm(loader, desc="Eval", unit="batch", ascii=True): images = images.to(device) outputs = net(images) # manually compute segmentation mask class probabilities per pixel probs = torch.nn.functional.softmax(outputs, dim=1).data.cpu().numpy() for tile, prob in zip(tiles, probs): x, y, z = list(map(int, tile)) # we predicted on buffered tiles; now get back probs for original image prob = directory.unbuffer(prob) assert prob.shape[ 0] == 2, "single channel requires binary model" assert np.allclose( np.sum(prob, axis=0), 1.0 ), "single channel requires probabilities to sum up to one" image = np.around(prob[1:, :, :]).astype(np.uint8).squeeze() out = Image.fromarray(image, mode="P") out.putpalette(palette) os.makedirs(os.path.join(args.probs, str(z), str(x)), exist_ok=True) path = os.path.join(args.probs, str(z), str(x), str(y) + ".png") out.save(path, optimize=True) if args.web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "./" tiles = [tile for tile, _ in tiles_from_slippy_map(args.tiles)] web_ui(args.probs, base_url, tiles, tiles, "png", template)
def main(args): config = load_config(args.config) check_channels(config) check_classes(config) args.workers = torch.cuda.device_count() * 2 if torch.device( "cuda") and not args.workers else args.workers log = Logs(os.path.join(args.out, "log")) if torch.cuda.is_available(): log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format( torch.cuda.device_count(), args.workers)) log.log("(Torch:{} Cuda:{} CudNN:{})".format( torch.__version__, torch.version.cuda, torch.backends.cudnn.version())) device = torch.device("cuda") torch.backends.cudnn.benchmark = True else: log.log("RoboSat.pink - predict on CPU, with {} workers".format( args.workers)) device = torch.device("cpu") try: chkpt = torch.load(args.checkpoint, map_location=device) assert chkpt["producer_name"] == "RoboSat.pink" model_module = import_module("robosat_pink.models.{}".format( chkpt["nn"].lower())) nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"]).to(device) nn = torch.nn.DataParallel(nn) nn.load_state_dict(chkpt["state_dict"]) nn.eval() except: sys.exit("ERROR: Unable to load {} checkpoint.".format( args.checkpoint)) log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"])) try: loader_module = import_module("robosat_pink.loaders.{}".format( chkpt["loader"].lower())) loader_predict = getattr(loader_module, chkpt["loader"])(config, chkpt["shape_in"][1:3], args.tiles, mode="predict") except: sys.exit("ERROR: Unable to load {} data loader.".format( chkpt["loader"])) loader = DataLoader(loader_predict, batch_size=args.bs, num_workers=args.workers) palette = make_palette(config["classes"][0]["color"], config["classes"][1]["color"]) with torch.no_grad( ): # don't track tensors with autograd during prediction for images, tiles in tqdm(loader, desc="Eval", unit="batch", ascii=True): images = images.to(device) try: outputs = nn(images) probs = torch.nn.functional.softmax(outputs, dim=1).data.cpu().numpy() except: log.log("WARNING: Skipping batch:") for tile, prob in zip(tiles, probs): log.log(" - {}".format(str(tile))) continue for tile, prob in zip(tiles, probs): try: x, y, z = list(map(int, tile)) mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze() tile_label_to_file(args.out, mercantile.Tile(x, y, z), palette, mask) except: log.log("WARNING: Skipping tile {}".format(str(tile))) if not args.no_web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "./" tiles = [tile for tile, _ in tiles_from_slippy_map(args.out)] web_ui(args.out, base_url, tiles, tiles, "png", template)