Пример #1
0
    def __init__(self, config, ts, root, mode):
        super().__init__()

        self.root = os.path.expanduser(root)
        self.config = config
        self.mode = mode

        assert mode == "train" or mode == "predict"

        num_channels = 0
        self.tiles = {}
        for channel in config["channels"]:
            path = os.path.join(self.root, channel["name"])
            self.tiles[channel["name"]] = [
                (tile, path)
                for tile, path in tiles_from_dir(path, xyz_path=True)
            ]
            self.tiles[channel["name"]].sort(key=lambda tile: tile[0])
            num_channels += len(channel["bands"])

        self.shape_in = (num_channels, ) + ts  # C,W,H
        self.shape_out = (len(config["classes"]), ) + ts  # C,W,H

        if self.mode == "train":
            path = os.path.join(self.root, "labels")
            self.tiles["labels"] = [
                (tile, path)
                for tile, path in tiles_from_dir(path, xyz_path=True)
            ]
            self.tiles["labels"].sort(key=lambda tile: tile[0])
Пример #2
0
    def __init__(self, config, ts, root, cover, mode):
        super().__init__()

        self.root = os.path.expanduser(root)
        self.config = config
        self.mode = mode
        self.cover = cover

        assert mode in ["train", "predict", "predict_translate"]
        xyz_translate = True if mode == "predict_translate" else False

        num_channels = 0
        self.tiles = {}
        for channel in config["channels"]:
            path = os.path.join(self.root, channel["name"])
            self.tiles[channel["name"]] = [
                (tile, path) for tile, path in tiles_from_dir(path, cover=cover, xyz_path=True, xyz_translate=xyz_translate)
            ]
            self.tiles[channel["name"]].sort(key=lambda tile: tile[0])
            num_channels += len(channel["bands"])

        self.shape_in = (num_channels,) + ts  # C,W,H
        self.shape_out = (len(config["classes"]),) + ts  # C,W,H

        if self.mode == "train":
            path = os.path.join(self.root, "labels")
            self.tiles["labels"] = [(tile, path) for tile, path in tiles_from_dir(path, cover=cover, xyz_path=True)]
            self.tiles["labels"].sort(key=lambda tile: tile[0])

        assert len(self.tiles), "Empty Dataset"
Пример #3
0
def main(args):
    config = load_config(args.config)
    check_classes(config)
    index = [i for i in (list(range(len(config["classes"])))) if config["classes"][i]["title"] == args.type]
    assert index, "Requested type {} not found among classes title in the config file.".format(args.type)
    print("RoboSat.pink - vectorize {} from {}".format(args.type, args.masks), file=sys.stderr, flush=True)

    out = open(args.out, "w", encoding="utf-8")
    assert out, "Unable to write in output file"

    out.write('{"type":"FeatureCollection","features":[')

    first = True
    for tile, path in tqdm(list(tiles_from_dir(args.masks, xyz_path=True)), ascii=True, unit="mask"):
        mask = (np.array(Image.open(path).convert("P"), dtype=np.uint8) == index).astype(np.uint8)
        try:
            C, W, H = mask.shape
        except:
            W, H = mask.shape
        transform = rasterio.transform.from_bounds((*mercantile.bounds(tile.x, tile.y, tile.z)), W, H)

        for shape, value in rasterio.features.shapes(mask, transform=transform, mask=mask):
            geom = '"geometry":{{"type": "Polygon", "coordinates":{}}}'.format(json.dumps(shape["coordinates"]))
            out.write('{}{{"type":"Feature",{}}}'.format("" if first else ",", geom))
            first = False

    out.write("]}")
Пример #4
0
    def test_slippy_map_directory(self):
        root = "tests/fixtures/images"
        tiles = [(tile, path)
                 for tile, path in tiles_from_dir(root, xyz_path=True)]
        tiles.sort()

        self.assertEqual(len(tiles), 3)

        tile, path = tiles[0]
        self.assertEqual(type(tile), mercantile.Tile)
        self.assertEqual(path, "tests/fixtures/images/18/69105/105093.jpg")
