def main(args):
    model = load_config(args.model)
    dataset = load_config(args.dataset)

    cuda = model['common']['cuda']

    if cuda and not torch.cuda.is_available():
        sys.exit('Error: CUDA requested but not available')

    global size
    size = args.size

    global token
    token = os.getenv('MAPBOX_ACCESS_TOKEN')

    if not token:
        sys.exit('Error: map token needed visualizing results; export MAPBOX_ACCESS_TOKEN')

    global session
    session = requests.Session()

    global tiles
    tiles = args.url

    global predictor
    predictor = Predictor(args.checkpoint, model, dataset)

    app.run(host=args.host, port=args.port, threaded=False)
コード例 #2
0
def main(args):
    model = load_config(args.model)
    dataset = load_config(args.dataset)

    cuda = model['common']['cuda']

    if cuda and not torch.cuda.is_available():
        sys.exit('Error: CUDA requested but not available')

    global size
    size = args.size

    global token
    token = os.getenv('MAPBOX_ACCESS_TOKEN')

    if not token:
        sys.exit(
            'Error: map token needed visualizing results; export MAPBOX_ACCESS_TOKEN'
        )

    global session
    session = requests.Session()

    global tiles
    tiles = args.url

    global predictor
    predictor = Predictor(args.checkpoint, model, dataset)

    app.run(host=args.host, port=args.port, threaded=False)
コード例 #3
0
ファイル: weights.py プロジェクト: zzm422/robosat
def main(args):
    dataset = load_config(args.dataset)

    path = dataset["common"]["dataset"]
    num_classes = len(dataset["common"]["classes"])

    train_transform = Compose([ConvertImageMode(mode="P"), MaskToTensor()])

    train_dataset = SlippyMapTiles(os.path.join(path, "training", "labels"), transform=train_transform)

    n = 0
    counts = np.zeros(num_classes, dtype=np.int64)

    loader = DataLoader(train_dataset, batch_size=1)
    for images, tile in tqdm(loader, desc="Loading", unit="image", ascii=True):
        image = torch.squeeze(images)

        image = np.array(image, dtype=np.uint8)
        n += image.shape[0] * image.shape[1]
        counts += np.bincount(image.ravel(), minlength=num_classes)

    # Class weighting scheme `w = 1 / ln(c + p)` see:
    # - https://arxiv.org/abs/1707.03718
    #     LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation
    # - https://arxiv.org/abs/1606.02147
    #     ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation

    probs = counts / n
    weights = 1 / np.log(1.02 + probs)

    weights.round(6, out=weights)
    print(weights.tolist())
コード例 #4
0
ファイル: features.py プロジェクト: hzitoun/robosat
def main(args):

    module_search_path = [args.path] if args.path else []
    module_search_path.append(os.path.join(Path(__file__).parent.parent, "features"))
    modules = [(path, name) for path, name, _ in pkgutil.iter_modules(module_search_path) if name != "core"]
    if args.type not in [name for _, name in modules]:
        sys.exit("Unknown type, thoses available are {}".format([name for _, name in modules]))

    config = load_config(args.config)
    labels = config["classes"]["titles"]
    if args.type not in labels:
        sys.exit("The type you asked is not consistent with yours classes in the config file provided.")
    index = labels.index(args.type)

    if args.path:
        sys.path.append(args.path)
        module = import_module(args.type)
    else:
        module = import_module("robosat.features.{}".format(args.type))

    handler = getattr(module, "{}Handler".format(args.type.title()))()

    for tile, path in tqdm(list(tiles_from_slippy_map(args.masks)), ascii=True, unit="mask"):
        image = np.array(Image.open(path).convert("P"), dtype=np.uint8)
        mask = (image == index).astype(np.uint8)
        handler.apply(tile, mask)

    handler.save(args.out)
コード例 #5
0
def main(args):
    dataset = load_config(args.dataset)

    classes = dataset["common"]["classes"]
    colors = dataset["common"]["colors"]
    assert len(classes) == len(colors), "classes and colors coincide"

    assert len(colors) == 2, "only binary models supported right now"
    bg = colors[0]
    fg = colors[1]

    os.makedirs(args.out, exist_ok=True)

    # We can only rasterize all tiles at a single zoom.
    assert all(tile.z == args.zoom for tile in tiles_from_csv(args.tiles))

    with open(args.features) as f:
        fc = json.load(f)

    # Find all tiles the features cover and make a map object for quick lookup.
    feature_map = collections.defaultdict(list)
    for i, feature in enumerate(
            tqdm(fc["features"], ascii=True, unit="feature")):

        if feature["geometry"]["type"] != "Polygon":
            continue

        try:
            for tile in burntiles.burn([feature], zoom=args.zoom):
                feature_map[mercantile.Tile(*tile)].append(feature)
        except ValueError as e:
            print("Warning: invalid feature {}, skipping".format(i),
                  file=sys.stderr)
            continue

    # Burn features to tiles and write to a slippy map directory.
    for tile in tqdm(list(tiles_from_csv(args.tiles)), ascii=True,
                     unit="tile"):
        if tile in feature_map:
            out = burn(tile, feature_map[tile], args.size)
        else:
            out = np.zeros(shape=(args.size, args.size), dtype=np.uint8)

        out_dir = os.path.join(args.out, str(tile.z), str(tile.x))
        os.makedirs(out_dir, exist_ok=True)

        out_path = os.path.join(out_dir, "{}.png".format(tile.y))

        if os.path.exists(out_path):
            prev = np.array(Image.open(out_path))
            out = np.maximum(out, prev)

        out = Image.fromarray(out, mode="P")

        palette = make_palette(bg, fg)
        out.putpalette(palette)

        out.save(out_path, optimize=True)
コード例 #6
0
def main(args):
    dataset = load_config(args.dataset)

    classes = dataset['common']['classes']
    colors = dataset['common']['colors']
    assert len(classes) == len(colors), 'classes and colors coincide'

    assert len(colors) == 2, 'only binary models supported right now'
    bg = colors[0]
    fg = colors[1]

    os.makedirs(args.out, exist_ok=True)

    # We can only rasterize all tiles at a single zoom.
    assert all(tile.z == args.zoom for tile in tiles_from_csv(args.tiles))

    with open(args.features) as f:
        fc = json.load(f)

    # Find all tiles the features cover and make a map object for quick lookup.
    feature_map = collections.defaultdict(list)
    for i, feature in enumerate(
            tqdm(fc['features'], ascii=True, unit='feature')):

        if feature['geometry']['type'] != 'Polygon':
            continue

        try:
            for tile in burntiles.burn([feature], zoom=args.zoom):
                feature_map[mercantile.Tile(*tile)].append(feature)
        except ValueError as e:
            print('Warning: invalid feature {}, skipping'.format(i),
                  file=sys.stderr)
            continue

    # Burn features to tiles and write to a slippy map directory.
    for tile in tqdm(list(tiles_from_csv(args.tiles)), ascii=True,
                     unit='tile'):
        if tile in feature_map:
            out = burn(tile, feature_map[tile], args.size)
        else:
            out = Image.fromarray(np.zeros(shape=(args.size,
                                                  args.size)).astype(int),
                                  mode='P')

        palette = make_palette(bg, fg)
        out.putpalette(palette)

        out_path = os.path.join(args.out, str(tile.z), str(tile.x))
        os.makedirs(out_path, exist_ok=True)

        out.save(os.path.join(out_path, '{}.png'.format(tile.y)),
                 optimize=True)
