示例#1
0
    def __init__(self, config, ts, root, mode):
        super().__init__()

        self.root = os.path.expanduser(root)
        self.config = config
        self.mode = mode

        assert mode == "train" or mode == "predict"

        num_channels = 0
        self.tiles = {}
        for channel in config["channels"]:
            path = os.path.join(self.root, channel["name"])
            self.tiles[channel["name"]] = [
                (tile, path) for tile, path in tiles_from_slippy_map(path)
            ]
            self.tiles[channel["name"]].sort(key=lambda tile: tile[0])
            num_channels += len(channel["bands"])

        self.shape_in = (num_channels, ) + ts  # C,W,H
        self.shape_out = (len(config["classes"]), ) + ts  # C,W,H

        if self.mode == "train":
            path = os.path.join(self.root, "labels")
            self.tiles["labels"] = [
                (tile, path) for tile, path in tiles_from_slippy_map(path)
            ]
            self.tiles["labels"].sort(key=lambda tile: tile[0])
示例#2
0
    def __init__(self, root, mode, transform=None):
        super().__init__()

        self.tiles = []
        self.transform = transform

        self.tiles = [(tile, path) for tile, path in tiles_from_slippy_map(root)]
        self.tiles.sort(key=lambda tile: tile[0])
        self.mode = mode
示例#3
0
    def test_slippy_map_directory(self):
        root = "tests/fixtures/images"
        tiles = [tile for tile in tiles_from_slippy_map(root)]
        tiles.sort()

        self.assertEqual(len(tiles), 3)

        tile, path = tiles[0]
        self.assertEqual(type(tile), mercantile.Tile)
        self.assertEqual(path, "tests/fixtures/images/18/69105/105093.jpg")
示例#4
0
    def __init__(self, root, transform=None, size=512, overlap=32):

        super().__init__()

        assert size >= 256
        assert overlap >= 0

        self.size = size
        self.overlap = overlap
        self.transform = transform
        self.tiles = list(tiles_from_slippy_map(root))
示例#5
0
    def __init__(self, root, mode, transform=None, tile_index=False):
        super().__init__()

        self.tiles = []
        self.transform = transform
        self.tile_index = tile_index

        self.tiles = [(tile, path)
                      for tile, path in tiles_from_slippy_map(root)]
        if tile_index:
            self.tiles = dict(self.tiles)

        #self.tiles.sort(key=lambda tile: tile[0])
        self.mode = mode
示例#6
0
def main(args):
    config = load_config(args.config)
    check_classes(config)
    index = [
        i for i in (list(range(len(config["classes"]))))
        if config["classes"][i]["title"] == args.type
    ]
    if not index:
        sys.exit(
            "ERROR: Requested type {} not found among classes title in the config file."
            .format(args.type))

    print("RoboSat.pink - vectorize {} from {}".format(args.type, args.masks))

    with open(args.out, "w", encoding="utf-8") as out:
        first = True
        out.write('{"type":"FeatureCollection","features":[')

        for tile, path in tqdm(list(tiles_from_slippy_map(args.masks)),
                               ascii=True,
                               unit="mask"):
            try:
                features = (np.array(Image.open(path).convert("P"),
                                     dtype=np.uint8) == index).astype(np.uint8)
                try:
                    C, W, H = features.shape
                except:
                    W, H = features.shape
                transform = rasterio.transform.from_bounds(
                    (*mercantile.bounds(tile.x, tile.y, tile.z)), W, H)

                for shape, value in rasterio.features.shapes(
                        features, transform=transform):
                    prop = '"properties":{{"x":{},"y":{},"z":{}}}'.format(
                        int(tile.x), int(tile.y), int(tile.z))
                    geom = '"geometry":{{"type": "Polygon", "coordinates":{}}}'.format(
                        json.dumps(shape["coordinates"]))
                    out.write('{}{{"type":"Feature",{},{}}}'.format(
                        "," if not first else "", geom, prop))
                    first = False
            except:
                sys.exit("ERROR: Unable to vectorize tile {}.".format(
                    str(tile)))

        out.write("]}")