Пример #5
0
def main(args):

    assert not (args.extent and args.splits
                ), "--splits and --extent are mutually exclusive options."
    assert not (args.extent and
                len(args.out) > 1), "--extent option imply a single output."
    assert not (args.sql and not args.pg), "--sql option imply --pg"
    assert (
        int(args.bbox is not None) + int(args.geojson is not None) +
        int(args.sql is not None) + int(args.dir is not None) +
        int(args.raster is not None) + int(args.cover is not None) == 1
    ), "One, and only one, input type must be provided, among: --dir, --bbox, --cover, --raster, --geojson or --sql"

    if args.bbox:
        try:
            w, s, e, n, crs = args.bbox.split(",")
            w, s, e, n = map(float, (w, s, e, n))
        except:
            crs = None
            w, s, e, n = map(float, args.bbox.split(","))
        assert isinstance(w, float) and isinstance(
            s, float), "Invalid bbox parameter."

    if args.splits:
        splits = [int(split) for split in args.splits.split("/")]
        assert len(splits) == len(args.out) and 0 < sum(
            splits) <= 100, "Invalid split value or incoherent with out paths."

    assert not (not args.zoom and
                (args.geojson or args.bbox
                 or args.raster)), "Zoom parameter is required."

    args.out = [os.path.expanduser(out) for out in args.out]

    cover = []

    if args.raster:
        print("RoboSat.pink - cover from {} at zoom {}".format(
            args.raster, args.zoom),
              file=sys.stderr,
              flush=True)
        with rasterio_open(os.path.expanduser(args.raster)) as r:
            w, s, e, n = transform_bounds(r.crs, "EPSG:4326", *r.bounds)
            assert isinstance(w, float) and isinstance(
                s, float), "Unable to deal with raster projection"

            cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.geojson:
        print("RoboSat.pink - cover from {} at zoom {}".format(
            args.geojson, args.zoom),
              file=sys.stderr,
              flush=True)
        with open(os.path.expanduser(args.geojson)) as f:
            feature_collection = json.load(f)
            srid = geojson_srid(feature_collection)
            feature_map = collections.defaultdict(list)

            for feature in tqdm(feature_collection["features"],
                                ascii=True,
                                unit="feature"):
                feature_map = geojson_parse_feature(args.zoom, srid,
                                                    feature_map, feature)

        cover = feature_map.keys()

    if args.sql:
        print("RoboSat.pink - cover from {} {} at zoom {}".format(
            args.sql, args.pg, args.zoom),
              file=sys.stderr,
              flush=True)
        conn = psycopg2.connect(args.pg)
        assert conn, "Unable to connect to PostgreSQL database."
        db = conn.cursor()

        query = """
            WITH
              sql  AS ({}),
              geom AS (SELECT "1" AS geom FROM sql AS t("1"))
              SELECT '{{"type": "Feature", "geometry": '
                     || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6)
                     || '}}' AS features
              FROM geom
            """.format(args.sql)

        db.execute(query)
        assert db.rowcount is not None and db.rowcount != -1, "SQL Query return no result."

        feature_map = collections.defaultdict(list)

        for feature in tqdm(
                db.fetchall(), ascii=True, unit="feature"
        ):  # FIXME: fetchall will not always fit in memory...
            feature_map = geojson_parse_feature(args.zoom, 4326, feature_map,
                                                json.loads(feature[0]))

        cover = feature_map.keys()

    if args.bbox:
        print("RoboSat.pink - cover from {} at zoom {}".format(
            args.bbox, args.zoom),
              file=sys.stderr,
              flush=True)
        if crs:
            w, s, e, n = transform_bounds(crs, "EPSG:4326", w, s, e, n)
            assert isinstance(w, float) and isinstance(
                s, float), "Unable to deal with raster projection"

        cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.cover:
        print("RoboSat.pink - cover from {}".format(args.cover),
              file=sys.stderr,
              flush=True)
        cover = [tile for tile in tiles_from_csv(args.cover)]

    if args.dir:
        print("RoboSat.pink - cover from {}".format(args.dir),
              file=sys.stderr,
              flush=True)
        cover = [
            tile for tile in tiles_from_dir(args.dir, xyz=not (args.no_xyz))
        ]

    _cover = []
    extent_w, extent_s, extent_n, extent_e = (180.0, 90.0, -180.0, -90.0)
    for tile in tqdm(cover, ascii=True, unit="tile"):
        if args.zoom and tile.z != args.zoom:
            w, s, n, e = transform_bounds("EPSG:3857", "EPSG:4326",
                                          *xy_bounds(tile))
            for t in tiles(w, s, n, e, args.zoom):
                unique = True
                for _t in _cover:
                    if _t == t:
                        unique = False
                if unique:
                    _cover.append(t)
        else:
            if args.extent:
                w, s, n, e = transform_bounds("EPSG:3857", "EPSG:4326",
                                              *xy_bounds(tile))
            _cover.append(tile)

        if args.extent:
            extent_w, extent_s, extent_n, extent_e = (min(extent_w,
                                                          w), min(extent_s, s),
                                                      max(extent_n,
                                                          n), max(extent_e, e))

    cover = _cover

    if args.splits:
        shuffle(cover)  # in-place
        cover_splits = [
            math.floor(len(cover) * split / 100)
            for i, split in enumerate(splits, 1)
        ]
        if len(splits) > 1 and sum(map(
                int, splits)) == 100 and len(cover) > sum(map(int, splits)):
            cover_splits[0] = len(cover) - sum(map(
                int, cover_splits[1:]))  # no tile waste
        s = 0
        covers = []
        for e in cover_splits:
            covers.append(cover[s:s + e])
            s += e
    else:
        covers = [cover]

    if args.extent:
        if args.out and os.path.dirname(args.out[0]) and not os.path.isdir(
                os.path.dirname(args.out[0])):
            os.makedirs(os.path.dirname(args.out[0]), exist_ok=True)

        extent = "{:.8f},{:.8f},{:.8f},{:.8f}".format(extent_w, extent_s,
                                                      extent_n, extent_e)

        if args.out:
            with open(args.out[0], "w") as fp:
                fp.write(extent)
        else:
            print(extent)
    else:
        for i, cover in enumerate(covers):

            if os.path.dirname(args.out[i]) and not os.path.isdir(
                    os.path.dirname(args.out[i])):
                os.makedirs(os.path.dirname(args.out[i]), exist_ok=True)

            with open(args.out[i], "w") as fp:
                csv.writer(fp).writerows(cover)
