Example #1
0
def main(args):
    log = Log(os.path.join(args.out, "log"), out=sys.stderr)

    tiles = set(tiles_from_csv(args.cover))
    extension = ""

    for tile in tqdm(tiles, desc="Subset", unit="tiles", ascii=True):

        paths = glob(os.path.join(args.dir, str(tile.z), str(tile.x), "{}.*".format(tile.y)))
        if len(paths) != 1:
            log.log("Warning: {} skipped.".format(tile))
            continue
        src = paths[0]

        try:
            extension = os.path.splitext(src)[1][1:]
            dst = os.path.join(args.out, str(tile.z), str(tile.x), "{}.{}".format(tile.y, extension))
            if not os.path.isdir(os.path.join(args.out, str(tile.z), str(tile.x))):
                os.makedirs(os.path.join(args.out, str(tile.z), str(tile.x)), exist_ok=True)
            if args.move:
                assert os.path.isfile(src)
                shutil.move(src, dst)
            else:
                shutil.copyfile(src, dst)
        except:
            sys.exit("Error: Unable to process {}".format(tile))

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        web_ui(args.out, args.web_ui, tiles, tiles, extension, template)
Example #2
0
def main(args):
    model = load_config(args.model)
    dataset = load_config(args.dataset)

    device = torch.device("cuda" if model["common"]["cuda"] else "cpu")

    if model["common"]["cuda"] and not torch.cuda.is_available():
        sys.exit("Error: CUDA requested but not available")

    os.makedirs(model["common"]["checkpoint"], exist_ok=True)

    num_classes = len(dataset["common"]["classes"])
    net = UNet(num_classes)
    net = DataParallel(net)
    net = net.to(device)

    if model["common"]["cuda"]:
        torch.backends.cudnn.benchmark = True

    try:
        weight = torch.Tensor(dataset["weights"]["values"])
    except KeyError:
        if model["opt"]["loss"] in ("CrossEntropy", "mIoU", "Focal"):
            sys.exit(
                "Error: The loss function used, need dataset weights values")

    optimizer = Adam(net.parameters(),
                     lr=model["opt"]["lr"],
                     weight_decay=model["opt"]["decay"])

    resume = 0
    if args.checkpoint:

        def map_location(storage, _):
            return storage.cuda() if model["common"]["cuda"] else storage.cpu()

        # https://github.com/pytorch/pytorch/issues/7178
        chkpt = torch.load(args.checkpoint, map_location=map_location)
        net.load_state_dict(chkpt["state_dict"])

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]

    if model["opt"]["loss"] == "CrossEntropy":
        criterion = CrossEntropyLoss2d(weight=weight).to(device)
    elif model["opt"]["loss"] == "mIoU":
        criterion = mIoULoss2d(weight=weight).to(device)
    elif model["opt"]["loss"] == "Focal":
        criterion = FocalLoss2d(weight=weight).to(device)
    elif model["opt"]["loss"] == "Lovasz":
        criterion = LovaszLoss2d().to(device)
    else:
        sys.exit("Error: Unknown [opt][loss] value !")

    train_loader, val_loader = get_dataset_loaders(model, dataset,
                                                   args.workers)

    num_epochs = model["opt"]["epochs"]
    if resume >= num_epochs:
        sys.exit(
            "Error: Epoch {} set in {} already reached by the checkpoint provided"
            .format(num_epochs, args.model))

    history = collections.defaultdict(list)
    log = Log(os.path.join(model["common"]["checkpoint"], "log"))

    log.log("--- Hyper Parameters on Dataset: {} ---".format(
        dataset["common"]["dataset"]))
    log.log("Batch Size:\t {}".format(model["common"]["batch_size"]))
    log.log("Image Size:\t {}".format(model["common"]["image_size"]))
    log.log("Learning Rate:\t {}".format(model["opt"]["lr"]))
    log.log("Weight Decay:\t {}".format(model["opt"]["decay"]))
    log.log("Loss function:\t {}".format(model["opt"]["loss"]))
    if "weight" in locals():
        log.log("Weights :\t {}".format(dataset["weights"]["values"]))
    log.log("---")

    for epoch in range(resume, num_epochs):
        log.log("Epoch: {}/{}".format(epoch + 1, num_epochs))

        train_hist = train(train_loader, num_classes, device, net, optimizer,
                           criterion)
        log.log(
            "Train    loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".
            format(
                train_hist["loss"],
                train_hist["miou"],
                dataset["common"]["classes"][1],
                train_hist["fg_iou"],
                train_hist["mcc"],
            ))

        for k, v in train_hist.items():
            history["train " + k].append(v)

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        log.log(
            "Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".
            format(val_hist["loss"], val_hist["miou"],
                   dataset["common"]["classes"][1], val_hist["fg_iou"],
                   val_hist["mcc"]))

        for k, v in val_hist.items():
            history["val " + k].append(v)

        visual = "history-{:05d}-of-{:05d}.png".format(epoch + 1, num_epochs)
        plot(os.path.join(model["common"]["checkpoint"], visual), history)

        checkpoint = "checkpoint-{:05d}-of-{:05d}.pth".format(
            epoch + 1, num_epochs)

        states = {
            "epoch": epoch + 1,
            "state_dict": net.state_dict(),
            "optimizer": optimizer.state_dict()
        }

        torch.save(states,
                   os.path.join(model["common"]["checkpoint"], checkpoint))