示例#7
0
def main(args):
    dataset = load_config(args.dataset)

    labels = dataset["common"]["classes"]
    assert set(labels).issuperset(
        set(handlers.keys())), "handlers have a class label"
    index = labels.index(args.type)

    handler = handlers[args.type]()

    tiles = list(tiles_from_slippy_map(args.masks))

    for tile, path in tqdm(tiles, ascii=True, unit="mask"):
        image = np.array(Image.open(path).convert("P"), dtype=np.uint8)
        mask = (image == index).astype(np.uint8)

        handler.apply(tile, mask)

    handler.save(args.out)
示例#8
0
    def __init__(self, root, transform=None, size=512, overlap=32):
        """
        Args:
          root: the slippy map directory root with a `z/x/y.png` sub-structure.
          transform: the transformation to run on the buffered tile.
          size: the Slippy Map tile size in pixels
          overlap: the tile border to add on every side; in pixel.

        Note:
          The overlap must not span multiple tiles.

          Use `unbuffer` to get back the original tile.
        """

        super().__init__()

        assert overlap >= 0
        assert size >= 256

        self.transform = transform
        self.size = size
        self.overlap = overlap
        self.tiles = list(tiles_from_slippy_map(root))
示例#9
0
def main(args):

    module_search_path = [args.path] if args.path else []
    module_search_path.append(
        os.path.join(Path(__file__).parent.parent, "features"))
    modules = [(path, name)
               for path, name, _ in pkgutil.iter_modules(module_search_path)
               if name != "core"]
    if args.type not in [name for _, name in modules]:
        sys.exit("Unknown type, thoses available are {}".format(
            [name for _, name in modules]))

    config = load_config(args.config)
    labels = config["classes"]["titles"]
    if args.type not in labels:
        sys.exit(
            "The type you asked is not consistent with yours classes in the config file provided."
        )
    index = labels.index(args.type)

    if args.path:
        sys.path.append(args.path)
        module = import_module(args.type)
    else:
        module = import_module("robosat_pink.features.{}".format(args.type))

    handler = getattr(module, "{}Handler".format(args.type.title()))()

    for tile, path in tqdm(list(tiles_from_slippy_map(args.masks)),
                           ascii=True,
                           unit="mask"):
        image = np.array(Image.open(path).convert("P"), dtype=np.uint8)
        mask = (image == index).astype(np.uint8)
        handler.apply(tile, mask)

    handler.save(args.out)
示例#10
0
def main(args):
    config = load_config(args.config)
    num_classes = len(config["classes"])
    batch_size = args.batch_size if args.batch_size else config["model"][
        "batch_size"]
    tile_size = args.tile_size if args.tile_size else config["model"][
        "tile_size"]

    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    def map_location(storage, _):
        return storage.cuda() if torch.cuda.is_available() else storage.cpu()

    # https://github.com/pytorch/pytorch/issues/7178
    chkpt = torch.load(args.checkpoint, map_location=map_location)

    models = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.models.__file__)])
    ]
    if config["model"]["name"] not in [model for model in models]:
        sys.exit("Unknown model, thoses available are {}".format(
            [model for model in models]))

    std = []
    mean = []
    num_channels = 0
    for channel in config["channels"]:
        std.extend(channel["std"])
        mean.extend(channel["mean"])
        num_channels += len(channel["bands"])

    encoder = config["model"]["encoder"]
    pretrained = config["model"]["pretrained"]

    model_module = import_module("robosat_pink.models.{}".format(
        config["model"]["name"]))

    net = getattr(model_module, "{}".format(config["model"]["name"].title()))(
        num_classes=num_classes,
        num_channels=num_channels,
        encoder=encoder,
        pretrained=pretrained).to(device)

    net = torch.nn.DataParallel(net)

    net.load_state_dict(chkpt["state_dict"])
    net.eval()

    transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)])
    directory = BufferedSlippyMapTiles(args.tiles,
                                       transform=transform,
                                       size=tile_size,
                                       overlap=args.overlap)
    loader = DataLoader(directory,
                        batch_size=batch_size,
                        num_workers=args.workers)

    palette = make_palette(config["classes"][0]["color"],
                           config["classes"][1]["color"])

    # don't track tensors with autograd during prediction
    with torch.no_grad():
        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):
            images = images.to(device)
            outputs = net(images)

            # manually compute segmentation mask class probabilities per pixel
            probs = torch.nn.functional.softmax(outputs,
                                                dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))

                # we predicted on buffered tiles; now get back probs for original image
                prob = directory.unbuffer(prob)

                assert prob.shape[
                    0] == 2, "single channel requires binary model"
                assert np.allclose(
                    np.sum(prob, axis=0), 1.0
                ), "single channel requires probabilities to sum up to one"

                image = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()

                out = Image.fromarray(image, mode="P")
                out.putpalette(palette)

                os.makedirs(os.path.join(args.probs, str(z), str(x)),
                            exist_ok=True)
                path = os.path.join(args.probs, str(z), str(x),
                                    str(y) + ".png")

                out.save(path, optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.tiles)]
        web_ui(args.probs, base_url, tiles, tiles, "png", template)