コード例 #7
0
ファイル: export.py プロジェクト: zzm422/robosat
def main(args):
    dataset = load_config(args.dataset)

    num_classes = len(dataset["common"]["classes"])
    net = UNet(num_classes)

    chkpt = torch.load(args.checkpoint, map_location="cpu")
    net.load_state_dict(chkpt)
    net = torch.nn.DataParallel(net)

    # Todo: make input channels configurable, not hard-coded to three channels for RGB
    batch = torch.autograd.Variable(torch.randn(1, 3, args.image_size, args.image_size))

    torch.onnx.export(net, batch, args.model)
def main(args):
    dataset = load_config(args.dataset)

    classes = dataset['common']['classes']
    colors = dataset['common']['colors']
    assert len(classes) == len(colors), 'classes and colors coincide'

    assert len(colors) == 2, 'only binary models supported right now'
    bg = colors[0]
    fg = colors[1]

    os.makedirs(args.out, exist_ok=True)

    # We can only rasterize all tiles at a single zoom.
    assert all(tile.z == args.zoom for tile in tiles_from_csv(args.tiles))

    with open(args.features) as f:
        fc = json.load(f)

    # Find all tiles the features cover and make a map object for quick lookup.
    feature_map = collections.defaultdict(list)
    for i, feature in enumerate(tqdm(fc['features'], ascii=True, unit='feature')):

        if feature['geometry']['type'] != 'Polygon':
            continue

        try:
            for tile in burntiles.burn([feature], zoom=args.zoom):
                feature_map[mercantile.Tile(*tile)].append(feature)
        except ValueError as e:
            print('Warning: invalid feature {}, skipping'.format(i), file=sys.stderr)
            continue

    # Burn features to tiles and write to a slippy map directory.
    for tile in tqdm(list(tiles_from_csv(args.tiles)), ascii=True, unit='tile'):
        if tile in feature_map:
            out = burn(tile, feature_map[tile], args.size)
        else:
            out = Image.fromarray(np.zeros(shape=(args.size, args.size)).astype(int), mode='P')

        palette = make_palette(bg, fg)
        out.putpalette(palette)

        out_path = os.path.join(args.out, str(tile.z), str(tile.x))
        os.makedirs(out_path, exist_ok=True)

        out.save(os.path.join(out_path, '{}.png'.format(tile.y)), optimize=True)
コード例 #9
0
def main(args):
    config = load_config(args.config)

    if args.type == "onnx":
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        # Workaround: PyTorch ONNX, DataParallel with GPU issue, cf https://github.com/pytorch/pytorch/issues/5315

    num_classes = len(config["classes"]["classes"])
    num_channels = 0
    for channel in config["channels"]:
        num_channels += len(channel["bands"])

    export_channels = num_channels if not args.export_channels else args.export_channels
    assert num_channels >= export_channels, "Will be hard indeed, to export more channels than thoses dataset provide"

    def map_location(storage, _):
        return storage.cpu()

    net = UNet(num_classes, num_channels=num_channels).to("cpu")
    chkpt = torch.load(args.checkpoint, map_location=map_location)
    net = torch.nn.DataParallel(net)
    net.load_state_dict(chkpt["state_dict"])

    if export_channels < num_channels:
        weights = torch.zeros((64, export_channels, 7, 7))
        weights.data = net.module.resnet.conv1.weight.data[:, :
                                                           export_channels, :, :]
        net.module.resnet.conv1 = nn.Conv2d(num_channels,
                                            64,
                                            kernel_size=7,
                                            stride=2,
                                            padding=3,
                                            bias=False)
        net.module.resnet.conv1.weight = nn.Parameter(weights)

    if args.type == "onnx":
        batch = torch.autograd.Variable(
            torch.randn(1, export_channels, args.image_size, args.image_size))
        torch.onnx.export(net, batch, args.out)

    elif args.type == "pth":
        states = {
            "epoch": chkpt["epoch"],
            "state_dict": net.state_dict(),
            "optimizer": chkpt["optimizer"]
        }
        torch.save(states, args.out)
def main(args):
    dataset = load_config(args.dataset)

    labels = dataset['common']['classes']
    assert set(labels).issuperset(set(handlers.keys())), 'handlers have a class label'
    index = labels.index(args.type)

    handler = handlers[args.type]()

    tiles = list(tiles_from_slippy_map(args.masks))

    for tile, path in tqdm(tiles, ascii=True, unit='mask'):
        image = np.array(Image.open(path).convert('P'), dtype=np.uint8)
        mask = (image == index).astype(np.uint8)

        handler.apply(tile, mask)

    handler.save(args.out)
コード例 #11
0
ファイル: features.py プロジェクト: anannuoqi91/robosat202101
def main(args):
    dataset = load_config(args.dataset)

    labels = dataset["common"]["classes"]
    assert set(labels).issuperset(set(handlers.keys())), "handlers have a class label"
    index = labels.index(args.type)

    handler = handlers[args.type]()

    tiles = list(tiles_from_slippy_map(args.masks))

    for tile, path in tqdm(tiles, ascii=True, unit="mask"):
        image = np.array(Image.open(path).convert("P"), dtype=np.uint8)
        mask = (image == index).astype(np.uint8)

        handler.apply(tile, mask)

    handler.save(args.out)
def main(args):
    dataset = load_config(args.dataset)
    path = dataset['common']['dataset']

    train_transform = Compose([
        ConvertImageMode(mode='RGB'),
        ImageToTensor()
    ])

    train_dataset = SlippyMapTiles(os.path.join(path, 'training', 'images'), transform=train_transform)

    n = 0
    mean = np.zeros(3, dtype=np.float64)

    loader = DataLoader(train_dataset, batch_size=1)
    for images, tile in tqdm(loader, desc='Loading', unit='image', ascii=True):
        image = torch.squeeze(images)
        assert image.size(0) == 3, 'channel first'

        image = np.array(image, dtype=np.float64)
        n += image.shape[1] * image.shape[2]

        mean += np.sum(image, axis=(1, 2))

    mean /= n
    mean.round(decimals=6, out=mean)
    print('mean: {}'.format(mean.tolist()))

    std = np.zeros(3, dtype=np.float64)

    loader = DataLoader(train_dataset, batch_size=1)
    for images, tile in tqdm(loader, desc='Loading', unit='image', ascii=True):
        image = torch.squeeze(images)
        assert image.size(0) == 3, 'channel first'

        image = np.array(image, dtype=np.float64)
        difference = np.transpose(image, (1, 2, 0)) - mean
        std += np.sum(np.square(difference), axis=(0, 1))

    std = np.sqrt(std / (n - 1))
    std.round(decimals=6, out=std)
    print('std: {}'.format(std.tolist()))