Example #3
0
def main(args):
    config = load_config(args.config)
    print(config)
    lr = args.lr if args.lr else config["model"]["lr"]
    dataset_path = args.dataset if args.dataset else config["dataset"]["path"]
    num_epochs = args.epochs if args.epochs else config["model"]["epochs"]

    log = Log(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        device = torch.device("cuda")

        torch.backends.cudnn.benchmark = True
        log.log("RoboSat - training on {} GPUs, with {} workers".format(torch.cuda.device_count(), args.workers))
    else:
        device = torch.device("cpu")
        print(args.workers)
        log.log("RoboSat - training on CPU, with {} workers". format(args.workers))

    num_classes = len(config["classes"]["titles"])
    num_channels = 0
    for channel in config["channels"]:
        num_channels += len(channel["bands"])
    pretrained = config["model"]["pretrained"]
    net = DataParallel(UNet(num_classes, num_channels=num_channels, pretrained=pretrained)).to(device)

    if config["model"]["loss"] in ("CrossEntropy", "mIoU", "Focal"):
        try:
            weight = torch.Tensor(config["classes"]["weights"])
        except KeyError:
            sys.exit("Error: The loss function used, need dataset weights values")

    optimizer = Adam(net.parameters(), lr=lr, weight_decay=config["model"]["decay"])

    resume = 0
    if args.checkpoint:

        def map_location(storage, _):
            return storage.cuda() if torch.cuda.is_available() else storage.cpu()

        # https://github.com/pytorch/pytorch/issues/7178
        chkpt = torch.load(args.checkpoint, map_location=map_location)
        net.load_state_dict(chkpt["state_dict"])
        log.log("Using checkpoint: {}".format(args.checkpoint))

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]

    if config["model"]["loss"] == "CrossEntropy":
        criterion = CrossEntropyLoss2d(weight=weight).to(device)
    elif config["model"]["loss"] == "mIoU":
        criterion = mIoULoss2d(weight=weight).to(device)
    elif config["model"]["loss"] == "Focal":
        criterion = FocalLoss2d(weight=weight).to(device)
    elif config["model"]["loss"] == "Lovasz":
        criterion = LovaszLoss2d().to(device)
    else:
        sys.exit("Error: Unknown [model][loss] value !")

    train_loader, val_loader = get_dataset_loaders(dataset_path, config, args.workers)

    if resume >= num_epochs:
        sys.exit("Error: Epoch {} set in {} already reached by the checkpoint provided".format(num_epochs, args.config))

    history = collections.defaultdict(list)

    log.log("")
    log.log("--- Input tensor from Dataset: {} ---".format(dataset_path))
    num_channel = 1
    for channel in config["channels"]:
        for band in channel["bands"]:
            log.log("Channel {}:\t\t {}[band: {}]".format(num_channel, channel["sub"], band))
            num_channel += 1
    log.log("")
    log.log("--- Hyper Parameters ---")
    log.log("Batch Size:\t\t {}".format(config["model"]["batch_size"]))
    log.log("Image Size:\t\t {}".format(config["model"]["image_size"]))
    log.log("Data Augmentation:\t {}".format(config["model"]["data_augmentation"]))
    log.log("Learning Rate:\t\t {}".format(lr))
    log.log("Weight Decay:\t\t {}".format(config["model"]["decay"]))
    log.log("Loss function:\t\t {}".format(config["model"]["loss"]))
    log.log("ResNet pre-trained:\t {}".format(config["model"]["pretrained"]))
    if "weight" in locals():
        log.log("Weights :\t\t {}".format(config["dataset"]["weights"]))
    log.log("")

    for epoch in range(resume, num_epochs):

        log.log("---")
        log.log("Epoch: {}/{}".format(epoch + 1, num_epochs))

        train_hist = train(train_loader, num_classes, device, net, optimizer, criterion)
        log.log(
            "Train    loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".format(
                train_hist["loss"],
                train_hist["miou"],
                config["classes"]["titles"][1],
                train_hist["fg_iou"],
                train_hist["mcc"],
            )
        )

        for k, v in train_hist.items():
            history["train " + k].append(v)

        val_hist = validate(val_loader, num_classes, device, net, criterion)
        log.log(
            "Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".format(
                val_hist["loss"], val_hist["miou"], config["classes"]["titles"][1], val_hist["fg_iou"], val_hist["mcc"]
            )
        )

        for k, v in val_hist.items():
            history["val " + k].append(v)
        visual_path = os.path.join(args.out, "history-{:05d}-of-{:05d}.png".format(epoch + 1, num_epochs))
        plot(visual_path, history)

        if (args.save_intermed):
            states = {"epoch": epoch + 1, "state_dict": net.state_dict(), "optimizer": optimizer.state_dict()}
            checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}-of-{:05d}.pth".format(epoch + 1, num_epochs))
            torch.save(states, checkpoint_path)