示例#11
0
def main(args):
    config = load_config(args.config)
    num_classes = len(config["classes"])
    batch_size = args.batch_size if args.batch_size else config["model"][
        "batch_size"]
    tile_size = args.tile_size if args.tile_size else config["model"][
        "tile_size"]

    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    def map_location(storage, _):
        return storage.cuda() if torch.cuda.is_available() else storage.cpu()

    # https://github.com/pytorch/pytorch/issues/7178
    # chkpt = torch.load(args.checkpoint, map_location=map_location)
    S3_CHECKPOINT = False
    chkpt = args.checkpoint
    if chkpt.startswith("s3://"):
        S3_CHECKPOINT = True
        # load from s3
        chkpt = chkpt[5:]

    models = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.models.__file__)])
    ]
    if config["model"]["name"] not in [model for model in models]:
        sys.exit("Unknown model, thoses available are {}".format(
            [model for model in models]))

    num_channels = 0
    for channel in config["channels"]:
        num_channels += len(channel["bands"])

    pretrained = config["model"]["pretrained"]
    encoder = config["model"]["encoder"]

    model_module = import_module("robosat_pink.models.{}".format(
        config["model"]["name"]))

    net = getattr(model_module, "{}".format(config["model"]["name"].title()))(
        num_classes=num_classes,
        num_channels=num_channels,
        encoder=encoder,
        pretrained=pretrained).to(device)

    net = torch.nn.DataParallel(net)

    try:
        if S3_CHECKPOINT:
            sess = boto3.Session(profile_name=args.aws_profile)
            fs = s3fs.S3FileSystem(session=sess)
            with s3fs.S3File(fs, chkpt, 'rb') as C:
                state = torch.load(io.BytesIO(C.read()),
                                   map_location=map_location)
        else:
            state = torch.load(chkpt, map_location=map_location)
        net.load_state_dict(state['state_dict'])
        net.to(device)
    except FileNotFoundError as f:
        print("{} checkpoint not found.".format(CHECKPOINT))

    net.eval()
    #
    # mean = np.array([[[8237.95084794]],
    #
    #                [[6467.98702156]],
    #
    #                [[6446.61743148]],
    #
    #                [[4520.95360105]]])
    # std  = array([[[12067.03414753]],
    #
    #                [[ 8810.00542703]],
    #
    #                [[10710.64289882]],
    #
    #                [[ 9024.92028515]]])
    # #transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)])
    # transform = A.Compose([
    #     A.Normalize(mean = mean, std = std, max_pixel_value = 1.0),
    #     A.ToFloat()
    # ])

    if args.tiles.startswith('s3://'):
        directory = S3SlippyMapTiles(args.tiles,
                                     mode='multibands',
                                     transform=None,
                                     aws_profile=args.aws_profile)
    else:
        directory = SlippyMapTiles(args.tiles,
                                   mode="multibands",
                                   transform=transform)
    # directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=tile_size, overlap=args.overlap)
    loader = DataLoader(directory,
                        batch_size=batch_size,
                        num_workers=args.workers)

    palette = make_palette(config["classes"][0]["color"])

    # don't track tensors with autograd during prediction
    with torch.no_grad():
        for tiles, images in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):
            tiles = list(zip(tiles[0], tiles[1], tiles[2]))
            images = images.to(device)
            outputs = net(images)

            print(len(tiles), len(outputs))
            for tile, prob in zip([tiles], outputs):
                savedir = args.probs
                x = tile[0].item()
                y = tile[1].item()
                z = tile[2].item()

                # manually compute segmentation mask class probabilities per pixel

                image = (prob > args.threshold).astype(np.uint8)

                out = Image.fromarray(image, mode="P")
                out.putpalette(palette)

                os.makedirs(os.path.join(args.probs, str(z), str(x)),
                            exist_ok=True)
                path = os.path.join(args.probs, str(z), str(x),
                                    str(y) + ".png")

                out.save(path, optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.tiles)]
        web_ui(args.probs, base_url, tiles, tiles, "png", template)
