コード例 #1
0
def main():
    ap = argparse.ArgumentParser(description='run model on tileset')
    ap.add_argument('model', help='path to model checkpoint')
    ap.add_argument('config', help='path to model config')
    ap.add_argument('tiles', help='path to XYZ tile folder')
    ap.add_argument('outputdir', help='name for tile output')
    ap.add_argument('--aws_profile',
                    help='AWS Profile Name',
                    default='default')

    args = ap.parse_args()

    config = load_config(args.config)
    tiles = S3SlippyMapTiles(args.tiles,
                             mode='multibands',
                             aws_profile=args.aws_profile)
    net = model(config, args.model)

    loader = DataLoader(tiles,
                        batch_size=config['model']['batch_size'],
                        shuffle=True,
                        num_workers=1)

    palette = make_palette(config["classes"][0]["color"])

    fs = s3fs.S3FileSystem(session=boto3.Session(profile_name='esip'))

    outputdir = args.outputdir[5:] + '/' + os.path.basename(args.tiles)
    print("Saving predictions to {}.".format(outputdir))

    predict(net, loader, outputdir, palette, fs)
コード例 #2
0
ファイル: vectorize.py プロジェクト: yzuaiyou/robosat.pink
def main(args):
    config = load_config(args.config)
    check_classes(config)
    index = [
        i for i in (list(range(len(config["classes"]))))
        if config["classes"][i]["title"] == args.type
    ]
    if not index:
        sys.exit(
            "ERROR: Requested type {} not found among classes title in the config file."
            .format(args.type))

    print("RoboSat.pink - vectorize {} from {}".format(args.type, args.masks))

    with open(args.out, "w", encoding="utf-8") as out:
        first = True
        out.write('{"type":"FeatureCollection","features":[')

        for tile, path in tqdm(list(tiles_from_slippy_map(args.masks)),
                               ascii=True,
                               unit="mask"):
            try:
                features = (np.array(Image.open(path).convert("P"),
                                     dtype=np.uint8) == index).astype(np.uint8)
                try:
                    C, W, H = features.shape
                except:
                    W, H = features.shape
                transform = rasterio.transform.from_bounds(
                    (*mercantile.bounds(tile.x, tile.y, tile.z)), W, H)

                for shape, value in rasterio.features.shapes(
                        features, transform=transform):
                    prop = '"properties":{{"x":{},"y":{},"z":{}}}'.format(
                        int(tile.x), int(tile.y), int(tile.z))
                    geom = '"geometry":{{"type": "Polygon", "coordinates":{}}}'.format(
                        json.dumps(shape["coordinates"]))
                    out.write('{}{{"type":"Feature",{},{}}}'.format(
                        "," if not first else "", geom, prop))
                    first = False
            except:
                sys.exit("ERROR: Unable to vectorize tile {}.".format(
                    str(tile)))

        out.write("]}")
コード例 #3
0
ファイル: features.py プロジェクト: ajijohn/planet-snowcover
def main(args):

    module_search_path = [args.path] if args.path else []
    module_search_path.append(
        os.path.join(Path(__file__).parent.parent, "features"))
    modules = [(path, name)
               for path, name, _ in pkgutil.iter_modules(module_search_path)
               if name != "core"]
    if args.type not in [name for _, name in modules]:
        sys.exit("Unknown type, thoses available are {}".format(
            [name for _, name in modules]))

    config = load_config(args.config)
    labels = config["classes"]["titles"]
    if args.type not in labels:
        sys.exit(
            "The type you asked is not consistent with yours classes in the config file provided."
        )
    index = labels.index(args.type)

    if args.path:
        sys.path.append(args.path)
        module = import_module(args.type)
    else:
        module = import_module("robosat_pink.features.{}".format(args.type))

    handler = getattr(module, "{}Handler".format(args.type.title()))()

    for tile, path in tqdm(list(tiles_from_slippy_map(args.masks)),
                           ascii=True,
                           unit="mask"):
        image = np.array(Image.open(path).convert("P"), dtype=np.uint8)
        mask = (image == index).astype(np.uint8)
        handler.apply(tile, mask)

    handler.save(args.out)