コード例 #13
0
def main(args):
    dataset = load_config(args.dataset)
    path = dataset["common"]["dataset"]

    train_transform = Compose([ConvertImageMode(mode="RGB"), ImageToTensor()])

    train_dataset = SlippyMapTiles(os.path.join(path, "training", "images"),
                                   transform=train_transform)

    n = 0
    mean = np.zeros(3, dtype=np.float64)

    loader = DataLoader(train_dataset, batch_size=1)
    for images, tile in tqdm(loader, desc="Loading", unit="image", ascii=True):
        image = torch.squeeze(images)
        assert image.size(0) == 3, "channel first"

        image = np.array(image, dtype=np.float64)
        n += image.shape[1] * image.shape[2]

        mean += np.sum(image, axis=(1, 2))

    mean /= n
    mean.round(decimals=6, out=mean)
    print("mean: {}".format(mean.tolist()))

    std = np.zeros(3, dtype=np.float64)

    loader = DataLoader(train_dataset, batch_size=1)
    for images, tile in tqdm(loader, desc="Loading", unit="image", ascii=True):
        image = torch.squeeze(images)
        assert image.size(0) == 3, "channel first"

        image = np.array(image, dtype=np.float64)
        difference = np.transpose(image, (1, 2, 0)) - mean
        std += np.sum(np.square(difference), axis=(0, 1))

    std = np.sqrt(std / (n - 1))
    std.round(decimals=6, out=std)
    print("std: {}".format(std.tolist()))
コード例 #14
0
ファイル: serve.py プロジェクト: hzitoun/robosat
def main(args):
    config = load_config(args.config)

    global size
    size = args.tile_size

    global token
    token = os.getenv("MAPBOX_ACCESS_TOKEN")

    if not token:
        sys.exit(
            "Error: map token needed visualizing results; export MAPBOX_ACCESS_TOKEN"
        )

    global session
    session = requests.Session()

    global tiles
    tiles = args.url

    global predictor
    predictor = Predictor(args.checkpoint, config)

    app.run(host=args.host, port=args.port, threaded=False)
コード例 #15
0
def main(args):
    dataset = load_config(args.dataset)

    classes = dataset["common"]["classes"]
    colors = dataset["common"]["colors"]
    assert len(classes) == len(colors), "classes and colors coincide"

    assert len(colors) == 2, "only binary models supported right now"
    bg = colors[0]
    fg = colors[1]

    os.makedirs(args.out, exist_ok=True)

    # We can only rasterize all tiles at a single zoom.
    assert all(tile.z == args.zoom for tile in tiles_from_csv(args.tiles))

    with open(args.features) as f:
        fc = json.load(f)

    # Find all tiles the features cover and make a map object for quick lookup.
    feature_map = collections.defaultdict(list)
    for i, feature in enumerate(
            tqdm(fc["features"], ascii=True, unit="feature")):

        if feature["geometry"]["type"] != "Polygon":
            continue

        try:
            for tile in burntiles.burn([feature], zoom=args.zoom):
                feature_map[mercantile.Tile(*tile)].append(feature)
        except ValueError as e:
            print("Warning: invalid feature {}, skipping".format(i),
                  file=sys.stderr)
            continue

    single_burning(args, feature_map, bg, fg)
コード例 #16
0
def main(args):
    config = load_config(args.config)

    classes = config["classes"]["titles"]
    colors = config["classes"]["colors"]
    assert len(classes) == len(colors), "classes and colors coincide"
    assert len(colors) == 2, "only binary models supported right now"

    os.makedirs(args.out, exist_ok=True)

    # We can only rasterize all tiles at a single zoom.
    assert all(tile.z == args.zoom for tile in tiles_from_csv(args.cover))

    # Find all tiles the features cover and make a map object for quick lookup.
    feature_map = collections.defaultdict(list)
    log = Log(os.path.join(args.out, "log"), out=sys.stderr)

    def parse_polygon(feature_map, polygon, i):

        try:
            for i, ring in enumerate(
                    polygon["coordinates"]
            ):  # GeoJSON coordinates could be N dimensionals
                polygon["coordinates"][i] = [[
                    x, y
                ] for point in ring for x, y in zip([point[0]], [point[1]])]

            for tile in burntiles.burn([{
                    "type": "feature",
                    "geometry": polygon
            }],
                                       zoom=args.zoom):
                feature_map[mercantile.Tile(*tile)].append({
                    "type": "feature",
                    "geometry": polygon
                })

        except ValueError as e:
            log.log("Warning: invalid feature {}, skipping".format(i))

        return feature_map

    def parse_geometry(feature_map, geometry, i):

        if geometry["type"] == "Polygon":
            feature_map = parse_polygon(feature_map, geometry, i)

        elif geometry["type"] == "MultiPolygon":
            for polygon in geometry["coordinates"]:
                feature_map = parse_polygon(feature_map, {
                    "type": "Polygon",
                    "coordinates": polygon
                }, i)
        else:
            log.log(
                "Notice: {} is a non surfacic geometry type, skipping feature {}"
                .format(geometry["type"], i))

        return feature_map

    for feature in args.features:
        with open(feature) as f:
            fc = json.load(f)
            for i, feature in enumerate(
                    tqdm(fc["features"], ascii=True, unit="feature")):

                if feature["geometry"]["type"] == "GeometryCollection":
                    for geometry in feature["geometry"]["geometries"]:
                        feature_map = parse_geometry(feature_map, geometry, i)
                else:
                    feature_map = parse_geometry(feature_map,
                                                 feature["geometry"], i)

    # Burn features to tiles and write to a slippy map directory.
    for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True,
                     unit="tile"):
        if tile in feature_map:
            out = burn(tile, feature_map[tile], args.size)
        else:
            out = np.zeros(shape=(args.size, args.size), dtype=np.uint8)

        out_dir = os.path.join(args.out, str(tile.z), str(tile.x))
        os.makedirs(out_dir, exist_ok=True)

        out_path = os.path.join(out_dir, "{}.png".format(tile.y))

        if os.path.exists(out_path):
            prev = np.array(Image.open(out_path))
            out = np.maximum(out, prev)

        out = Image.fromarray(out, mode="P")

        out_path = os.path.join(args.out, str(tile.z), str(tile.x))
        os.makedirs(out_path, exist_ok=True)

        out.putpalette(
            complementary_palette(make_palette(colors[0], colors[1])))
        out.save(os.path.join(out_path, "{}.png".format(tile.y)),
                 optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile in tiles_from_csv(args.cover)]
        web_ui(args.out, args.web_ui, tiles, tiles, "png", template)