Пример #6
0
def main(args):

    args.out = os.path.expanduser(args.out)

    if not args.workers:
        args.workers = os.cpu_count()

    print("RoboSat.pink - compare {} on CPU, with {} workers".format(args.mode, args.workers), file=sys.stderr, flush=True)

    if not args.masks or not args.labels:
        assert args.mode != "list", "Parameters masks and labels are mandatories in list mode."
        assert args.minimum_fg == 0.0 and args.maximum_fg == 100.0, "Both masks and labels mandatory in QoD filtering."
        assert args.minimum_qod == 0.0 and args.maximum_qod == 100.0, "Both masks and labels mandatory in QoD filtering."

    if args.images:
        tiles = [tile for tile in tiles_from_dir(args.images[0])]
        for image in args.images[1:]:
            assert sorted(tiles) == sorted([tile for tile in tiles_from_dir(image)]), "Unconsistent images directories"

    if args.labels and args.masks:
        tiles_masks = [tile for tile in tiles_from_dir(args.masks)]
        tiles_labels = [tile for tile in tiles_from_dir(args.labels)]
        if args.images:
            assert sorted(tiles) == sorted(tiles_masks) == sorted(tiles_labels), "Unconsistent images/label/mask directories"
        else:
            assert sorted(tiles_masks) == sorted(tiles_labels), "Label and Mask directories are not consistent"
            tiles = tiles_masks

    tiles_list = []
    tiles_compare = []
    progress = tqdm(total=len(tiles), ascii=True, unit="tile")
    log = False if args.mode == "list" else Logs(os.path.join(args.out, "log"))

    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(tile):
            x, y, z = list(map(str, tile))

            if args.masks and args.labels:

                label = np.array(Image.open(os.path.join(args.labels, z, x, "{}.png".format(y))))
                mask = np.array(Image.open(os.path.join(args.masks, z, x, "{}.png".format(y))))

                assert label.shape == mask.shape, "Inconsistent tiles (size or dimensions)"

                try:
                    dist, fg_ratio, qod = compare(torch.as_tensor(label, device="cpu"), torch.as_tensor(mask, device="cpu"))
                except:
                    progress.update()
                    return False, tile

                if not args.minimum_fg <= fg_ratio <= args.maximum_fg or not args.minimum_qod <= qod <= args.maximum_qod:
                    progress.update()
                    return True, tile

            tiles_compare.append(tile)

            if args.mode == "side":
                for i, root in enumerate(args.images):
                    img = tile_image_from_file(tile_from_xyz(root, x, y, z)[1])

                    if i == 0:
                        side = np.zeros((img.shape[0], img.shape[1] * len(args.images), 3))
                        side = np.swapaxes(side, 0, 1) if args.vertical else side
                        image_shape = img.shape
                    else:
                        assert image_shape[0:2] == img.shape[0:2], "Unconsistent image size to compare"

                    if args.vertical:
                        side[i * image_shape[0] : (i + 1) * image_shape[0], :, :] = img
                    else:
                        side[:, i * image_shape[0] : (i + 1) * image_shape[0], :] = img

                tile_image_to_file(args.out, tile, np.uint8(side))

            elif args.mode == "stack":
                for i, root in enumerate(args.images):
                    tile_image = tile_image_from_file(tile_from_xyz(root, x, y, z)[1])

                    if i == 0:
                        image_shape = tile_image.shape[0:2]
                        stack = tile_image / len(args.images)
                    else:
                        assert image_shape == tile_image.shape[0:2], "Unconsistent image size to compare"
                        stack = stack + (tile_image / len(args.images))

                tile_image_to_file(args.out, tile, np.uint8(stack))

            elif args.mode == "list":
                tiles_list.append([tile, fg_ratio, qod])

            progress.update()
            return True, tile

        for tile, ok in executor.map(worker, tiles):
            if not ok and log:
                log.log("Warning: skipping. {}".format(str(tile)))

    if args.mode == "list":
        with open(args.out, mode="w") as out:

            if args.geojson:
                out.write('{"type":"FeatureCollection","features":[')

            first = True
            for tile_list in tiles_list:
                tile, fg_ratio, qod = tile_list
                x, y, z = list(map(str, tile))
                if args.geojson:
                    prop = '"properties":{{"x":{},"y":{},"z":{},"fg":{:.1f},"qod":{:.1f}}}'.format(x, y, z, fg_ratio, qod)
                    geom = '"geometry":{}'.format(json.dumps(feature(tile, precision=6)["geometry"]))
                    out.write('{}{{"type":"Feature",{},{}}}'.format("," if not first else "", geom, prop))
                    first = False
                else:
                    out.write("{},{},{}\t{:.1f}\t{:.1f}{}".format(x, y, z, fg_ratio, qod, os.linesep))

            if args.geojson:
                out.write("]}")
            out.close()

    base_url = args.web_ui_base_url if args.web_ui_base_url else "."

    if args.mode == "side" and not args.no_web_ui:
        template = "compare.html" if not args.web_ui_template else args.web_ui_template
        web_ui(args.out, base_url, tiles, tiles_compare, args.format, template, union_tiles=False)

    if args.mode == "stack" and not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile in tiles_from_dir(args.images[0])]
        web_ui(args.out, base_url, tiles, tiles_compare, args.format, template)