Example #4
0
def main(args):
    tiles = list(tiles_from_csv(args.tiles))
    already_dl = 0
    dl = 0

    with requests.Session() as session:
        num_workers = args.rate

        os.makedirs(os.path.join(args.out), exist_ok=True)
        log = Log(os.path.join(args.out, "log"), out=sys.stderr)
        log.log("Begin download from {}".format(args.url))

        # tqdm has problems with concurrent.futures.ThreadPoolExecutor; explicitly call `.update`
        # https://github.com/tqdm/tqdm/issues/97
        progress = tqdm(total=len(tiles), ascii=True, unit="image")

        with futures.ThreadPoolExecutor(num_workers) as executor:

            def worker(tile):
                tick = time.monotonic()

                x, y, z = map(str, [tile.x, tile.y, tile.z])

                os.makedirs(os.path.join(args.out, z, x), exist_ok=True)
                path = os.path.join(args.out, z, x, "{}.{}".format(y, args.ext))

                if os.path.isfile(path):
                    return tile, None, True

                if args.type == "XYZ":
                    url = args.url.format(x=tile.x, y=tile.y, z=tile.z)
                elif args.type == "TMS":
                    tile.y = (2 ** tile.z) - tile.y - 1
                    url = args.url.format(x=tile.x, y=tile.y, z=tile.z)
                elif args.type == "WMS":
                    xmin, ymin, xmax, ymax = xy_bounds(tile)
                    url = args.url.format(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)

                res = fetch_image(session, url, args.timeout)
                if not res:
                    return tile, url, False

                try:
                    image = Image.open(res)
                    image.save(path, optimize=True)
                except OSError:
                    return tile, url, False

                tock = time.monotonic()

                time_for_req = tock - tick
                time_per_worker = num_workers / args.rate

                if time_for_req < time_per_worker:
                    time.sleep(time_per_worker - time_for_req)

                progress.update()

                return tile, url, True

            for tile, url, ok in executor.map(worker, tiles):
                if url and ok:
                    dl += 1
                elif not url and ok:
                    already_dl += 1
                else:
                    log.log("Warning:\n {} failed, skipping.\n {}\n".format(tile, url))

    if already_dl:
        log.log("Notice:\n {} tiles were already downloaded previously, and so skipped now.".format(already_dl))
    if already_dl + dl == len(tiles):
        log.log(" Coverage is fully downloaded.")

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        web_ui(args.out, args.web_ui, tiles, tiles, args.ext, template)