コード例 #4
0
def main(args):
    config = load_config(args.config)
    print(config)

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        device = torch.device("cuda")

        torch.backends.cudnn.benchmark = True
        log.log("RoboSat - training on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
    else:
        device = torch.device("cpu")
        log.log("RoboSat - training on CPU, with {} workers".format(
            args.workers))

    num_classes = len(config["classes"])
    num_channels = 0
    for channel in config["channels"]:
        num_channels += len(channel["bands"])
    pretrained = config["model"]["pretrained"]
    encoder = config["model"]["encoder"]

    models = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.models.__file__)])
    ]
    if config["model"]["name"] not in [model for model in models]:
        sys.exit("Unknown model, thoses available are {}".format(
            [model for model in models]))

    model_module = import_module("robosat_pink.models.{}".format(
        config["model"]["name"]))
    net = getattr(model_module, "{}".format(config["model"]["name"].title()))(
        num_classes=num_classes,
        num_channels=num_channels,
        encoder=encoder,
        pretrained=pretrained).to(device)

    net = torch.nn.DataParallel(net)
    optimizer = Adam(net.parameters(),
                     lr=config["model"]["lr"],
                     weight_decay=config["model"]["decay"])

    resume = 0

    # check checkpoint situation  + load if ncessary
    checkpoint = None  # no checkpoint
    if args.checkpoint:  # command line checkpoint
        checkpoint = args.checkpoint
    try:  # config file checkpoint
        checkpoint = config["checkpoint"]['path']
    except:
        # no checkpoint in config file
        pass

    S3_CHECKPOINT = False
    if checkpoint:

        if checkpoint.startswith("s3://"):
            S3_CHECKPOINT = True
            # load from s3
            checkpoint = checkpoint[5:]
            sess = boto3.Session(profile_name=config['dataset']['aws_profile'])
            fs = s3fs.S3FileSystem(session=sess)
            s3ckpt = s3fs.S3File(fs, checkpoint, 'rb')

        def map_location(storage, _):
            return storage.cuda() if torch.cuda.is_available(
            ) else storage.cpu()

    if checkpoint is not None:

        def map_location(storage, _):
            return storage.cuda() if torch.cuda.is_available(
            ) else storage.cpu()

        try:
            if S3_CHECKPOINT:
                with s3fs.S3File(fs, checkpoint, 'rb') as C:
                    state = torch.load(io.BytesIO(C.read()),
                                       map_location=map_location)
            else:
                state = torch.load(checkpoint)
            optimizer.load_state_dict(state['optimizer'])
            net.load_state_dict(state['state_dict'])
            net.to(device)
        except FileNotFoundError as f:
            print("{} checkpoint not found.".format(CHECKPOINT))

        log.log("Using checkpoint: {}".format(checkpoint))

    losses = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.losses.__file__)])
    ]
    if config["model"]["loss"] not in [loss for loss in losses]:
        sys.exit("Unknown loss, thoses available are {}".format(
            [loss for loss in losses]))

    loss_module = import_module("robosat_pink.losses.{}".format(
        config["model"]["loss"]))
    criterion = getattr(loss_module, "{}".format(
        config["model"]["loss"].title()))().to(device)

    train_loader, val_loader = get_dataset_loaders(config,
                                                   args.workers,
                                                   idDir=args.out)

    if resume >= config["model"]["epochs"]:
        sys.exit(
            "Error: Epoch {} set in {} already reached by the checkpoint provided"
            .format(config["model"]["epochs"], args.config))

    log.log("")
    log.log("--- Input tensor from Dataset: {} ---".format(
        config["dataset"]["image_bucket"] + '/' +
        config['dataset']['imagery_directory_regex']))

    log.log("")
    log.log("--- Hyper Parameters ---")
    log.log("Model:\t\t\t {}".format(config["model"]["name"]))
    log.log("Encoder model:\t\t {}".format(config["model"]["encoder"]))
    log.log("Loss function:\t\t {}".format(config["model"]["loss"]))
    log.log("ResNet pre-trained:\t {}".format(config["model"]["pretrained"]))
    log.log("Batch Size:\t\t {}".format(config["model"]["batch_size"]))
    log.log("Tile Size:\t\t {}".format(config["model"]["tile_size"]))
    log.log("Data Augmentation:\t {}".format(
        config["model"]["data_augmentation"]))
    log.log("Learning Rate:\t\t {}".format(config["model"]["lr"]))
    log.log("Weight Decay:\t\t {}".format(config["model"]["decay"]))
    log.log("")

    for epoch in range(resume, config["model"]["epochs"]):

        log.log("---")
        log.log("Epoch: {}/{}".format(epoch + 1, config["model"]["epochs"]))

        train_hist = train(train_loader, num_classes, device, net, optimizer,
                           criterion)
        log.log(
            "Train    loss: {:.4f}, mIoU: {:.3f}, IoU: {:.3f}, precision:  {:.3f}, recall: {:.3f}"
            .format(
                train_hist["loss"],
                train_hist["miou"],
                train_hist["fg_iou"],
                train_hist["precision"],
                train_hist["recall"],
            ))

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        log.log(
            "Validate loss: {:.4f}, mIoU: {:.3f}, IoU: {:.3f}, precision:  {:.3f}, recall: {:.3f}"
            .format(
                train_hist["loss"],
                train_hist["miou"],
                train_hist["fg_iou"],
                train_hist["precision"],
                train_hist["recall"],
            ))

        states = {
            "epoch": epoch + 1,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        checkpoint_path = os.path.join(
            args.out, "checkpoint-{:05d}-of-{:05d}.pth".format(
                epoch + 1, config["model"]["epochs"]))
        torch.save(states, checkpoint_path)
コード例 #5
0
ファイル: compare.py プロジェクト: martham93/robosat.pink
def main(args):

    if not args.masks or not args.labels or not args.config:
        if args.mode == "list":
            sys.exit(
                "Parameters masks, labels and config, are all mandatories in list mode."
            )
        if args.minimum_fg > 0 or args.maximum_fg < 100 or args.minimum_qod > 0 or args.maximum_qod < 100:
            sys.exit(
                "Parameters masks, labels and config, are all mandatories in QoD filtering."
            )

    if args.images:
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])]
        for image in args.images[1:]:
            assert sorted(tiles) == sorted([
                tile for tile, _ in tiles_from_slippy_map(image)
            ]), "inconsistent coverages"

    if args.labels and args.masks:
        tiles_masks = [tile for tile, _ in tiles_from_slippy_map(args.masks)]
        tiles_labels = [tile for tile, _ in tiles_from_slippy_map(args.labels)]
        if args.images:
            assert sorted(tiles) == sorted(tiles_masks) == sorted(
                tiles_labels), "inconsistent coverages"
        else:
            assert sorted(tiles_masks) == sorted(
                tiles_labels), "inconsistent coverages"
            tiles = tiles_masks

    if args.mode == "list":
        out = open(args.out, mode="w")
        if args.geojson:
            out.write('{"type":"FeatureCollection","features":[')
            first = True

    tiles_compare = []
    for tile in tqdm(list(tiles), desc="Compare", unit="tile", ascii=True):

        x, y, z = list(map(str, tile))

        if args.masks and args.labels and args.config:
            titles = [
                classe["title"]
                for classe in load_config(args.config)["classes"]
            ]
            dist, fg_ratio, qod = compare(args.masks, args.labels, tile,
                                          titles)
            if not args.minimum_fg <= fg_ratio <= args.maximum_fg or not args.minimum_qod <= qod <= args.maximum_qod:
                continue

        tiles_compare.append(tile)

        if args.mode == "side":

            for i, root in enumerate(args.images):
                img = tile_image(tile_from_slippy_map(root, x, y, z)[1])

                if i == 0:
                    side = np.zeros(
                        (img.shape[0], img.shape[1] * len(args.images), 3))
                    side = np.swapaxes(side, 0, 1) if args.vertical else side
                    image_shape = img.shape
                else:
                    assert image_shape == img.shape, "Unconsistent image size to compare"

                if args.vertical:
                    side[i * image_shape[0]:(i + 1) *
                         image_shape[0], :, :] = img
                else:
                    side[:,
                         i * image_shape[0]:(i + 1) * image_shape[0], :] = img

            os.makedirs(os.path.join(args.out, z, x), exist_ok=True)
            side = Image.fromarray(np.uint8(side))
            side.save(os.path.join(args.out, z, x, "{}.{}".format(y,
                                                                  args.ext)),
                      optimize=True)

        elif args.mode == "stack":

            for i, root in enumerate(args.images):
                img = tile_image(tile_from_slippy_map(root, x, y, z)[1])

                if i == 0:
                    image_shape = img.shape[0:2]
                    stack = img / len(args.images)
                else:
                    assert image_shape == img.shape[
                        0:2], "Unconsistent image size to compare"
                    stack = stack + (img / len(args.images))

            os.makedirs(os.path.join(args.out, str(z), str(x)), exist_ok=True)
            stack = Image.fromarray(np.uint8(stack))
            stack.save(os.path.join(args.out, str(z), str(x),
                                    "{}.{}".format(y, args.ext)),
                       optimize=True)

        elif args.mode == "list":
            if args.geojson:
                prop = '"properties":{{"x":{},"y":{},"z":{},"fg":{:.1f},"qod":{:.1f}}}'.format(
                    x, y, z, fg_ratio, qod)
                geom = '"geometry":{}'.format(
                    json.dumps(feature(tile, precision=6)["geometry"]))
                out.write('{}{{"type":"Feature",{},{}}}'.format(
                    "," if not first else "", geom, prop))
                first = False
            else:
                out.write("{},{},{}\t\t{:.1f}\t\t{:.1f}{}".format(
                    x, y, z, fg_ratio, qod, os.linesep))

    if args.mode == "list":
        if args.geojson:
            out.write("]}")
        out.close()

    base_url = args.web_ui_base_url if args.web_ui_base_url else "./"

    if args.mode == "side" and args.web_ui:
        template = "compare.html" if not args.web_ui_template else args.web_ui_template
        web_ui(args.out, base_url, None, tiles_compare, args.ext, template)

    if args.mode == "stack" and args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])]
        web_ui(args.out, base_url, tiles, tiles_compare, args.ext, template)