Пример #7
0
def main(args):

    assert not (args.extent and args.splits), "--splits and --extent are mutually exclusive options."
    assert not (args.extent and len(args.out) > 1), "--extent option imply a single output."
    assert (
        int(args.bbox is not None)
        + int(args.geojson is not None)
        + int(args.dir is not None)
        + int(args.raster is not None)
        + int(args.cover is not None)
        == 1
    ), "One, and only one, input type must be provided, among: --dir, --bbox, --cover or --geojson."

    if args.bbox:
        try:
            w, s, e, n, crs = args.bbox.split(",")
            w, s, e, n = map(float, (w, s, e, n))
        except:
            crs = None
            w, s, e, n = map(float, args.bbox.split(","))
        assert isinstance(w, float) and isinstance(s, float), "Invalid bbox parameter."

    if args.splits:
        splits = [int(split) for split in args.splits.split("/")]
        assert len(splits) == len(args.out) and 0 < sum(splits) <= 100, "Invalid split value or incoherent with out paths."

    assert not (not args.zoom and (args.geojson or args.bbox or args.raster)), "Zoom parameter is required."

    args.out = [os.path.expanduser(out) for out in args.out]

    cover = []

    if args.raster:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.raster, args.zoom), file=sys.stderr, flush=True)
        with rasterio_open(os.path.expanduser(args.raster)) as r:
            w, s, e, n = transform_bounds(r.crs, "EPSG:4326", *r.bounds)
            assert isinstance(w, float) and isinstance(s, float), "Unable to deal with raster projection"

            cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.geojson:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.geojson, args.zoom), file=sys.stderr, flush=True)
        with open(os.path.expanduser(args.geojson)) as f:
            feature_collection = json.load(f)
            srid = geojson_srid(feature_collection)
            feature_map = collections.defaultdict(list)

            for i, feature in enumerate(tqdm(feature_collection["features"], ascii=True, unit="feature")):
                feature_map = geojson_parse_feature(args.zoom, srid, feature_map, feature)

        cover = feature_map.keys()

    if args.bbox:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.bbox, args.zoom), file=sys.stderr, flush=True)
        if crs:
            w, s, e, n = transform_bounds(crs, "EPSG:4326", w, s, e, n)
            assert isinstance(w, float) and isinstance(s, float), "Unable to deal with raster projection"

        cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.cover:
        print("RoboSat.pink - cover from {}".format(args.cover), file=sys.stderr, flush=True)
        cover = [tile for tile in tiles_from_csv(args.cover)]

    if args.dir:
        print("RoboSat.pink - cover from {}".format(args.dir), file=sys.stderr, flush=True)
        cover = [tile for tile in tiles_from_dir(args.dir, xyz=not (args.no_xyz))]

    _cover = []
    extent_w, extent_s, extent_n, extent_e = (180.0, 90.0, -180.0, -90.0)
    for tile in tqdm(cover, ascii=True, unit="tile"):
        if args.zoom and tile.z != args.zoom:
            w, s, n, e = transform_bounds("EPSG:3857", "EPSG:4326", *xy_bounds(tile))
            for t in tiles(w, s, n, e, args.zoom):
                unique = True
                for _t in _cover:
                    if _t == t:
                        unique = False
                if unique:
                    _cover.append(t)
        else:
            if args.extent:
                w, s, n, e = transform_bounds("EPSG:3857", "EPSG:4326", *xy_bounds(tile))
            _cover.append(tile)

        if args.extent:
            extent_w, extent_s, extent_n, extent_e = (min(extent_w, w), min(extent_s, s), max(extent_n, n), max(extent_e, e))

    cover = _cover

    if args.splits:
        shuffle(cover)  # in-place
        cover_splits = [math.floor(len(cover) * split / 100) for i, split in enumerate(splits, 1)]
        if len(splits) > 1 and sum(map(int, splits)) == 100 and len(cover) > sum(map(int, splits)):
            cover_splits[0] = len(cover) - sum(map(int, cover_splits[1:]))  # no tile waste
        s = 0
        covers = []
        for e in cover_splits:
            covers.append(cover[s : s + e])
            s += e
    else:
        covers = [cover]

    if args.extent:
        if args.out and os.path.dirname(args.out[0]) and not os.path.isdir(os.path.dirname(args.out[0])):
            os.makedirs(os.path.dirname(args.out[0]), exist_ok=True)

        extent = "{:.8f},{:.8f},{:.8f},{:.8f}".format(extent_w, extent_s, extent_n, extent_e)

        if args.out:
            with open(args.out[0], "w") as fp:
                fp.write(extent)
        else:
            print(extent)
    else:
        for i, cover in enumerate(covers):

            if os.path.dirname(args.out[i]) and not os.path.isdir(os.path.dirname(args.out[i])):
                os.makedirs(os.path.dirname(args.out[i]), exist_ok=True)

            with open(args.out[i], "w") as fp:
                csv.writer(fp).writerows(cover)