示例#12
0
def main(args):

    if not args.masks or not args.labels or not args.config:
        if args.mode == "list":
            sys.exit(
                "Parameters masks, labels and config, are all mandatories in list mode."
            )
        if args.minimum_fg > 0 or args.maximum_fg < 100 or args.minimum_qod > 0 or args.maximum_qod < 100:
            sys.exit(
                "Parameters masks, labels and config, are all mandatories in QoD filtering."
            )

    if args.images:
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])]
        for image in args.images[1:]:
            assert sorted(tiles) == sorted([
                tile for tile, _ in tiles_from_slippy_map(image)
            ]), "inconsistent coverages"

    if args.labels and args.masks:
        tiles_masks = [tile for tile, _ in tiles_from_slippy_map(args.masks)]
        tiles_labels = [tile for tile, _ in tiles_from_slippy_map(args.labels)]
        if args.images:
            assert sorted(tiles) == sorted(tiles_masks) == sorted(
                tiles_labels), "inconsistent coverages"
        else:
            assert sorted(tiles_masks) == sorted(
                tiles_labels), "inconsistent coverages"
            tiles = tiles_masks

    if args.mode == "list":
        out = open(args.out, mode="w")
        if args.geojson:
            out.write('{"type":"FeatureCollection","features":[')
            first = True

    tiles_compare = []
    for tile in tqdm(list(tiles), desc="Compare", unit="tile", ascii=True):

        x, y, z = list(map(str, tile))

        if args.masks and args.labels and args.config:
            titles = [
                classe["title"]
                for classe in load_config(args.config)["classes"]
            ]
            dist, fg_ratio, qod = compare(args.masks, args.labels, tile,
                                          titles)
            if not args.minimum_fg <= fg_ratio <= args.maximum_fg or not args.minimum_qod <= qod <= args.maximum_qod:
                continue

        tiles_compare.append(tile)

        if args.mode == "side":

            for i, root in enumerate(args.images):
                img = tile_image(tile_from_slippy_map(root, x, y, z)[1])

                if i == 0:
                    side = np.zeros(
                        (img.shape[0], img.shape[1] * len(args.images), 3))
                    side = np.swapaxes(side, 0, 1) if args.vertical else side
                    image_shape = img.shape
                else:
                    assert image_shape == img.shape, "Unconsistent image size to compare"

                if args.vertical:
                    side[i * image_shape[0]:(i + 1) *
                         image_shape[0], :, :] = img
                else:
                    side[:,
                         i * image_shape[0]:(i + 1) * image_shape[0], :] = img

            os.makedirs(os.path.join(args.out, z, x), exist_ok=True)
            side = Image.fromarray(np.uint8(side))
            side.save(os.path.join(args.out, z, x, "{}.{}".format(y,
                                                                  args.ext)),
                      optimize=True)

        elif args.mode == "stack":

            for i, root in enumerate(args.images):
                img = tile_image(tile_from_slippy_map(root, x, y, z)[1])

                if i == 0:
                    image_shape = img.shape[0:2]
                    stack = img / len(args.images)
                else:
                    assert image_shape == img.shape[
                        0:2], "Unconsistent image size to compare"
                    stack = stack + (img / len(args.images))

            os.makedirs(os.path.join(args.out, str(z), str(x)), exist_ok=True)
            stack = Image.fromarray(np.uint8(stack))
            stack.save(os.path.join(args.out, str(z), str(x),
                                    "{}.{}".format(y, args.ext)),
                       optimize=True)

        elif args.mode == "list":
            if args.geojson:
                prop = '"properties":{{"x":{},"y":{},"z":{},"fg":{:.1f},"qod":{:.1f}}}'.format(
                    x, y, z, fg_ratio, qod)
                geom = '"geometry":{}'.format(
                    json.dumps(feature(tile, precision=6)["geometry"]))
                out.write('{}{{"type":"Feature",{},{}}}'.format(
                    "," if not first else "", geom, prop))
                first = False
            else:
                out.write("{},{},{}\t\t{:.1f}\t\t{:.1f}{}".format(
                    x, y, z, fg_ratio, qod, os.linesep))

    if args.mode == "list":
        if args.geojson:
            out.write("]}")
        out.close()

    base_url = args.web_ui_base_url if args.web_ui_base_url else "./"

    if args.mode == "side" and args.web_ui:
        template = "compare.html" if not args.web_ui_template else args.web_ui_template
        web_ui(args.out, base_url, None, tiles_compare, args.ext, template)

    if args.mode == "stack" and args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])]
        web_ui(args.out, base_url, tiles, tiles_compare, args.ext, template)