コード例 #6
0
def main(args):

    if not args.workers:
        args.workers = max(1, math.floor(os.cpu_count() * 0.5))

    if args.label:
        config = load_config(args.config)
        check_classes(config)
        colors = [classe["color"] for classe in config["classes"]]
        palette = make_palette(*colors)

    splits_path = os.path.join(os.path.expanduser(args.out), ".splits")
    tiles_map = {}

    print("RoboSat.pink - tile on CPU, with {} workers".format(args.workers))

    bands = -1
    for path in args.rasters:
        try:
            raster = rasterio_open(path)
            w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds)
        except:
            sys.exit("Error: Unable to load raster {} or deal with it's projection".format(args.raster))

        if bands != -1:
            assert bands == len(raster.indexes), "Coverage must be bands consistent"
        bands = len(raster.indexes)

        tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)]
        for tile in tiles:
            tile_key = (str(tile.x), str(tile.y), str(tile.z))
            if tile_key not in tiles_map.keys():
                tiles_map[tile_key] = []
            tiles_map[tile_key].append(path)

    if args.label:
        ext = "png"
        bands = 1
    if not args.label:
        if bands == 1:
            ext = "png"
        if bands == 3:
            ext = "webp"
        if bands > 3:
            ext = "tiff"

    tiles = []
    progress = tqdm(total=len(tiles_map), ascii=True, unit="tile")
    # Begin to tile plain tiles
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(path):

            raster = rasterio_open(path)
            w, s, e, n = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds)
            transform, _, _ = calculate_default_transform(raster.crs, "EPSG:3857", raster.width, raster.height, w, s, e, n)
            tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)]
            tiled = []

            for tile in tiles:

                try:
                    w, s, e, n = mercantile.xy_bounds(tile)

                    # inspired by rio-tiler, cf: https://github.com/mapbox/rio-tiler/pull/45
                    warp_vrt = WarpedVRT(
                        raster,
                        crs="epsg:3857",
                        resampling=Resampling.bilinear,
                        add_alpha=False,
                        transform=from_bounds(w, s, e, n, args.ts, args.ts),
                        width=math.ceil((e - w) / transform.a),
                        height=math.ceil((s - n) / transform.e),
                    )
                    data = warp_vrt.read(
                        out_shape=(len(raster.indexes), args.ts, args.ts), window=warp_vrt.window(w, s, e, n)
                    )
                    image = np.moveaxis(data, 0, 2)  # C,H,W -> H,W,C
                except:
                    sys.exit("Error: Unable to tile {} from raster {}.".format(str(tile), raster))

                tile_key = (str(tile.x), str(tile.y), str(tile.z))
                if not args.label and len(tiles_map[tile_key]) == 1 and is_border(image):
                    progress.update()
                    continue

                if len(tiles_map[tile_key]) > 1:
                    out = os.path.join(splits_path, str(tiles_map[tile_key].index(path)))
                else:
                    out = args.out

                x, y, z = map(int, tile)

                if not args.label:
                    ret = tile_image_to_file(out, mercantile.Tile(x=x, y=y, z=z), image)
                if args.label:
                    ret = tile_label_to_file(out, mercantile.Tile(x=x, y=y, z=z), palette, image)

                if not ret:
                    sys.exit("Error: Unable to write tile {} from raster {}.".format(str(tile), raster))

                if len(tiles_map[tile_key]) == 1:
                    progress.update()
                    tiled.append(mercantile.Tile(x=x, y=y, z=z))

            return tiled

        for tiled in executor.map(worker, args.rasters):
            if tiled is not None:
                tiles.extend(tiled)

    # Aggregate remaining tiles splits
    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(tile_key):

            if len(tiles_map[tile_key]) == 1:
                return

            image = np.zeros((args.ts, args.ts, bands), np.uint8)

            x, y, z = map(int, tile_key)
            for i in range(len(tiles_map[tile_key])):
                root = os.path.join(splits_path, str(i))
                _, path = tile_from_slippy_map(root, x, y, z)

                if not args.label:
                    split = tile_image_from_file(path)
                if args.label:
                    split = tile_label_from_file(path)
                    split = split.reshape((args.ts, args.ts, 1))  # H,W -> H,W,C

                assert image.shape == split.shape
                image[:, :, :] += split[:, :, :]

            if not args.label and is_border(image):
                progress.update()
                return

            tile = mercantile.Tile(x=x, y=y, z=z)

            if not args.label:
                ret = tile_image_to_file(args.out, tile, image)

            if args.label:
                ret = tile_label_to_file(args.out, tile, palette, image)

            if not ret:
                sys.exit("Error: Unable to write tile {}.".format(str(tile_key)))

            progress.update()
            return tile

        for tiled in executor.map(worker, tiles_map.keys()):
            if tiled is not None:
                tiles.append(tiled)

        if splits_path and os.path.isdir(splits_path):
            shutil.rmtree(splits_path)  # Delete suffixes dir if any

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        web_ui(args.out, base_url, tiles, tiles, ext, template)
コード例 #7
0
ファイル: export.py プロジェクト: opengbdx/robosat.pink
def main(args):
    config = load_config(args.config)

    num_classes = len(config["classes"])
    num_channels = 0
    for channel in config["channels"]:
        num_channels += len(channel["bands"])

    export_channels = num_channels if not args.export_channels else args.export_channels
    if export_channels > num_channels:
        sys.exit(
            "Error: Attempt to export more channels than thoses dataset provide."
        )

    def map_location(storage, _):
        return storage.cuda() if torch.cuda.is_available() else storage.cpu()

    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    models = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.models.__file__)])
    ]
    if config["model"]["name"] not in [model for model in models]:
        sys.exit("Unknown model, thoses available are {}".format(
            [model for model in models]))

    encoder = config["model"]["encoder"]
    pretrained = config["model"]["pretrained"]

    model_module = import_module("robosat_pink.models.{}".format(
        config["model"]["name"]))
    net = getattr(model_module, "{}".format(config["model"]["name"].title()))(
        num_classes=num_classes,
        num_channels=num_channels,
        encoder=encoder,
        pretrained=pretrained).to(device)

    chkpt = torch.load(args.checkpoint, map_location=map_location)
    net = torch.nn.DataParallel(net)
    net.load_state_dict(chkpt["state_dict"])

    if export_channels < num_channels:
        weights = torch.zeros((64, export_channels, 7, 7))
        weights.data = net.module.resnet.conv1.weight.data[:, :
                                                           export_channels, :, :]
        net.module.resnet.conv1 = torch.nn.Conv2d(num_channels,
                                                  64,
                                                  kernel_size=7,
                                                  stride=2,
                                                  padding=3,
                                                  bias=False)
        net.module.resnet.conv1.weight = torch.nn.Parameter(weights)

    if args.type == "pth":
        states = {
            "epoch": chkpt["epoch"],
            "state_dict": net.state_dict(),
            "optimizer": chkpt["optimizer"]
        }
        torch.save(states, args.out)