コード例 #17
0
def main(args):
    model = load_config(args.model)
    dataset = load_config(args.dataset)

    device = torch.device("cuda" if model["common"]["cuda"] else "cpu")

    if model["common"]["cuda"] and not torch.cuda.is_available():
        sys.exit("Error: CUDA requested but not available")

    os.makedirs(model["common"]["checkpoint"], exist_ok=True)

    num_classes = len(dataset["common"]["classes"])
    net = UNet(num_classes)
    net = DataParallel(net)
    net = net.to(device)

    if model["common"]["cuda"]:
        torch.backends.cudnn.benchmark = True

    try:
        weight = torch.Tensor(dataset["weights"]["values"])
    except KeyError:
        if model["opt"]["loss"] in ("CrossEntropy", "mIoU", "Focal"):
            sys.exit(
                "Error: The loss function used, need dataset weights values")

    optimizer = Adam(net.parameters(),
                     lr=model["opt"]["lr"],
                     weight_decay=model["opt"]["decay"])

    resume = 0
    if args.checkpoint:

        def map_location(storage, _):
            return storage.cuda() if model["common"]["cuda"] else storage.cpu()

        # https://github.com/pytorch/pytorch/issues/7178
        chkpt = torch.load(args.checkpoint, map_location=map_location)
        net.load_state_dict(chkpt["state_dict"])

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]

    if model["opt"]["loss"] == "CrossEntropy":
        criterion = CrossEntropyLoss2d(weight=weight).to(device)
    elif model["opt"]["loss"] == "mIoU":
        criterion = mIoULoss2d(weight=weight).to(device)
    elif model["opt"]["loss"] == "Focal":
        criterion = FocalLoss2d(weight=weight).to(device)
    elif model["opt"]["loss"] == "Lovasz":
        criterion = LovaszLoss2d().to(device)
    else:
        sys.exit("Error: Unknown [opt][loss] value !")

    train_loader, val_loader = get_dataset_loaders(model, dataset,
                                                   args.workers)

    num_epochs = model["opt"]["epochs"]
    if resume >= num_epochs:
        sys.exit(
            "Error: Epoch {} set in {} already reached by the checkpoint provided"
            .format(num_epochs, args.model))

    history = collections.defaultdict(list)
    log = Log(os.path.join(model["common"]["checkpoint"], "log"))

    log.log("--- Hyper Parameters on Dataset: {} ---".format(
        dataset["common"]["dataset"]))
    log.log("Batch Size:\t {}".format(model["common"]["batch_size"]))
    log.log("Image Size:\t {}".format(model["common"]["image_size"]))
    log.log("Learning Rate:\t {}".format(model["opt"]["lr"]))
    log.log("Weight Decay:\t {}".format(model["opt"]["decay"]))
    log.log("Loss function:\t {}".format(model["opt"]["loss"]))
    if "weight" in locals():
        log.log("Weights :\t {}".format(dataset["weights"]["values"]))
    log.log("---")

    for epoch in range(resume, num_epochs):
        log.log("Epoch: {}/{}".format(epoch + 1, num_epochs))

        train_hist = train(train_loader, num_classes, device, net, optimizer,
                           criterion)
        log.log(
            "Train    loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".
            format(
                train_hist["loss"],
                train_hist["miou"],
                dataset["common"]["classes"][1],
                train_hist["fg_iou"],
                train_hist["mcc"],
            ))

        for k, v in train_hist.items():
            history["train " + k].append(v)

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        log.log(
            "Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".
            format(val_hist["loss"], val_hist["miou"],
                   dataset["common"]["classes"][1], val_hist["fg_iou"],
                   val_hist["mcc"]))

        for k, v in val_hist.items():
            history["val " + k].append(v)

        visual = "history-{:05d}-of-{:05d}.png".format(epoch + 1, num_epochs)
        plot(os.path.join(model["common"]["checkpoint"], visual), history)

        checkpoint = "checkpoint-{:05d}-of-{:05d}.pth".format(
            epoch + 1, num_epochs)

        states = {
            "epoch": epoch + 1,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict()
        }

        torch.save(states,
                   os.path.join(model["common"]["checkpoint"], checkpoint))
コード例 #18
0
def main(args):
    model = load_config(args.model)
    dataset = load_config(args.dataset)

    device = torch.device("cuda" if model["common"]["cuda"] else "cpu")

    if model["common"]["cuda"] and not torch.cuda.is_available():
        sys.exit("Error: CUDA requested but not available")

    # if args.batch_size < 2:
    #     sys.exit('Error: PSPNet requires more than one image for BatchNorm in Pyramid Pooling')

    os.makedirs(model["common"]["checkpoint"], exist_ok=True)

    num_classes = len(dataset["common"]["classes"])
    net = UNet(num_classes)
    net = DataParallel(net)
    net = net.to(device)

    if model["common"]["cuda"]:
        torch.backends.cudnn.benchmark = True

    if args.checkpoint:

        def map_location(storage, _):
            return storage.cuda() if model["common"]["cuda"] else storage.cpu()

        # https://github.com/pytorch/pytorch/issues/7178
        chkpt = torch.load(args.checkpoint, map_location=map_location)
        net.load_state_dict(chkpt)

    optimizer = Adam(net.parameters(),
                     lr=model["opt"]["lr"],
                     weight_decay=model["opt"]["decay"])

    weight = torch.Tensor(dataset["weights"]["values"])

    criterion = CrossEntropyLoss2d(weight=weight).to(device)
    # criterion = FocalLoss2d(weight=weight).to(device)

    train_loader, val_loader = get_dataset_loaders(model, dataset)

    num_epochs = model["opt"]["epochs"]

    history = collections.defaultdict(list)

    for epoch in range(num_epochs):
        print("Epoch: {}/{}".format(epoch + 1, num_epochs))

        train_hist = train(train_loader, num_classes, device, net, optimizer,
                           criterion)
        print("Train loss: {:.4f}, mean IoU: {:.4f}".format(
            train_hist["loss"], train_hist["iou"]))

        for k, v in train_hist.items():
            history["train " + k].append(v)

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        print("Validate loss: {:.4f}, mean IoU: {:.4f}".format(
            val_hist["loss"], val_hist["iou"]))

        for k, v in val_hist.items():
            history["val " + k].append(v)

        visual = "history-{:05d}-of-{:05d}.png".format(epoch + 1, num_epochs)
        plot(os.path.join(model["common"]["checkpoint"], visual), history)

        checkpoint = "checkpoint-{:05d}-of-{:05d}.pth".format(
            epoch + 1, num_epochs)
        torch.save(net.state_dict(),
                   os.path.join(model["common"]["checkpoint"], checkpoint))