示例#13
0
def main(args):

    args.out = os.path.expanduser(args.out)

    if not args.workers:
        args.workers = max(1, math.floor(os.cpu_count() * 0.5))

    print("RoboSat.pink - compare {} on CPU, with {} workers".format(
        args.mode, args.workers))

    if not args.masks or not args.labels:
        if args.mode == "list":
            sys.exit(
                "ERROR: Parameters masks and labels are mandatories in list mode."
            )
        if args.minimum_fg > 0 or args.maximum_fg < 100 or args.minimum_qod > 0 or args.maximum_qod < 100:
            sys.exit(
                "ERROR: Parameters masks and labels are mandatories in QoD filtering."
            )

    try:
        if args.images:
            tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])]
            for image in args.images[1:]:
                assert sorted(tiles) == sorted(
                    [tile for tile, _ in tiles_from_slippy_map(image)])

        if args.labels and args.masks:
            tiles_masks = [
                tile for tile, _ in tiles_from_slippy_map(args.masks)
            ]
            tiles_labels = [
                tile for tile, _ in tiles_from_slippy_map(args.labels)
            ]
            if args.images:
                assert sorted(tiles) == sorted(tiles_masks) == sorted(
                    tiles_labels)
            else:
                assert sorted(tiles_masks) == sorted(tiles_labels)
                tiles = tiles_masks
    except:
        sys.exit("ERROR: inconsistent input coverage")

    tiles_list = []
    tiles_compare = []
    progress = tqdm(total=len(tiles), ascii=True, unit="tile")
    log = False if args.mode == "list" else Logs(os.path.join(args.out, "log"))

    with futures.ThreadPoolExecutor(args.workers) as executor:

        def worker(tile):
            x, y, z = list(map(str, tile))

            if args.masks and args.labels:

                label = np.array(
                    Image.open(
                        os.path.join(args.labels, z, x, "{}.png".format(y))))
                mask = np.array(
                    Image.open(
                        os.path.join(args.masks, z, x, "{}.png".format(y))))

                assert label.shape == mask.shape

                try:
                    dist, fg_ratio, qod = compare(
                        torch.as_tensor(label, device="cpu"),
                        torch.as_tensor(mask, device="cpu"))
                except:
                    progress.update()
                    return False, tile

                if not args.minimum_fg <= fg_ratio <= args.maximum_fg or not args.minimum_qod <= qod <= args.maximum_qod:
                    progress.update()
                    return True, tile

            tiles_compare.append(tile)

            if args.mode == "side":
                for i, root in enumerate(args.images):
                    img = tile_image_from_file(
                        tile_from_slippy_map(root, x, y, z)[1])

                    if i == 0:
                        side = np.zeros(
                            (img.shape[0], img.shape[1] * len(args.images), 3))
                        side = np.swapaxes(side, 0,
                                           1) if args.vertical else side
                        image_shape = img.shape
                    else:
                        assert image_shape[0:2] == img.shape[
                            0:2], "Unconsistent image size to compare"

                    if args.vertical:
                        side[i * image_shape[0]:(i + 1) *
                             image_shape[0], :, :] = img
                    else:
                        side[:, i * image_shape[0]:(i + 1) *
                             image_shape[0], :] = img

                tile_image_to_file(args.out, tile, np.uint8(side))

            elif args.mode == "stack":
                for i, root in enumerate(args.images):
                    tile_image = tile_image_from_file(
                        tile_from_slippy_map(root, x, y, z)[1])

                    if i == 0:
                        image_shape = tile_image.shape[0:2]
                        stack = tile_image / len(args.images)
                    else:
                        assert image_shape == tile_image.shape[
                            0:2], "Unconsistent image size to compare"
                        stack = stack + (tile_image / len(args.images))

                tile_image_to_file(args.out, tile, np.uint8(stack))

            elif args.mode == "list":
                tiles_list.append([tile, fg_ratio, qod])

            progress.update()
            return True, tile

        for tile, ok in executor.map(worker, tiles):
            if not ok and log:
                log.log("Warning: skipping. {}".format(str(tile)))

    if args.mode == "list":
        with open(args.out, mode="w") as out:

            if args.geojson:
                out.write('{"type":"FeatureCollection","features":[')

            first = True
            for tile_list in tiles_list:
                tile, fg_ratio, qod = tile_list
                x, y, z = list(map(str, tile))
                if args.geojson:
                    prop = '"properties":{{"x":{},"y":{},"z":{},"fg":{:.1f},"qod":{:.1f}}}'.format(
                        x, y, z, fg_ratio, qod)
                    geom = '"geometry":{}'.format(
                        json.dumps(feature(tile, precision=6)["geometry"]))
                    out.write('{}{{"type":"Feature",{},{}}}'.format(
                        "," if not first else "", geom, prop))
                    first = False
                else:
                    out.write("{},{},{}\t{:.1f}\t{:.1f}{}".format(
                        x, y, z, fg_ratio, qod, os.linesep))

            if args.geojson:
                out.write("]}")
            out.close()
        print("Cover {} generated, with {} tiles, selected from initials {}".
              format(args.out, len(tiles_list), len(tiles)))

    base_url = args.web_ui_base_url if args.web_ui_base_url else "./"

    if args.mode == "side" and not args.no_web_ui:
        template = "compare.html" if not args.web_ui_template else args.web_ui_template
        web_ui(args.out, base_url, tiles, tiles_compare, args.format, template)

    if args.mode == "stack" and not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])]
        web_ui(args.out, base_url, tiles, tiles_compare, args.format, template)