Пример #8
0
def main(args):

    if (
        int(args.bbox is not None)
        + int(args.geojson is not None)
        + int(args.dir is not None)
        + int(args.xyz is not None)
        + int(args.raster is not None)
        + int(args.cover is not None)
        != 1
    ):
        sys.exit("ERROR: One, and only one, input type must be provided, among: --dir, --bbox, --cover or --geojson.")

    if args.bbox:
        try:
            w, s, e, n, crs = args.bbox.split(",")
            w, s, e, n = map(float, (w, s, e, n))
        except:
            try:
                crs = None
                w, s, e, n = map(float, args.bbox.split(","))
            except:
                sys.exit("ERROR: invalid bbox parameter.")

    if args.splits:

        try:
            splits = [int(split) for split in args.splits.split("/")]
            assert len(splits) == len(args.out)
            assert sum(splits) == 100
        except:
            sys.exit("ERROR: Invalid split value or incoherent with provided out paths.")

    if not args.zoom and (args.geojson or args.bbox or args.raster):
        sys.exit("ERROR: Zoom parameter is required.")

    args.out = [os.path.expanduser(out) for out in args.out]

    cover = []

    if args.raster:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.raster, args.zoom))
        with rasterio_open(os.path.expanduser(args.raster)) as r:
            try:
                w, s, e, n = transform_bounds(r.crs, "EPSG:4326", *r.bounds)
            except:
                sys.exit("ERROR: unable to deal with raster projection")

            cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.geojson:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.geojson, args.zoom))
        with open(os.path.expanduser(args.geojson)) as f:
            features = json.load(f)

        try:
            for feature in tqdm(features["features"], ascii=True, unit="feature"):
                cover.extend(map(tuple, burntiles.burn([feature], args.zoom).tolist()))
        except:
            sys.exit("ERROR: invalid or unsupported GeoJSON.")

        cover = list(set(cover))  # tiles can overlap for multiple features; unique tile ids

    if args.bbox:
        print("RoboSat.pink - cover from {} at zoom {}".format(args.bbox, args.zoom))
        if crs:
            try:
                w, s, e, n = transform_bounds(crs, "EPSG:4326", w, s, e, n)
            except:
                sys.exit("ERROR: unable to deal with raster projection")

        cover = [tile for tile in tiles(w, s, e, n, args.zoom)]

    if args.cover:
        print("RoboSat.pink - cover from {}".format(args.cover))
        cover = [tile for tile in tiles_from_csv(args.cover)]

    if args.dir:
        print("RoboSat.pink - cover from {}".format(args.dir))
        cover = [tile for tile in tiles_from_dir(args.dir, xyz=False)]

    if args.xyz:
        print("RoboSat.pink - cover from {}".format(args.xyz))
        cover = [tile for tile in tiles_from_dir(args.xyz, xyz=True)]

    _cover = []
    for tile in tqdm(cover, ascii=True, unit="tile"):
        if args.zoom and tile.z != args.zoom:
            w, s, n, e = transform_bounds("EPSG:3857", "EPSG:4326", *xy_bounds(tile))
            for t in tiles(w, s, n, e, args.zoom):
                unique = True
                for _t in _cover:
                    if _t == t:
                        unique = False
                if unique:
                    _cover.append(t)
        else:
            _cover.append(tile)
    cover = _cover

    if args.splits:
        shuffle(cover)  # in-place
        splits = [math.floor(len(cover) * split / 100) for i, split in enumerate(splits, 1)]
        s = 0
        covers = []
        for e in splits:
            covers.append(cover[s : s + e - 1])
            s += e
    else:
        covers = [cover]

    for i, cover in enumerate(covers):

        if os.path.dirname(args.out[i]) and not os.path.isdir(os.path.dirname(args.out[i])):
            os.makedirs(os.path.dirname(args.out[i]), exist_ok=True)

        with open(args.out[i], "w") as fp:
            csv.writer(fp).writerows(cover)