コード例 #19
0
def main(args):
    config = load_config(args.config)
    num_classes = len(config["classes"]["titles"])

    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    def map_location(storage, _):
        return storage.cuda() if torch.cuda.is_available() else storage.cpu()

    # https://github.com/pytorch/pytorch/issues/7178
    chkpt = torch.load(args.checkpoint, map_location=map_location)

    net = UNet(num_classes).to(device)
    net = nn.DataParallel(net)

    net.load_state_dict(chkpt["state_dict"])
    net.eval()

    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

    transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)])

    directory = BufferedSlippyMapDirectory(args.tiles,
                                           transform=transform,
                                           size=args.tile_size,
                                           overlap=args.overlap)
    loader = DataLoader(directory,
                        batch_size=args.batch_size,
                        num_workers=args.workers)

    if args.masks_output:
        palette = make_palette(config["classes"]["colors"][0],
                               config["classes"]["colors"][1])
    else:
        palette = continuous_palette_for_color("pink", 256)

    # don't track tensors with autograd during prediction
    with torch.no_grad():
        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):
            images = images.to(device)
            outputs = net(images)

            # manually compute segmentation mask class probabilities per pixel
            probs = nn.functional.softmax(outputs, dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))

                # we predicted on buffered tiles; now get back probs for original image
                prob = directory.unbuffer(prob)

                assert prob.shape[
                    0] == 2, "single channel requires binary model"
                assert np.allclose(
                    np.sum(prob, axis=0), 1.0
                ), "single channel requires probabilities to sum up to one"

                if args.masks_output:
                    image = np.around(prob[1:, :, :]).astype(
                        np.uint8).squeeze()
                else:
                    image = (prob[1:, :, :] * 255).astype(np.uint8).squeeze()

                out = Image.fromarray(image, mode="P")
                out.putpalette(palette)

                os.makedirs(os.path.join(args.probs, str(z), str(x)),
                            exist_ok=True)
                path = os.path.join(args.probs, str(z), str(x),
                                    str(y) + ".png")

                out.save(path, optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.tiles)]
        web_ui(args.probs, args.web_ui, tiles, tiles, "png", template)
コード例 #20
0
def main(args):
    model = load_config(args.model)
    dataset = load_config(args.dataset)

    cuda = model["common"]["cuda"]

    device = torch.device("cuda" if cuda else "cpu")

    def map_location(storage, _):
        return storage.cuda() if cuda else storage.cpu()

    if cuda and not torch.cuda.is_available():
        sys.exit("Error: CUDA requested but not available")

    num_classes = len(dataset["common"]["classes"])

    # https://github.com/pytorch/pytorch/issues/7178
    chkpt = torch.load(args.checkpoint, map_location=map_location)

    net = UNet(num_classes).to(device)
    net = nn.DataParallel(net)

    if cuda:
        torch.backends.cudnn.benchmark = True

    net.load_state_dict(chkpt["state_dict"])
    net.eval()

    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

    transform = Compose([
        ConvertImageMode(mode="RGB"),
        ImageToTensor(),
        Normalize(mean=mean, std=std)
    ])

    directory = BufferedSlippyMapDirectory(args.tiles,
                                           transform=transform,
                                           size=args.tile_size,
                                           overlap=args.overlap)
    loader = DataLoader(directory,
                        batch_size=args.batch_size,
                        num_workers=args.workers)

    # don't track tensors with autograd during prediction
    with torch.no_grad():
        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):
            images = images.to(device)
            outputs = net(images)

            # manually compute segmentation mask class probabilities per pixel
            probs = nn.functional.softmax(outputs, dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))

                # we predicted on buffered tiles; now get back probs for original image
                prob = directory.unbuffer(prob)

                # Quantize the floating point probabilities in [0,1] to [0,255] and store
                # a single-channel `.png` file with a continuous color palette attached.

                assert prob.shape[
                    0] == 2, "single channel requires binary model"
                assert np.allclose(
                    np.sum(prob, axis=0), 1.
                ), "single channel requires probabilities to sum up to one"
                foreground = prob[1:, :, :]

                anchors = np.linspace(0, 1, 256)
                quantized = np.digitize(foreground, anchors).astype(np.uint8)

                palette = continuous_palette_for_color("pink", 256)

                out = Image.fromarray(quantized.squeeze(), mode="P")
                out.putpalette(palette)

                os.makedirs(os.path.join(args.probs, str(z), str(x)),
                            exist_ok=True)
                path = os.path.join(args.probs, str(z), str(x),
                                    str(y) + ".png")

                out.save(path, optimize=True)
def main(args):
    model = load_config(args.model)
    dataset = load_config(args.dataset)

    device = torch.device('cuda' if model['common']['cuda'] else 'cpu')

    if model['common']['cuda'] and not torch.cuda.is_available():
        sys.exit('Error: CUDA requested but not available')

    # if args.batch_size < 2:
    #     sys.exit('Error: PSPNet requires more than one image for BatchNorm in Pyramid Pooling')

    os.makedirs(model['common']['checkpoint'], exist_ok=True)

    num_classes = len(dataset['common']['classes'])
    net = UNet(num_classes).to(device)

    if args.resume:
        path = os.path.join(model['common']['checkpoint'], args.resume)

        cuda = model['common']['cuda']

        def map_location(storage, _):
            return storage.cuda() if cuda else storage.cpu()

        chkpt = torch.load(path, map_location=map_location)
        net.load_state_dict(chkpt)
        resume_at_epoch = int(args.resume[11:16])
    else:
        resume_at_epoch = 0

    if model['common']['cuda']:
        torch.backends.cudnn.benchmark = True
        net = DataParallel(net)

    optimizer = SGD(net.parameters(), lr=model['opt']['lr'], momentum=model['opt']['momentum'])

    scheduler = MultiStepLR(optimizer, milestones=model['opt']['milestones'], gamma=model['opt']['gamma'])

    weight = torch.Tensor(dataset['weights']['values'])

    for i in range(resume_at_epoch):
        scheduler.step()

    criterion = CrossEntropyLoss2d(weight=weight).to(device)
    # criterion = FocalLoss2d(weight=weight).to(device)

    train_loader, val_loader = get_dataset_loaders(model, dataset)

    num_epochs = model['opt']['epochs']

    history = collections.defaultdict(list)

    for epoch in range(resume_at_epoch, num_epochs):
        print('Epoch: {}/{}'.format(epoch + 1, num_epochs))

        train_hist = train(train_loader, num_classes, device, net, optimizer, scheduler, criterion)
        print('Train loss: {:.4f}, mean IoU: {:.4f}'.format(train_hist['loss'], train_hist['iou']))

        for k, v in train_hist.items():
            history['train ' + k].append(v)

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        print('Validate loss: {:.4f}, mean IoU: {:.4f}'.format(val_hist['loss'], val_hist['iou']))

        for k, v in val_hist.items():
            history['val ' + k].append(v)

        visual = 'history-{:05d}-of-{:05d}.png'.format(epoch + 1, num_epochs)
        plot(os.path.join(model['common']['checkpoint'], visual), history)

        checkpoint = 'checkpoint-{:05d}-of-{:05d}.pth'.format(epoch + 1, num_epochs)
        torch.save(net.state_dict(), os.path.join(model['common']['checkpoint'], checkpoint))