示例#14
0
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    palette = make_palette(config["classes"][0]["color"],
                           config["classes"][1]["color"])
    args.workers = torch.cuda.device_count() * 2 if torch.device(
        "cuda") and not args.workers else args.workers

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        device = torch.device("cpu")

    chkpt = torch.load(args.checkpoint, map_location=device)
    model_module = load_module("robosat_pink.models.{}".format(
        chkpt["nn"].lower()))
    nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"],
                                            chkpt["shape_out"]).to(device)
    nn = torch.nn.DataParallel(nn)
    nn.load_state_dict(chkpt["state_dict"])
    nn.eval()

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    loader_module = load_module("robosat_pink.loaders.{}".format(
        chkpt["loader"].lower()))
    loader_predict = getattr(loader_module,
                             chkpt["loader"])(config,
                                              chkpt["shape_in"][1:3],
                                              args.dataset,
                                              mode="predict")

    loader = DataLoader(loader_predict,
                        batch_size=args.bs,
                        num_workers=args.workers)
    assert len(loader), "Empty predict dataset directory. Check your path."

    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):

            images = images.to(device)

            outputs = nn(images)
            probs = torch.nn.functional.softmax(outputs,
                                                dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))
                mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()
                tile_label_to_file(args.out, mercantile.Tile(x, y, z), palette,
                                   mask)

    if not args.no_web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.out)]
        web_ui(args.out, base_url, tiles, tiles, "png", template)