Пример #9
0
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    palette = make_palette([classe["color"] for classe in config["classes"]])
    args.workers = torch.cuda.device_count() * 2 if torch.device(
        "cuda") and not args.workers else args.workers

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        device = torch.device("cpu")

    chkpt = torch.load(args.checkpoint, map_location=device)
    model_module = load_module("robosat_pink.models.{}".format(
        chkpt["nn"].lower()))
    nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"],
                                            chkpt["shape_out"]).to(device)
    nn = torch.nn.DataParallel(nn)
    nn.load_state_dict(chkpt["state_dict"])
    nn.eval()

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    loader_module = load_module("robosat_pink.loaders.{}".format(
        chkpt["loader"].lower()))
    loader_predict = getattr(loader_module,
                             chkpt["loader"])(config,
                                              chkpt["shape_in"][1:3],
                                              args.dataset,
                                              mode="predict")

    loader = DataLoader(loader_predict,
                        batch_size=args.bs,
                        num_workers=args.workers)
    assert len(loader), "Empty predict dataset directory. Check your path."

    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):

            images = images.to(device)

            outputs = nn(images)
            probs = torch.nn.functional.softmax(outputs,
                                                dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))
                mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()
                tile_label_to_file(args.out, mercantile.Tile(x, y, z), palette,
                                   mask)

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        tiles = [tile for tile in tiles_from_dir(args.out)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)