コード例 #22
0
ファイル: compare.py プロジェクト: hzitoun/robosat
def main(args):

    if not args.masks or not args.labels or not args.config:
        if args.mode == "list":
            sys.exit(
                "Parameters masks, labels and config, are all mandatories in list mode."
            )
        if args.minimum_fg > 0 or args.maximum_fg < 100 or args.minimum_qod > 0 or args.maximum_qod < 100:
            sys.exit(
                "Parameters masks, labels and config, are all mandatories in QoD filtering."
            )

    if args.images:
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])]
        for image in args.images[1:]:
            assert sorted(tiles) == sorted([
                tile for tile, _ in tiles_from_slippy_map(image)
            ]), "inconsistent coverages"

    if args.labels and args.masks:
        tiles_masks = [tile for tile, _ in tiles_from_slippy_map(args.masks)]
        tiles_labels = [tile for tile, _ in tiles_from_slippy_map(args.labels)]
        if args.images:
            assert sorted(tiles) == sorted(tiles_masks) == sorted(
                tiles_labels), "inconsistent coverages"
        else:
            assert sorted(tiles_masks) == sorted(
                tiles_labels), "inconsistent coverages"
            tiles = tiles_masks

    if args.mode == "list":
        out = open(args.out, mode="w")
        if args.geojson:
            out.write('{"type":"FeatureCollection","features":[')
            first = True

    tiles_compare = []
    for tile in tqdm(list(tiles), desc="Compare", unit="tile", ascii=True):

        x, y, z = list(map(str, tile))

        if args.masks and args.labels and args.config:
            classes = load_config(args.config)["classes"]["classes"]
            dist, fg_ratio, qod = compare(args.masks, args.labels, tile,
                                          classes)
            if not args.minimum_fg <= fg_ratio <= args.maximum_fg or not args.minimum_qod <= qod <= args.maximum_qod:
                continue

        tiles_compare.append(tile)

        if args.mode == "side":

            for i, image in enumerate(args.images):
                img = tile_image(image, x, y, z)

                if i == 0:
                    side = np.zeros(
                        (img.shape[0], img.shape[1] * len(args.images), 3))
                    side = np.swapaxes(side, 0, 1) if args.vertical else side
                    image_shape = img.shape
                else:
                    assert image_shape == img.shape, "Unconsistent image size to compare"

                if args.vertical:
                    side[i * image_shape[0]:(i + 1) *
                         image_shape[0], :, :] = img
                else:
                    side[:,
                         i * image_shape[0]:(i + 1) * image_shape[0], :] = img

            os.makedirs(os.path.join(args.out, z, x), exist_ok=True)
            side = Image.fromarray(np.uint8(side))
            side.save(os.path.join(args.out, z, x, "{}.{}".format(y,
                                                                  args.ext)),
                      optimize=True)

        elif args.mode == "stack":

            for i, image in enumerate(args.images):
                img = tile_image(image, x, y, z)

                if i == 0:
                    image_shape = img.shape[0:2]
                    stack = img / len(args.images)
                else:
                    assert image_shape == img.shape[
                        0:2], "Unconsistent image size to compare"
                    stack = stack + (img / len(args.images))

            os.makedirs(os.path.join(args.out, str(z), str(x)), exist_ok=True)
            stack = Image.fromarray(np.uint8(stack))
            stack.save(os.path.join(args.out, str(z), str(x),
                                    "{}.{}".format(y, args.ext)),
                       optimize=True)

        elif args.mode == "list":
            if args.geojson:
                prop = '"properties":{{"x":{},"y":{},"z":{},"fg":{:.1f},"qod":{:.1f}}}'.format(
                    x, y, z, fg_ratio, qod)
                geom = '"geometry":{}'.format(
                    json.dumps(feature(tile, precision=6)["geometry"]))
                out.write('{}{{"type":"Feature",{},{}}}'.format(
                    "," if not first else "", geom, prop))
                first = False
            else:
                out.write("{},{},{}\t\t{:.1f}\t\t{:.1f}{}".format(
                    x, y, z, fg_ratio, qod, os.linesep))

    if args.mode == "list":
        if args.geojson:
            out.write("]}")
        out.close()

    elif args.mode == "side" and args.web_ui:
        template = "compare.html" if not args.web_ui_template else args.web_ui_template
        web_ui(args.out, args.web_ui, None, tiles_compare, args.ext, template)

    elif args.mode == "stack" and args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])]
        web_ui(args.out, args.web_ui, tiles, tiles_compare, args.ext, template)
コード例 #23
0
ファイル: zq_test.py プロジェクト: anannuoqi91/robosat202101
from robosat.tiles import tiles_from_slippy_map
from robosat.config import load_config

from robosat.features.parking import ParkingHandler

# Register post-processing handlers here; they need to support a `apply(tile, mask)` function
# for handling one mask and a `save(path)` function for GeoJSON serialization to a file.
handlers = {"parking": ParkingHandler}

args_dataset = r'/Users/zhangqi/Documents/GitHub/robosat/data/dataset-building-predict.toml'
args_masks = r'/Users/zhangqi/Documents/GitHub/robosat/data/predict_segmentation-masks'
args_type = 'parking'
args_out = r'/Users/zhangqi/Documents/GitHub/robosat/data/predict_geojson_features'

dataset = load_config(args_dataset)

labels = dataset["common"]["classes"]
assert set(labels).issuperset(set(
    handlers.keys())), "handlers have a class label"
index = labels.index(args_type)

handler = handlers[args_type]()

tiles = list(tiles_from_slippy_map(args_masks))

for tile, path in tqdm(tiles, ascii=True, unit="mask"):
    image = np.array(Image.open(path).convert("P"), dtype=np.uint8)
    mask = (image == index).astype(np.uint8)

    handler.apply(tile, mask)