示例#15
0
def main(args):

    if (int(args.bbox is not None) + int(args.geojson is not None) +
            int(args.dir is not None) + int(args.raster is not None) +
            int(args.cover is not None) != 1):
        sys.exit(
            "ERROR: One, and only one, input type must be provided, among: --dir, --bbox, --cover or --geojson."
        )

    if args.bbox:
        try:
            w, s, e, n, crs = args.bbox.split(",")
            w, s, e, n = map(float, (w, s, e, n))
        except:
            try:
                crs = None
                w, s, e, n = map(float, args.bbox.split(","))
            except:
                sys.exit("ERROR: invalid bbox parameter.")

    if args.splits:

        try:
            splits = [int(split) for split in args.splits.split("/")]
            assert len(splits) == len(args.out)
            assert sum(splits) == 100
        except:
            sys.exit(
                "ERROR: Invalid split value or incoherent with provided out paths."
            )

    if not args.zoom and (args.geojson or args.bbox or args.raster):
        sys.exit("ERROR: Zoom parameter is required.")

    args.out = [os.path.expanduser(out) for out in args.out]

    cover = []

    if args.raster:
        print("RoboSat.pink - cover from {} at zoom {}".format(
            args.raster, args.zoom))
        with rasterio_open(os.path.expanduser(args.raster)) as r:
            try:
                w, s, e, n = transform_bounds(r.crs, "EPSG:4326", *r.bounds)
            except:
                sys.exit("ERROR: unable to deal with raster projection")

            cover = tiles(w, s, e, n, args.zoom)

    if args.geojson:
        print("RoboSat.pink - cover from {} at zoom {}".format(
            args.geojson, args.zoom))
        with open(os.path.expanduser(args.geojson)) as f:
            features = json.load(f)

        try:
            for feature in tqdm(features["features"],
                                ascii=True,
                                unit="feature"):
                cover.extend(
                    map(tuple,
                        burntiles.burn([feature], args.zoom).tolist()))
        except:
            sys.exit("ERROR: invalid or unsupported GeoJSON.")

        cover = list(set(
            cover))  # tiles can overlap for multiple features; unique tile ids

    if args.bbox:
        print("RoboSat.pink - cover from {} at zoom {}".format(
            args.bbox, args.zoom))
        if crs:
            try:
                w, s, e, n = transform_bounds(crs, "EPSG:4326", w, s, e, n)
            except:
                sys.exit("ERROR: unable to deal with raster projection")

        cover = tiles(w, s, e, n, args.zoom)

    if args.dir:
        print("RoboSat.pink - cover from {}".format(args.dir))
        cover = [tile for tile, _ in tiles_from_slippy_map(args.dir)]

    if args.cover:
        print("RoboSat.pink - cover from {}".format(args.cover))
        cover = [tile for tile in tiles_from_csv(args.cover)]

    if args.splits:
        shuffle(cover)  # in-place
        splits = [
            math.floor(len(cover) * split / 100)
            for i, split in enumerate(splits, 1)
        ]
        s = 0
        covers = []
        for e in splits:
            covers.append(cover[s:s + e - 1])
            s += e
    else:
        covers = [cover]

    for i, cover in enumerate(covers):

        if os.path.dirname(args.out[i]) and not os.path.isdir(
                os.path.dirname(args.out[i])):
            os.makedirs(os.path.dirname(args.out[i]), exist_ok=True)

        with open(args.out[i], "w") as fp:
            csv.writer(fp).writerows(cover)