コード例 #8
0
def main(args):

    if (args.geojson and args.postgis) or (not args.geojson and not args.postgis):
        sys.exit("Input features to rasterize must be either GeoJSON or PostGIS")

    config = load_config(args.config)
    tile_size = args.tile_size if args.tile_size else config["model"]["tile_size"]
    colors = [classe["color"] for classe in config["classes"]]
    burn_value = 1

    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def geojson_parse_polygon(zoom, feature_map, polygon, i):

        try:
            for i, ring in enumerate(polygon["coordinates"]):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[x, y] for point in ring for x, y in zip([point[0]], [point[1]])]

            for tile in burntiles.burn([{"type": "feature", "geometry": polygon}], zoom=zoom):
                feature_map[mercantile.Tile(*tile)].append({"type": "feature", "geometry": polygon})

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def geojson_parse_geometry(zoom, feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = geojson_parse_polygon(zoom, feature_map, geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = geojson_parse_polygon(zoom, feature_map, {"type": "Polygon", "coordinates": polygon}, i)
        else:
            log.log("Notice: {} is a non surfacic geometry type, skipping feature {}".format(geometry["type"], i))

        return feature_map

    if args.geojson:

        tiles = [tile for tile in tiles_from_csv(args.cover)]
        zoom = tiles[0].z

        if [tile for tile in tiles if tile.z != zoom]:
            sys.exit("With GeoJson input, all tiles z values have to be the same, in the csv cover file.")

        feature_map = collections.defaultdict(list)

        # Compute a spatial index like
        for geojson_file in args.geojson:
            with open(geojson_file) as geojson:
                feature_collection = json.load(geojson)
                for i, feature in enumerate(tqdm(feature_collection["features"], ascii=True, unit="feature")):

                    if feature["geometry"]["type"] == "GeometryCollection":
                        for geometry in feature["geometry"]["geometries"]:
                            feature_map = geojson_parse_geometry(zoom, feature_map, geometry, i)
                    else:
                        feature_map = geojson_parse_geometry(zoom, feature_map, feature["geometry"], i)

        # Rasterize tiles
        for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"):
            if tile in feature_map:
                out = geojson_tile_burn(tile, feature_map[tile], tile_size, burn_value)
            else:
                out = np.zeros(shape=(tile_size, tile_size), dtype=np.uint8)

            write_tile(args.out, tile, colors, out)

    if args.postgis:

        try:
            pg_conn = psycopg2.connect(config["dataset"]["pg_dsn"])
            pg = pg_conn.cursor()
        except Exception:
            sys.exit("Unable to connect PostgreSQL: {}".format(config["dataset"]["pg_dsn"]))

        try:
            pg.execute("SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(args.postgis))
            srid = pg.fetchone()[0]
        except Exception:
            sys.exit("Unable to retrieve geometry SRID.")

        for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"):

            s, w, e, n = mercantile.bounds(tile)
            raster = np.zeros((tile_size, tile_size))

            query = """
WITH
     bbox      AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}  ) AS bbox),
     bbox_merc AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), 3857) AS bbox),

     rast_a    AS (SELECT ST_AddBand(
                           ST_SetSRID(
                             ST_MakeEmptyRaster({}, {}, ST_Xmin(bbox), ST_Ymax(bbox), (ST_YMax(bbox) - ST_YMin(bbox)) / {}),
                           3857),
                          '8BUI'::text, 0) AS rast
                   FROM bbox_merc),

     features  AS (SELECT ST_Union(ST_Transform(ST_Force2D(geom), 3857)) AS geom
                   FROM ({}) AS sub, bbox
                   WHERE ST_Intersects(geom, bbox)),

     rast_b    AS (SELECT ST_AsRaster(geom, rast, '8BUI', {}) AS rast
                   FROM features, rast_a
                   WHERE NOT ST_IsEmpty(geom))

SELECT ST_AsBinary(ST_MapAlgebra(rast_a.rast, rast_b.rast, '{}', NULL, 'FIRST')) AS wkb FROM rast_a, rast_b

""".format(
                s, w, e, n, srid, s, w, e, n, tile_size, tile_size, tile_size, args.postgis, burn_value, burn_value
            )

            try:
                pg.execute(query)
                row = pg.fetchone()
                if row:
                    raster = np.squeeze(wkb_to_numpy(io.BytesIO(row[0])), axis=2)

            except Exception:
                log.log("Warning: Invalid geometries, skipping {}".format(tile))
                pg_conn = psycopg2.connect(config["dataset"]["pg_dsn"])
                pg = pg_conn.cursor()

            write_tile(args.out, tile, colors, raster)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
コード例 #9
0
ファイル: validate.py プロジェクト: ajijohn/planet-snowcover
def main(args):
    configLoc = args[1]
    config = load_config(configLoc)

    print(config)
コード例 #10
0
ファイル: train.py プロジェクト: martham93/robosat.pink
def main(args):
    config = load_config(args.config)
    config["dataset"][
        "path"] = args.dataset if args.dataset else config["dataset"]["path"]
    config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"]
    config["model"][
        "epochs"] = args.epochs if args.epochs else config["model"]["epochs"]
    config["model"][
        "batch_size"] = args.batch_size if args.batch_size else config[
            "model"]["batch_size"]

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        device = torch.device("cuda")

        torch.backends.cudnn.benchmark = True
        log.log("RoboSat - training on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
    else:
        device = torch.device("cpu")
        log.log("RoboSat - training on CPU, with {} workers".format(
            args.workers))

    num_classes = len(config["classes"])
    num_channels = 0
    for channel in config["channels"]:
        num_channels += len(channel["bands"])
    pretrained = config["model"]["pretrained"]
    encoder = config["model"]["encoder"]

    models = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.models.__file__)])
    ]
    if config["model"]["name"] not in [model for model in models]:
        sys.exit("Unknown model, thoses available are {}".format(
            [model for model in models]))

    model_module = import_module("robosat_pink.models.{}".format(
        config["model"]["name"]))
    net = getattr(model_module, "{}".format(config["model"]["name"].title()))(
        num_classes=num_classes,
        num_channels=num_channels,
        encoder=encoder,
        pretrained=pretrained).to(device)

    net = torch.nn.DataParallel(net)
    optimizer = Adam(net.parameters(),
                     lr=config["model"]["lr"],
                     weight_decay=config["model"]["decay"])

    resume = 0
    if args.checkpoint:

        def map_location(storage, _):
            return storage.cuda() if torch.cuda.is_available(
            ) else storage.cpu()

        # https://github.com/pytorch/pytorch/issues/7178
        chkpt = torch.load(args.checkpoint, map_location=map_location)
        net.load_state_dict(chkpt["state_dict"])
        log.log("Using checkpoint: {}".format(args.checkpoint))

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]

    losses = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.losses.__file__)])
    ]
    if config["model"]["loss"] not in [loss for loss in losses]:
        sys.exit("Unknown loss, thoses available are {}".format(
            [loss for loss in losses]))

    loss_module = import_module("robosat_pink.losses.{}".format(
        config["model"]["loss"]))
    criterion = getattr(loss_module, "{}".format(
        config["model"]["loss"].title()))().to(device)

    train_loader, val_loader = get_dataset_loaders(config["dataset"]["path"],
                                                   config, args.workers)

    if resume >= config["model"]["epochs"]:
        sys.exit(
            "Error: Epoch {} set in {} already reached by the checkpoint provided"
            .format(config["model"]["epochs"], args.config))

    log.log("")
    log.log("--- Input tensor from Dataset: {} ---".format(
        config["dataset"]["path"]))
    num_channel = 1
    for channel in config["channels"]:
        for band in channel["bands"]:
            log.log("Channel {}:\t\t {}[band: {}]".format(
                num_channel, channel["sub"], band))
            num_channel += 1
    log.log("")
    log.log("--- Hyper Parameters ---")
    log.log("Model:\t\t\t {}".format(config["model"]["name"]))
    log.log("Encoder model:\t\t {}".format(config["model"]["encoder"]))
    log.log("Loss function:\t\t {}".format(config["model"]["loss"]))
    log.log("ResNet pre-trained:\t {}".format(config["model"]["pretrained"]))
    log.log("Batch Size:\t\t {}".format(config["model"]["batch_size"]))
    log.log("Tile Size:\t\t {}".format(config["model"]["tile_size"]))
    log.log("Data Augmentation:\t {}".format(
        config["model"]["data_augmentation"]))
    log.log("Learning Rate:\t\t {}".format(config["model"]["lr"]))
    log.log("Weight Decay:\t\t {}".format(config["model"]["decay"]))
    log.log("")

    for epoch in range(resume, config["model"]["epochs"]):

        log.log("---")
        log.log("Epoch: {}/{}".format(epoch + 1, config["model"]["epochs"]))

        train_hist = train(train_loader, num_classes, device, net, optimizer,
                           criterion)
        log.log(
            "Train    loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".
            format(
                train_hist["loss"],
                train_hist["miou"],
                config["classes"][1]["title"],
                train_hist["fg_iou"],
                train_hist["mcc"],
            ))

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        log.log(
            "Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".
            format(val_hist["loss"], val_hist["miou"],
                   config["classes"][1]["title"], val_hist["fg_iou"],
                   val_hist["mcc"]))

        states = {
            "epoch": epoch + 1,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        checkpoint_path = os.path.join(
            args.out, "checkpoint-{:05d}-of-{:05d}.pth".format(
                epoch + 1, config["model"]["epochs"]))
        torch.save(states, checkpoint_path)
コード例 #11
0
ファイル: predict.py プロジェクト: ajijohn/robosat.pink
def main(args):
    config = load_config(args.config)
    num_classes = len(config["classes"])
    batch_size = args.batch_size if args.batch_size else config["model"]["batch_size"]
    tile_size = config["model"]["tile_size"]

    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    def map_location(storage, _):
        return storage.cuda() if torch.cuda.is_available() else storage.cpu()

    # check checkpoint situation  + load if ncessary
    chkpt = None # no checkpoint
    if args.checkpoint: # command line checkpoint
        chkpt = args.checkpoint
    else:
        try: # config file checkpoint
            chkpt = config["checkpoint"]['path']
        except:
            # no checkpoint in config file
            pass

    S3_CHECKPOINT = False
    if chkpt.startswith("s3://"):
        S3_CHECKPOINT = True
        # load from s3
        chkpt = chkpt[5:]

    models = [name for _, name, _ in pkgutil.iter_modules([os.path.dirname(robosat_pink.models.__file__)])]
    if config["model"]["name"] not in [model for model in models]:
        sys.exit("Unknown model, thoses available are {}".format([model for model in models]))

    num_channels = 0
    for channel in config["channels"]:
        num_channels += len(channel["bands"])

    pretrained = config["model"]["pretrained"]
    encoder = config["model"]["encoder"]

    model_module = import_module("robosat_pink.models.{}".format(config["model"]["name"]))

    net = getattr(model_module, "{}".format(config["model"]["name"].title()))(
        num_classes=num_classes, num_channels=num_channels, encoder=encoder, pretrained=pretrained
    ).to(device)

    net = torch.nn.DataParallel(net)


    try:
        if S3_CHECKPOINT:
            sess = boto3.Session(profile_name=args.aws_profile)
            fs = s3fs.S3FileSystem(session=sess)
            with s3fs.S3File(fs, chkpt, 'rb') as C:
                state = torch.load(io.BytesIO(C.read()), map_location = map_location)
        else:
            state = torch.load(chkpt, map_location= map_location)
        net.load_state_dict(state['state_dict'], strict=False)
        net.to(device)
    except FileNotFoundError as f:
        print("{} checkpoint not found.".format(chkpt))


    net.eval()

    tile_ids_filter = None
    if args.tile_ids is not None:
        tile_ids_filter = pd.read_csv(args.tile_ids, names=['ids']).ids.values



    ## Construct torch Dataset, either from single directory (if args.tiles is given) or from config. Used --tile_ids argument
    ## to determine how to filter resulting tiles (e.g. to only run prediction on a test set)
    if args.tiles is not None:
        imagery_locs = [args.tiles]
        # use tiledir  provided
        if args.tiles.startswith('s3://'):
            allImageryDatasets = [S3SlippyMapTiles(args.tiles, mode='multibands', transform=None, aws_profile = args.aws_profile, ids = tile_ids_filter, buffered=args.buffer, buffered_overlap=args.buffer_overlap, tilesize=tile_size, bands=num_channels)]
        else:
            allImageryDatasets = [SlippyMapTiles(args.tiles, mode="multibands", transform = None)]
        # directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=tile_size,re overlap=args.overlap)
    else: # use config to search for tiles
        fs = s3fs.S3FileSystem(session = boto3.Session(profile_name = config['dataset']['aws_profile']))
        p = pprint.PrettyPrinter()
        imagery_searchpath = config['dataset']['image_bucket']  + '/' +  config['dataset']['imagery_directory_regex']
        print("Searching for imagery...({})".format(imagery_searchpath))
        imagery_candidates = fs.ls(config['dataset']['image_bucket'])
        print("candidates:")
        p.pprint(imagery_candidates)
        imagery_locs = [c for c in imagery_candidates if match(imagery_searchpath, c)]
        print("result:")
        p.pprint(imagery_locs)

        allImageryDatasets = [
            S3SlippyMapTiles("s3://" +  loc, mode='multibands', transform=None, aws_profile=args.aws_profile, ids=tile_ids_filter)
            for loc in imagery_locs
        ]


    palette = make_palette(config["classes"][0]["color"])


    # don't track tensors with autograd during prediction
    with torch.no_grad():
        for dataset, imageloc in zip(allImageryDatasets, imagery_locs):
            print("Prediction: {}".format(imageloc))
            imageloc_path = imageloc.replace("/", ":") # to not recreate directory structure when saving
            loader = DataLoader(dataset, batch_size=batch_size, num_workers=args.workers)
            for tiles, images in tqdm(loader, desc="Eval", unit="batch", ascii=True):
                tiles = list(zip(tiles[0], tiles[1], tiles[2]))
                images = images.to(device)
                outputs = net(images)


                for i, (tile, prob) in enumerate(zip(tiles, outputs)):
                    tile = Tile(tile[0].item(), tile[1].item(), tile[2].item())
                    savedir = args.preds

                    # manually compute segmentation mask class probabilities per pixel
                    image = (prob > args.threshold).cpu().numpy().astype(np.uint8)

                    if args.buffer:
                        image = allImageryDatasets[0].unbuffer(image)

                    image = image.squeeze()
                    
                    _write_png(tile, image, os.path.join(savedir, imageloc_path), palette)

                    if(args.create_tif):
                        _write_tif(tile, image, os.path.join(savedir, imageloc_path))
コード例 #12
0
ファイル: rasterize.py プロジェクト: yzuaiyou/robosat.pink
def main(args):

    if (args.geojson and args.postgis) or (not args.geojson
                                           and not args.postgis):
        sys.exit(
            "ERROR: Input features to rasterize must be either GeoJSON or PostGIS"
        )

    if args.postgis and not args.pg_dsn:
        sys.exit(
            "ERROR: With PostGIS input features, --pg_dsn must be provided")

    config = load_config(args.config)
    check_classes(config)
    palette = make_palette(*[classe["color"] for classe in config["classes"]],
                           complementary=True)
    burn_value = 1

    args.out = os.path.expanduser(args.out)
    os.makedirs(args.out, exist_ok=True)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def geojson_parse_polygon(zoom, srid, feature_map, polygon, i):

        try:
            if srid != 4326:
                polygon = [
                    xy for xy in geojson_reproject(
                        {
                            "type": "feature",
                            "geometry": polygon
                        }, srid, 4326)
                ][0]

            for i, ring in enumerate(
                    polygon["coordinates"]
            ):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[
                    x, y
                ] for point in ring for x, y in zip([point[0]], [point[1]])]

            if polygon["coordinates"]:
                for tile in burntiles.burn([{
                        "type": "feature",
                        "geometry": polygon
                }],
                                           zoom=zoom):
                    feature_map[mercantile.Tile(*tile)].append({
                        "type":
                        "feature",
                        "geometry":
                        polygon
                    })

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def geojson_parse_geometry(zoom, srid, feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = geojson_parse_polygon(zoom, srid, feature_map,
                                                geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = geojson_parse_polygon(zoom, srid, feature_map, {
                    "type": "Polygon",
                    "coordinates": polygon
                }, i)
        else:
            log.log(
                "Notice: {} is a non surfacic geometry type, skipping feature {}"
                .format(geometry["type"], i))

        return feature_map

    if args.geojson:

        try:
            tiles = [
                tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
            ]
            zoom = tiles[0].z
            assert not [tile for tile in tiles if tile.z != zoom]
        except:
            sys.exit("ERROR: Inconsistent cover {}".format(args.cover))

        feature_map = collections.defaultdict(list)

        log.log("RoboSat.pink - rasterize - Compute spatial index")
        for geojson_file in args.geojson:

            with open(os.path.expanduser(geojson_file)) as geojson:
                try:
                    feature_collection = json.load(geojson)
                except:
                    sys.exit("ERROR: {} is not a valid JSON file.".format(
                        geojson_file))

                try:
                    crs_mapping = {"CRS84": "4326", "900913": "3857"}
                    srid = feature_collection["crs"]["properties"][
                        "name"].split(":")[-1]
                    srid = int(srid) if srid not in crs_mapping else int(
                        crs_mapping[srid])
                except:
                    srid = int(4326)

                for i, feature in enumerate(
                        tqdm(feature_collection["features"],
                             ascii=True,
                             unit="feature")):

                    try:
                        if feature["geometry"]["type"] == "GeometryCollection":
                            for geometry in feature["geometry"]["geometries"]:
                                feature_map = geojson_parse_geometry(
                                    zoom, srid, feature_map, geometry, i)
                        else:
                            feature_map = geojson_parse_geometry(
                                zoom, srid, feature_map, feature["geometry"],
                                i)
                    except:
                        sys.exit(
                            "ERROR: Unable to parse {} file. Seems not a valid GEOJSON file."
                            .format(geojson_file))

        log.log(
            "RoboSat.pink - rasterize - rasterizing tiles from {} on cover {}".
            format(args.geojson, args.cover))
        with open(os.path.join(os.path.expanduser(args.out),
                               "instances.cover"),
                  mode="w") as cover:
            for tile in tqdm(list(
                    tiles_from_csv(os.path.expanduser(args.cover))),
                             ascii=True,
                             unit="tile"):

                try:
                    if tile in feature_map:
                        cover.write("{},{},{}  {}{}".format(
                            tile.x, tile.y, tile.z, len(feature_map[tile]),
                            os.linesep))
                        out = geojson_tile_burn(tile, feature_map[tile], 4326,
                                                args.ts, burn_value)
                    else:
                        cover.write("{},{},{}  {}{}".format(
                            tile.x, tile.y, tile.z, 0, os.linesep))
                        out = np.zeros(shape=(args.ts, args.ts),
                                       dtype=np.uint8)

                    tile_label_to_file(args.out, tile, palette, out)
                except:
                    log.log("Warning: Unable to rasterize tile. Skipping {}".
                            format(str(tile)))

    if args.postgis:

        try:
            pg_conn = psycopg2.connect(args.pg_dsn)
            pg = pg_conn.cursor()
        except Exception:
            sys.exit("Unable to connect PostgreSQL: {}".format(args.pg_dsn))

        log.log(
            "RoboSat.pink - rasterize - rasterizing tiles from PostGIS on cover {}"
            .format(args.cover))
        log.log(" SQL {}".format(args.postgis))
        try:
            pg.execute(
                "SELECT ST_Srid(geom) AS srid FROM ({} LIMIT 1) AS sub".format(
                    args.postgis))
            srid = pg.fetchone()[0]
        except Exception:
            sys.exit("Unable to retrieve geometry SRID.")

        for tile in tqdm(list(tiles_from_csv(args.cover)),
                         ascii=True,
                         unit="tile"):

            s, w, e, n = mercantile.bounds(tile)
            raster = np.zeros((args.ts, args.ts))

            query = """
WITH
     bbox      AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {}  ) AS bbox),
     bbox_merc AS (SELECT ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), 3857) AS bbox),

     rast_a    AS (SELECT ST_AddBand(
                           ST_SetSRID(
                             ST_MakeEmptyRaster({}, {}, ST_Xmin(bbox), ST_Ymax(bbox), (ST_YMax(bbox) - ST_YMin(bbox)) / {}),
                           3857),
                          '8BUI'::text, 0) AS rast
                   FROM bbox_merc),

     features  AS (SELECT ST_Union(ST_Transform(ST_Force2D(geom), 3857)) AS geom
                   FROM ({}) AS sub, bbox
                   WHERE ST_Intersects(geom, bbox)),

     rast_b    AS (SELECT ST_AsRaster(geom, rast, '8BUI', {}) AS rast
                   FROM features, rast_a
                   WHERE NOT ST_IsEmpty(geom))

SELECT ST_AsBinary(ST_MapAlgebra(rast_a.rast, rast_b.rast, '{}', NULL, 'FIRST')) AS wkb FROM rast_a, rast_b

""".format(s, w, e, n, srid, s, w, e, n, args.ts, args.ts, args.ts,
            args.postgis, burn_value, burn_value)

            try:
                pg.execute(query)
                row = pg.fetchone()
                if row:
                    raster = np.squeeze(wkb_to_numpy(io.BytesIO(row[0])),
                                        axis=2)

            except Exception:
                log.log(
                    "Warning: Invalid geometries, skipping {}".format(tile))
                pg_conn = psycopg2.connect(args.pg_dsn)
                pg = pg_conn.cursor()

            try:
                tile_label_to_file(args.out, tile, palette, raster)
            except:
                log.log(
                    "Warning: Unable to rasterize tile. Skipping {}".format(
                        str(tile)))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
コード例 #13
0
def main(args):

    config = load_config(args.config)
    colors = [classe["color"] for classe in config["classes"]]
    tile_size = args.tile_size

    try:
        raster = rasterio_open(args.raster)
        w, s, e, n = bounds = transform_bounds(raster.crs, "EPSG:4326",
                                               *raster.bounds)
        transform, _, _ = calculate_default_transform(raster.crs, "EPSG:3857",
                                                      raster.width,
                                                      raster.height, *bounds)
    except:
        sys.exit("Error: Unable to load raster or deal with it's projection")

    tiles = [
        mercantile.Tile(x=x, y=y, z=z)
        for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
    ]
    tiles_nodata = []

    for tile in tqdm(tiles, desc="Tiling", unit="tile", ascii=True):

        w, s, e, n = tile_bounds = mercantile.xy_bounds(tile)

        # Inspired by Rio-Tiler, cf: https://github.com/mapbox/rio-tiler/pull/45
        warp_vrt = WarpedVRT(
            raster,
            crs="EPSG:3857",
            resampling=Resampling.bilinear,
            add_alpha=False,
            transform=from_bounds(*tile_bounds, args.size, args.size),
            width=math.ceil((e - w) / transform.a),
            height=math.ceil((s - n) / transform.e),
        )
        data = warp_vrt.read(out_shape=(len(raster.indexes), tile_size,
                                        tile_size),
                             window=warp_vrt.window(w, s, e, n))

        # If no_data is set, remove all tiles with at least one whole border filled only with no_data (on all bands)
        if type(args.no_data) is not None and (
                np.all(data[:, 0, :] == args.no_data)
                or np.all(data[:, -1, :] == args.no_data)
                or np.all(data[:, :, 0] == args.no_data)
                or np.all(data[:, :, -1] == args.no_data)):
            tiles_nodata.append(tile)
            continue

        C, W, H = data.shape

        os.makedirs(os.path.join(args.out, str(args.zoom), str(tile.x)),
                    exist_ok=True)
        path = os.path.join(args.out, str(args.zoom), str(tile.x), str(tile.y))

        if args.type == "label":
            assert C == 1, "Error: Label raster input should be 1 band"

            ext = "png"
            img = Image.fromarray(np.squeeze(data, axis=0), mode="P")
            img.putpalette(make_palette(colors[0], colors[1]))
            img.save("{}.{}".format(path, ext), optimize=True)

        elif args.type == "image":
            assert C == 1 or C == 3, "Error: Image raster input should be either 1 or 3 bands"

            # GeoTiff could be 16 or 32bits
            if data.dtype == "uint16":
                data = np.uint8(data / 256)
            elif data.dtype == "uint32":
                data = np.uint8(data / (256 * 256))

            if C == 1:
                ext = "png"
                Image.fromarray(np.squeeze(data, axis=0),
                                mode="L").save("{}.{}".format(path, ext),
                                               optimize=True)
            elif C == 3:
                ext = "webp"
                Image.fromarray(np.moveaxis(data, 0, 2),
                                mode="RGB").save("{}.{}".format(path, ext),
                                                 optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile in tiles if tile not in tiles_nodata]
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        web_ui(args.out, base_url, tiles, tiles, ext, template)
コード例 #14
0
def main(args):
    config = load_config(args.config)
    tile_size = args.tile_size if args.tile_size else config["model"]["tile_size"]
    colors = [classe["color"] for classe in config["classes"]]

    os.makedirs(args.out, exist_ok=True)

    # We can only rasterize all tiles at a single zoom.
    assert all(tile.z == args.zoom for tile in tiles_from_csv(args.cover))

    # Find all tiles the features cover and make a map object for quick lookup.
    feature_map = collections.defaultdict(list)
    log = Logs(os.path.join(args.out, "log"), out=sys.stderr)

    def parse_polygon(feature_map, polygon, i):

        try:
            for i, ring in enumerate(polygon["coordinates"]):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[x, y] for point in ring for x, y in zip([point[0]], [point[1]])]

            for tile in burntiles.burn([{"type": "feature", "geometry": polygon}], zoom=args.zoom):
                feature_map[mercantile.Tile(*tile)].append({"type": "feature", "geometry": polygon})

        except ValueError:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def parse_geometry(feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = parse_polygon(feature_map, geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = parse_polygon(feature_map, {"type": "Polygon", "coordinates": polygon}, i)
        else:
            log.log("Notice: {} is a non surfacic geometry type, skipping feature {}".format(geometry["type"], i))

        return feature_map

    for feature in args.features:
        with open(feature) as f:
            fc = json.load(f)
            for i, feature in enumerate(tqdm(fc["features"], ascii=True, unit="feature")):

                if feature["geometry"]["type"] == "GeometryCollection":
                    for geometry in feature["geometry"]["geometries"]:
                        feature_map = parse_geometry(feature_map, geometry, i)
                else:
                    feature_map = parse_geometry(feature_map, feature["geometry"], i)

    # Burn features to tiles and write to a slippy map directory.
    for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"):
        if tile in feature_map:
            out = burn(tile, feature_map[tile], tile_size)
        else:
            out = np.zeros(shape=(tile_size, tile_size), dtype=np.uint8)

        out_dir = os.path.join(args.out, str(tile.z), str(tile.x))
        os.makedirs(out_dir, exist_ok=True)

        out_path = os.path.join(out_dir, "{}.png".format(tile.y))

        if os.path.exists(out_path):
            prev = np.array(Image.open(out_path))
            out = np.maximum(out, prev)

        out = Image.fromarray(out, mode="P")

        out_path = os.path.join(args.out, str(tile.z), str(tile.x))
        os.makedirs(out_path, exist_ok=True)

        out.putpalette(complementary_palette(make_palette(colors[0], colors[1])))
        out.save(os.path.join(out_path, "{}.png".format(tile.y)), optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
コード例 #15
0
def main(args):
    config = load_config(args.config)
    num_classes = len(config["classes"])
    batch_size = args.batch_size if args.batch_size else config["model"][
        "batch_size"]
    tile_size = args.tile_size if args.tile_size else config["model"][
        "tile_size"]

    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    def map_location(storage, _):
        return storage.cuda() if torch.cuda.is_available() else storage.cpu()

    # https://github.com/pytorch/pytorch/issues/7178
    # chkpt = torch.load(args.checkpoint, map_location=map_location)
    S3_CHECKPOINT = False
    chkpt = args.checkpoint
    if chkpt.startswith("s3://"):
        S3_CHECKPOINT = True
        # load from s3
        chkpt = chkpt[5:]

    models = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.models.__file__)])
    ]
    if config["model"]["name"] not in [model for model in models]:
        sys.exit("Unknown model, thoses available are {}".format(
            [model for model in models]))

    num_channels = 0
    for channel in config["channels"]:
        num_channels += len(channel["bands"])

    pretrained = config["model"]["pretrained"]
    encoder = config["model"]["encoder"]

    model_module = import_module("robosat_pink.models.{}".format(
        config["model"]["name"]))

    net = getattr(model_module, "{}".format(config["model"]["name"].title()))(
        num_classes=num_classes,
        num_channels=num_channels,
        encoder=encoder,
        pretrained=pretrained).to(device)

    net = torch.nn.DataParallel(net)

    try:
        if S3_CHECKPOINT:
            sess = boto3.Session(profile_name=args.aws_profile)
            fs = s3fs.S3FileSystem(session=sess)
            with s3fs.S3File(fs, chkpt, 'rb') as C:
                state = torch.load(io.BytesIO(C.read()),
                                   map_location=map_location)
        else:
            state = torch.load(chkpt, map_location=map_location)
        net.load_state_dict(state['state_dict'])
        net.to(device)
    except FileNotFoundError as f:
        print("{} checkpoint not found.".format(CHECKPOINT))

    net.eval()
    #
    # mean = np.array([[[8237.95084794]],
    #
    #                [[6467.98702156]],
    #
    #                [[6446.61743148]],
    #
    #                [[4520.95360105]]])
    # std  = array([[[12067.03414753]],
    #
    #                [[ 8810.00542703]],
    #
    #                [[10710.64289882]],
    #
    #                [[ 9024.92028515]]])
    # #transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)])
    # transform = A.Compose([
    #     A.Normalize(mean = mean, std = std, max_pixel_value = 1.0),
    #     A.ToFloat()
    # ])

    if args.tiles.startswith('s3://'):
        directory = S3SlippyMapTiles(args.tiles,
                                     mode='multibands',
                                     transform=None,
                                     aws_profile=args.aws_profile)
    else:
        directory = SlippyMapTiles(args.tiles,
                                   mode="multibands",
                                   transform=transform)
    # directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=tile_size, overlap=args.overlap)
    loader = DataLoader(directory,
                        batch_size=batch_size,
                        num_workers=args.workers)

    palette = make_palette(config["classes"][0]["color"])

    # don't track tensors with autograd during prediction
    with torch.no_grad():
        for tiles, images in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):
            tiles = list(zip(tiles[0], tiles[1], tiles[2]))
            images = images.to(device)
            outputs = net(images)

            print(len(tiles), len(outputs))
            for tile, prob in zip([tiles], outputs):
                savedir = args.probs
                x = tile[0].item()
                y = tile[1].item()
                z = tile[2].item()

                # manually compute segmentation mask class probabilities per pixel

                image = (prob > args.threshold).astype(np.uint8)

                out = Image.fromarray(image, mode="P")
                out.putpalette(palette)

                os.makedirs(os.path.join(args.probs, str(z), str(x)),
                            exist_ok=True)
                path = os.path.join(args.probs, str(z), str(x),
                                    str(y) + ".png")

                out.save(path, optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.tiles)]
        web_ui(args.probs, base_url, tiles, tiles, "png", template)