コード例 #24
0
def main(args):
    config = load_config(args.config)
    print(config)
    lr = args.lr if args.lr else config["model"]["lr"]
    dataset_path = args.dataset if args.dataset else config["dataset"]["path"]
    num_epochs = args.epochs if args.epochs else config["model"]["epochs"]

    log = Log(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        device = torch.device("cuda")

        torch.backends.cudnn.benchmark = True
        log.log("RoboSat - training on {} GPUs, with {} workers".format(torch.cuda.device_count(), args.workers))
    else:
        device = torch.device("cpu")
        print(args.workers)
        log.log("RoboSat - training on CPU, with {} workers". format(args.workers))

    num_classes = len(config["classes"]["titles"])
    num_channels = 0
    for channel in config["channels"]:
        num_channels += len(channel["bands"])
    pretrained = config["model"]["pretrained"]
    net = DataParallel(UNet(num_classes, num_channels=num_channels, pretrained=pretrained)).to(device)

    if config["model"]["loss"] in ("CrossEntropy", "mIoU", "Focal"):
        try:
            weight = torch.Tensor(config["classes"]["weights"])
        except KeyError:
            sys.exit("Error: The loss function used, need dataset weights values")

    optimizer = Adam(net.parameters(), lr=lr, weight_decay=config["model"]["decay"])

    resume = 0
    if args.checkpoint:

        def map_location(storage, _):
            return storage.cuda() if torch.cuda.is_available() else storage.cpu()

        # https://github.com/pytorch/pytorch/issues/7178
        chkpt = torch.load(args.checkpoint, map_location=map_location)
        net.load_state_dict(chkpt["state_dict"])
        log.log("Using checkpoint: {}".format(args.checkpoint))

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]

    if config["model"]["loss"] == "CrossEntropy":
        criterion = CrossEntropyLoss2d(weight=weight).to(device)
    elif config["model"]["loss"] == "mIoU":
        criterion = mIoULoss2d(weight=weight).to(device)
    elif config["model"]["loss"] == "Focal":
        criterion = FocalLoss2d(weight=weight).to(device)
    elif config["model"]["loss"] == "Lovasz":
        criterion = LovaszLoss2d().to(device)
    else:
        sys.exit("Error: Unknown [model][loss] value !")

    train_loader, val_loader = get_dataset_loaders(dataset_path, config, args.workers)

    if resume >= num_epochs:
        sys.exit("Error: Epoch {} set in {} already reached by the checkpoint provided".format(num_epochs, args.config))

    history = collections.defaultdict(list)

    log.log("")
    log.log("--- Input tensor from Dataset: {} ---".format(dataset_path))
    num_channel = 1
    for channel in config["channels"]:
        for band in channel["bands"]:
            log.log("Channel {}:\t\t {}[band: {}]".format(num_channel, channel["sub"], band))
            num_channel += 1
    log.log("")
    log.log("--- Hyper Parameters ---")
    log.log("Batch Size:\t\t {}".format(config["model"]["batch_size"]))
    log.log("Image Size:\t\t {}".format(config["model"]["image_size"]))
    log.log("Data Augmentation:\t {}".format(config["model"]["data_augmentation"]))
    log.log("Learning Rate:\t\t {}".format(lr))
    log.log("Weight Decay:\t\t {}".format(config["model"]["decay"]))
    log.log("Loss function:\t\t {}".format(config["model"]["loss"]))
    log.log("ResNet pre-trained:\t {}".format(config["model"]["pretrained"]))
    if "weight" in locals():
        log.log("Weights :\t\t {}".format(config["dataset"]["weights"]))
    log.log("")

    for epoch in range(resume, num_epochs):

        log.log("---")
        log.log("Epoch: {}/{}".format(epoch + 1, num_epochs))

        train_hist = train(train_loader, num_classes, device, net, optimizer, criterion)
        log.log(
            "Train    loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".format(
                train_hist["loss"],
                train_hist["miou"],
                config["classes"]["titles"][1],
                train_hist["fg_iou"],
                train_hist["mcc"],
            )
        )

        for k, v in train_hist.items():
            history["train " + k].append(v)

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        log.log(
            "Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".format(
                val_hist["loss"], val_hist["miou"], config["classes"]["titles"][1], val_hist["fg_iou"], val_hist["mcc"]
            )
        )

        for k, v in val_hist.items():
            history["val " + k].append(v)
        visual_path = os.path.join(args.out, "history-{:05d}-of-{:05d}.png".format(epoch + 1, num_epochs))
        plot(visual_path, history)

        if (args.save_intermed):
            states = {"epoch": epoch + 1, "state_dict": net.state_dict(), "optimizer": optimizer.state_dict()}
            checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}-of-{:05d}.pth".format(epoch + 1, num_epochs))
            torch.save(states, checkpoint_path)
コード例 #25
0
def main(args):
    model = load_config(args.model)
    dataset = load_config(args.dataset)

    device = torch.device('cuda' if model['common']['cuda'] else 'cpu')

    if model['common']['cuda'] and not torch.cuda.is_available():
        sys.exit('Error: CUDA requested but not available')

    # if args.batch_size < 2:
    #     sys.exit('Error: PSPNet requires more than one image for BatchNorm in Pyramid Pooling')

    os.makedirs(model['common']['checkpoint'], exist_ok=True)

    num_classes = len(dataset['common']['classes'])
    net = UNet(num_classes).to(device)

    if args.resume:
        path = os.path.join(model['common']['checkpoint'], args.resume)

        cuda = model['common']['cuda']

        def map_location(storage, _):
            return storage.cuda() if cuda else storage.cpu()

        chkpt = torch.load(path, map_location=map_location)
        net.load_state_dict(chkpt)
        resume_at_epoch = int(args.resume[11:16])
    else:
        resume_at_epoch = 0

    if model['common']['cuda']:
        torch.backends.cudnn.benchmark = True
        net = DataParallel(net)

    optimizer = SGD(net.parameters(),
                    lr=model['opt']['lr'],
                    momentum=model['opt']['momentum'])

    scheduler = MultiStepLR(optimizer,
                            milestones=model['opt']['milestones'],
                            gamma=model['opt']['gamma'])

    weight = torch.Tensor(dataset['weights']['values'])

    for i in range(resume_at_epoch):
        scheduler.step()

    criterion = CrossEntropyLoss2d(weight=weight).to(device)
    # criterion = FocalLoss2d(weight=weight).to(device)

    train_loader, val_loader = get_dataset_loaders(model, dataset)

    num_epochs = model['opt']['epochs']

    history = collections.defaultdict(list)

    for epoch in range(resume_at_epoch, num_epochs):
        print('Epoch: {}/{}'.format(epoch + 1, num_epochs))

        train_hist = train(train_loader, num_classes, device, net, optimizer,
                           scheduler, criterion)
        print('Train loss: {:.4f}, mean IoU: {:.4f}'.format(
            train_hist['loss'], train_hist['iou']))

        for k, v in train_hist.items():
            history['train ' + k].append(v)

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        print('Validate loss: {:.4f}, mean IoU: {:.4f}'.format(
            val_hist['loss'], val_hist['iou']))

        for k, v in val_hist.items():
            history['val ' + k].append(v)

        visual = 'history-{:05d}-of-{:05d}.png'.format(epoch + 1, num_epochs)
        plot(os.path.join(model['common']['checkpoint'], visual), history)

        checkpoint = 'checkpoint-{:05d}-of-{:05d}.pth'.format(
            epoch + 1, num_epochs)
        torch.save(net.state_dict(),
                   os.path.join(model['common']['checkpoint'], checkpoint))