コード例 #16
0
#get_ipython().run_line_magic('matplotlib', 'inline')
import rasterio as rio
from matplotlib import pyplot as plt
import rasterio.plot
import os
from datetime import datetime as dt
from rasterio.io import MemoryFile
import tempfile

sys.path.append("../model/robosat_pink/")
from robosat_pink.config import load_config
# original with 5/28
config_location= '/home/ubuntu/planet-snowcover/experiments/co-train.toml'
# revised with neighboring watershed - revered May 1st as dont have DEM or VEG for that
#config_location= '/home/ubuntu/planet-snowcover/experiments/co-train-neigh.toml'
config = load_config(config_location)


p = pprint.PrettyPrinter()

fs = s3fs.S3FileSystem(session = boto3.Session(profile_name = config['dataset']['aws_profile']))

imagery_searchpath = config['dataset']['image_bucket']  + '/' +  config['dataset']['imagery_directory_regex']
print("Searching for imagery...({})".format(imagery_searchpath))
imagery_candidates = fs.ls(config['dataset']['image_bucket'])
#print("candidates:")
#p.pprint(imagery_candidates)
imagery_locs = [c for c in imagery_candidates if match(imagery_searchpath, c)]
print("result:")
p.pprint(imagery_locs)
コード例 #17
0
def main(args):
    config = load_config(args.config)
    args.out = os.path.expanduser(args.out)
    args.workers = torch.cuda.device_count() * 2 if torch.device(
        "cuda") and not args.workers else args.workers
    config["model"][
        "loader"] = args.loader if args.loader else config["model"]["loader"]
    config["model"]["bs"] = args.bs if args.bs else config["model"]["bs"]
    config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"]
    config["model"]["ts"] = args.ts if args.ts else config["model"]["ts"]
    config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"]
    config["model"][
        "loss"] = args.loss if args.loss else config["model"]["loss"]
    config["model"]["da"] = args.da if args.da else config["model"]["da"]
    config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"]
    check_classes(config)
    check_channels(config)
    check_model(config)

    if not os.path.isdir(os.path.expanduser(args.dataset)):
        sys.exit("ERROR: dataset {} is not a directory".format(args.dataset))

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - training on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - training on CPU, with {} workers - (Torch:{})".
                format(args.workers, torch.__version__))
        log.log(
            "WARNING: Are you really sure sure about not training on GPU ?")
        device = torch.device("cpu")

    try:
        loader = import_module("robosat_pink.loaders.{}".format(
            config["model"]["loader"].lower()))
        loader_train = getattr(loader, config["model"]["loader"])(
            config, config["model"]["ts"],
            os.path.join(args.dataset, "training"), "train")
        loader_val = getattr(loader, config["model"]["loader"])(
            config, config["model"]["ts"],
            os.path.join(args.dataset, "validation"), "train")
    except:
        sys.exit("ERROR: Unable to load data loaders")

    try:
        model_module = import_module("robosat_pink.models.{}".format(
            config["model"]["nn"].lower()))
    except:
        sys.exit("ERROR: Unable to load {} model".format(
            config["model"]["nn"]))

    nn = getattr(model_module, config["model"]["nn"])(
        loader_train.shape_in, loader_train.shape_out,
        config["model"]["pretrained"]).to(device)
    nn = torch.nn.DataParallel(nn)
    optimizer = Adam(nn.parameters(), lr=config["model"]["lr"])

    resume = 0
    if args.checkpoint:
        try:
            chkpt = torch.load(os.path.expanduser(args.checkpoint),
                               map_location=device)
            nn.load_state_dict(chkpt["state_dict"])
            log.log("Using checkpoint: {}".format(args.checkpoint))

        except:
            sys.exit("ERROR: Unable to load {} checkpoint".format(
                args.checkpoint))

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]
            if resume >= args.epochs:
                sys.exit(
                    "ERROR: Epoch {} already reached by the given checkpoint".
                    format(config["model"]["epochs"]))

    try:
        loss_module = import_module("robosat_pink.losses.{}".format(
            config["model"]["loss"].lower()))
        criterion = getattr(loss_module, config["model"]["loss"])().to(device)
    except:
        sys.exit("ERROR: Unable to load {} loss".format(
            config["model"]["loss"]))

    bs = config["model"]["bs"]
    train_loader = DataLoader(loader_train,
                              batch_size=bs,
                              shuffle=True,
                              drop_last=True,
                              num_workers=args.workers)
    val_loader = DataLoader(loader_val,
                            batch_size=bs,
                            shuffle=False,
                            drop_last=True,
                            num_workers=args.workers)

    log.log("--- Input tensor from Dataset: {} ---".format(args.dataset))
    num_channel = 1  # 1-based numerotation
    for channel in config["channels"]:
        for band in channel["bands"]:
            log.log("Channel {}:\t\t {}[band: {}]".format(
                num_channel, channel["name"], band))
            num_channel += 1

    log.log("--- Hyper Parameters ---")
    for hp in config["model"]:
        log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp]))

    for epoch in range(resume, args.epochs):
        UUID = uuid.uuid1()
        log.log("---{}Epoch: {}/{} -- UUID: {}".format(os.linesep, epoch + 1,
                                                       args.epochs, UUID))

        train(train_loader, config, log, device, nn, optimizer, criterion)
        validate(val_loader, config, log, device, nn, criterion)

        try:  # https://github.com/pytorch/pytorch/issues/9176
            nn_doc = nn.module.doc
            nn_version = nn.module.version
        except AttributeError:
            nn_version = nn.version
            nn_doc == nn.doc

        states = {
            "uuid": UUID,
            "model_version": nn_version,
            "producer_name": "RoboSat.pink",
            "producer_version": "0.4.0",
            "model_licence": "MIT",
            "domain": "pink.RoboSat",  # reverse-DNS
            "doc_string": nn_doc,
            "shape_in": loader_train.shape_in,
            "shape_out": loader_train.shape_out,
            "state_dict": nn.state_dict(),
            "epoch": epoch + 1,
            "nn": config["model"]["nn"],
            "optimizer": optimizer.state_dict(),
            "loader": config["model"]["loader"],
        }
        checkpoint_path = os.path.join(
            args.out, "checkpoint-{:05d}.pth".format(epoch + 1))
        try:
            torch.save(states, checkpoint_path)
        except:
            sys.exit(
                "ERROR: Unable to save checkpoint {}".format(checkpoint_path))
コード例 #18
0
def main(args):
    config = load_config(args.config)
    num_classes = len(config["classes"])
    batch_size = args.batch_size if args.batch_size else config["model"][
        "batch_size"]
    tile_size = args.tile_size if args.tile_size else config["model"][
        "tile_size"]

    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    def map_location(storage, _):
        return storage.cuda() if torch.cuda.is_available() else storage.cpu()

    # https://github.com/pytorch/pytorch/issues/7178
    chkpt = torch.load(args.checkpoint, map_location=map_location)

    models = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.models.__file__)])
    ]
    if config["model"]["name"] not in [model for model in models]:
        sys.exit("Unknown model, thoses available are {}".format(
            [model for model in models]))

    std = []
    mean = []
    num_channels = 0
    for channel in config["channels"]:
        std.extend(channel["std"])
        mean.extend(channel["mean"])
        num_channels += len(channel["bands"])

    encoder = config["model"]["encoder"]
    pretrained = config["model"]["pretrained"]

    model_module = import_module("robosat_pink.models.{}".format(
        config["model"]["name"]))

    net = getattr(model_module, "{}".format(config["model"]["name"].title()))(
        num_classes=num_classes,
        num_channels=num_channels,
        encoder=encoder,
        pretrained=pretrained).to(device)

    net = torch.nn.DataParallel(net)

    net.load_state_dict(chkpt["state_dict"])
    net.eval()

    transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)])
    directory = BufferedSlippyMapTiles(args.tiles,
                                       transform=transform,
                                       size=tile_size,
                                       overlap=args.overlap)
    loader = DataLoader(directory,
                        batch_size=batch_size,
                        num_workers=args.workers)

    palette = make_palette(config["classes"][0]["color"],
                           config["classes"][1]["color"])

    # don't track tensors with autograd during prediction
    with torch.no_grad():
        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):
            images = images.to(device)
            outputs = net(images)

            # manually compute segmentation mask class probabilities per pixel
            probs = torch.nn.functional.softmax(outputs,
                                                dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))

                # we predicted on buffered tiles; now get back probs for original image
                prob = directory.unbuffer(prob)

                assert prob.shape[
                    0] == 2, "single channel requires binary model"
                assert np.allclose(
                    np.sum(prob, axis=0), 1.0
                ), "single channel requires probabilities to sum up to one"

                image = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()

                out = Image.fromarray(image, mode="P")
                out.putpalette(palette)

                os.makedirs(os.path.join(args.probs, str(z), str(x)),
                            exist_ok=True)
                path = os.path.join(args.probs, str(z), str(x),
                                    str(y) + ".png")

                out.save(path, optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.tiles)]
        web_ui(args.probs, base_url, tiles, tiles, "png", template)