コード例 #26
0
def main(args):

    if args.type == "label":
        try:
            config = load_config(args.config)
        except:
            sys.exit("Error: Unable to load DataSet config file")

        classes = config["classes"]["title"]
        colors = config["classes"]["colors"]
        assert len(classes) == len(colors), "classes and colors coincide"
        assert len(colors) == 2, "only binary models supported right now"

    try:
        raster = rasterio_open(args.raster)
        w, s, e, n = bounds = transform_bounds(raster.crs, "EPSG:4326",
                                               *raster.bounds)
        transform, _, _ = calculate_default_transform(raster.crs, "EPSG:3857",
                                                      raster.width,
                                                      raster.height, *bounds)
    except:
        sys.exit("Error: Unable to load raster or deal with it's projection")

    tiles = [
        mercantile.Tile(x=x, y=y, z=z)
        for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)
    ]
    tiles_nodata = []

    for tile in tqdm(tiles, desc="Tiling", unit="tile", ascii=True):

        w, s, e, n = tile_bounds = mercantile.xy_bounds(tile)

        # Inspired by Rio-Tiler, cf: https://github.com/mapbox/rio-tiler/pull/45
        warp_vrt = WarpedVRT(
            raster,
            crs="EPSG:3857",
            resampling=Resampling.bilinear,
            add_alpha=False,
            transform=from_bounds(*tile_bounds, args.size, args.size),
            width=math.ceil((e - w) / transform.a),
            height=math.ceil((s - n) / transform.e),
        )
        data = warp_vrt.read(out_shape=(len(raster.indexes), args.size,
                                        args.size),
                             window=warp_vrt.window(w, s, e, n))

        # If no_data is set, remove all tiles with at least one whole border filled only with no_data (on all bands)
        if type(args.no_data) is not None and (
                np.all(data[:, 0, :] == args.no_data)
                or np.all(data[:, -1, :] == args.no_data)
                or np.all(data[:, :, 0] == args.no_data)
                or np.all(data[:, :, -1] == args.no_data)):
            tiles_nodata.append(tile)
            continue

        C, W, H = data.shape

        os.makedirs(os.path.join(args.out, str(args.zoom), str(tile.x)),
                    exist_ok=True)
        path = os.path.join(args.out, str(args.zoom), str(tile.x), str(tile.y))

        if args.type == "label":
            assert C == 1, "Error: Label raster input should be 1 band"

            ext = "png"
            img = Image.fromarray(np.squeeze(data, axis=0), mode="P")
            img.putpalette(make_palette(colors[0], colors[1]))
            img.save("{}.{}".format(path, ext), optimize=True)

        elif args.type == "image":
            assert C == 1 or C == 3, "Error: Image raster input should be either 1 or 3 bands"

            # GeoTiff could be 16 or 32bits
            if data.dtype == "uint16":
                data = np.uint8(data / 256)
            elif data.dtype == "uint32":
                data = np.uint8(data / (256 * 256))

            if C == 1:
                ext = "png"
                Image.fromarray(np.squeeze(data, axis=0),
                                mode="L").save("{}.{}".format(path, ext),
                                               optimize=True)
            elif C == 3:
                ext = "webp"
                Image.fromarray(np.moveaxis(data, 0, 2),
                                mode="RGB").save("{}.{}".format(path, ext),
                                                 optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        tiles = [tile for tile in tiles if tile not in tiles_nodata]
        web_ui(args.out, args.web_ui, tiles, tiles, ext, template)
def main(args):
    model = load_config(args.model)
    dataset = load_config(args.dataset)

    cuda = model['common']['cuda']

    device = torch.device('cuda' if cuda else 'cpu')

    def map_location(storage, _):
        return storage.cuda() if cuda else storage.cpu()

    if cuda and not torch.cuda.is_available():
        sys.exit('Error: CUDA requested but not available')

    num_classes = len(dataset['common']['classes'])

    # https://github.com/pytorch/pytorch/issues/7178
    chkpt = torch.load(args.checkpoint, map_location=map_location)

    net = UNet(num_classes).to(device)
    net = nn.DataParallel(net)

    if cuda:
        torch.backends.cudnn.benchmark = True

    net.load_state_dict(chkpt)
    net.eval()

    transform = Compose([
        ConvertImageMode(mode='RGB'),
        ImageToTensor(),
        Normalize(mean=dataset['stats']['mean'], std=dataset['stats']['std'])
    ])

    directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=args.tile_size, overlap=args.overlap)
    loader = DataLoader(directory, batch_size=args.batch_size)

    # don't track tensors with autograd during prediction
    with torch.no_grad():
        for images, tiles in tqdm(loader, desc='Eval', unit='batch', ascii=True):
            images = images.to(device)
            outputs = net(images)

            # manually compute segmentation mask class probabilities per pixel
            probs = nn.functional.softmax(outputs, dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))

                # we predicted on buffered tiles; now get back probs for original image
                prob = directory.unbuffer(prob)

                # Quantize the floating point probabilities in [0,1] to [0,255] and store
                # a single-channel `.png` file with a continuous color palette attached.

                assert prob.shape[0] == 2, 'single channel requires binary model'
                assert np.allclose(np.sum(prob, axis=0), 1.), 'single channel requires probabilities to sum up to one'
                foreground = prob[1:, :, :]

                anchors = np.linspace(0, 1, 256)
                quantized = np.digitize(foreground, anchors).astype(np.uint8)

                palette = continuous_palette_for_color('pink', 256)

                out = Image.fromarray(quantized.squeeze(), mode='P')
                out.putpalette(palette)

                os.makedirs(os.path.join(args.probs, str(z), str(x)), exist_ok=True)
                path = os.path.join(args.probs, str(z), str(x), str(y) + '.png')

                out.save(path, optimize=True)