コード例 #19
0
ファイル: predict.py プロジェクト: yzuaiyou/robosat.pink
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    args.workers = torch.cuda.device_count() * 2 if torch.device(
        "cuda") and not args.workers else args.workers

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        device = torch.device("cpu")

    try:
        chkpt = torch.load(args.checkpoint, map_location=device)
        assert chkpt["producer_name"] == "RoboSat.pink"
        model_module = import_module("robosat_pink.models.{}".format(
            chkpt["nn"].lower()))
        nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"],
                                                chkpt["shape_out"]).to(device)
        nn = torch.nn.DataParallel(nn)
        nn.load_state_dict(chkpt["state_dict"])
        nn.eval()
    except:
        sys.exit("ERROR: Unable to load {} checkpoint.".format(
            args.checkpoint))

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    try:
        loader_module = import_module("robosat_pink.loaders.{}".format(
            chkpt["loader"].lower()))
        loader_predict = getattr(loader_module,
                                 chkpt["loader"])(config,
                                                  chkpt["shape_in"][1:3],
                                                  args.tiles,
                                                  mode="predict")
    except:
        sys.exit("ERROR: Unable to load {} data loader.".format(
            chkpt["loader"]))

    loader = DataLoader(loader_predict,
                        batch_size=args.bs,
                        num_workers=args.workers)
    palette = make_palette(config["classes"][0]["color"],
                           config["classes"][1]["color"])

    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):

            images = images.to(device)

            try:
                outputs = nn(images)
                probs = torch.nn.functional.softmax(outputs,
                                                    dim=1).data.cpu().numpy()
            except:
                log.log("WARNING: Skipping batch:")
                for tile, prob in zip(tiles, probs):
                    log.log(" - {}".format(str(tile)))
                continue

            for tile, prob in zip(tiles, probs):

                try:
                    x, y, z = list(map(int, tile))
                    mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()
                    tile_label_to_file(args.out, mercantile.Tile(x, y, z),
                                       palette, mask)
                except:
                    log.log("WARNING: Skipping tile {}".format(str(tile)))

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.out)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)