コード例 #1
0
def predict(config, cover, args, palette, chkpt, nn, device, mode):
    assert mode in ["predict", "predict_translate"], "Predict unknown mode"
    loader_module = load_module("robosat_pink.loaders.{}".format(
        chkpt["loader"].lower()))
    loader_predict = getattr(loader_module,
                             chkpt["loader"])(config,
                                              chkpt["shape_in"][1:3],
                                              args.dataset,
                                              cover,
                                              mode=mode)

    loader = DataLoader(loader_predict,
                        batch_size=args.bs,
                        num_workers=args.workers)
    assert len(loader), "Empty predict dataset directory. Check your path."

    tiled = []
    for images, tiles in tqdm(loader, desc="Eval", unit="batch", ascii=True):

        images = images.to(device)
        for tile, prob in zip(
                tiles,
                torch.nn.functional.softmax(nn(images),
                                            dim=1).data.cpu().numpy()):
            x, y, z = list(map(int, tile))
            mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()
            if mode == "predict":
                tile_label_to_file(args.out, mercantile.Tile(x, y, z), palette,
                                   mask)
            if mode == "predict_translate":
                tile_translate_to_file(args.out, mercantile.Tile(x, y, z),
                                       palette, mask, config["model"]["ms"])
            tiled.append(mercantile.Tile(x, y, z))

    return tiled
コード例 #2
0
def main(args):

    chkpt = torch.load(args.checkpoint, map_location=torch.device("cpu"))

    model_module = load_module("robosat_pink.models.{}".format(chkpt["nn"].lower()))
    nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"]).to("cpu")

    print("RoboSat.pink - export model to {}".format(args.type))
    print("Model: {} - UUID: {} - Torch {}".format(chkpt["nn"], chkpt["uuid"], torch.__version__))
    print(chkpt["doc_string"])

    try:  # https://github.com/pytorch/pytorch/issues/9176
        nn.module.state_dict(chkpt["state_dict"])
    except AttributeError:
        nn.state_dict(chkpt["state_dict"])

    nn.eval()

    batch = torch.rand(1, *chkpt["shape_in"])
    if args.type == "onnx":
        torch.onnx.export(
            nn,
            torch.autograd.Variable(batch),
            args.out,
            input_names=["input", "shape_in", "shape_out"],
            output_names=["output"],
            dynamic_axes={"input": {0: "num_batch"}, "output": {0: "num_batch"}},
        )

    if args.type == "jit":
        torch.jit.trace(nn, batch).save(args.out)
コード例 #3
0
 def __init__(self, metrics, config=None):
     self.config = config
     self.metrics = {metric: 0.0 for metric in metrics}
     self.modules = {
         metric: load_module("robosat_pink.metrics." + metric)
         for metric in metrics
     }
     self.n = 0
コード例 #4
0
    def __init__(self, shape_in, shape_out, encoder="resnet50", train_config=None):
        self.doc = """
            U-Net inspired encoder-decoder architecture with a ResNet encoder as proposed by Alexander Buslaev.

            - https://arxiv.org/abs/1505.04597 - U-Net: Convolutional Networks for Biomedical Image Segmentation
            - https://arxiv.org/pdf/1804.08024 - Angiodysplasia Detection and Localization Using DCNN
            - https://arxiv.org/abs/1806.00844 - TernausNetV2: Fully Convolutional Network for Instance Segmentation
        """

        if encoder in ["resnet50", "resnet101", "resnet152"]:
            self.doc += "https://arxiv.org/abs/1512.03385 - Deep Residual Learning for Image Recognition"
        elif encoder in ["resnext50_32x4d", "resnext101_32x8d"]:
            self.doc += "https://arxiv.org/pdf/1611.05431 - Aggregated Residual Transformations for DNN"
        elif encoder in ["wide_resnet50_2", "wide_resnet101_2"]:
            self.doc += "https://arxiv.org/abs/1605.07146 - Wide Residual Networks"
        else:
            encoders = "Resnet50, Resnet101, Resnet152, Resnext50_32x4d, Resnext101_32x8d, Wide_resnet50_2, Wide_resnet101_2"
            assert False, "Albunet, expects as encoder: " + encoders

        self.version = 1

        num_filters = 32
        num_channels = shape_in[0]
        num_classes = shape_out[0]

        super().__init__()

        try:
            pretrained = train_config["model"]["pretrained"]
        except:
            pretrained = False

        models = load_module("torchvision.models")
        self.resnet = getattr(models, encoder)(pretrained=pretrained)

        assert num_channels
        if num_channels != 3:
            weights = nn.init.xavier_uniform_(torch.zeros((64, num_channels, 7, 7)))
            if pretrained:
                for c in range(min(num_channels, 3)):
                    weights.data[:, c, :, :] = self.resnet.conv1.weight.data[:, c, :, :]
            self.resnet.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.resnet.conv1.weight = nn.Parameter(weights)

        # No encoder reference, cf: https://github.com/pytorch/pytorch/issues/8392

        self.center = DecoderBlock(2048, num_filters * 8)

        self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8)
        self.dec1 = DecoderBlock(1024 + num_filters * 8, num_filters * 8)
        self.dec2 = DecoderBlock(512 + num_filters * 8, num_filters * 2)
        self.dec3 = DecoderBlock(256 + num_filters * 2, num_filters * 2 * 2)
        self.dec4 = DecoderBlock(num_filters * 2 * 2, num_filters)
        self.dec5 = ConvRelu(num_filters, num_filters)

        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
コード例 #5
0
def main(args):

    print(
        "RoboSat.pink - extract {} from {}. Could take some time. Please wait."
        .format(args.type, args.pbf))

    module = load_module("robosat_pink.osm.{}".format(args.type.lower()))
    osmium_handler = getattr(module, "{}Handler".format(args.type))()
    osmium_handler.apply_file(filename=os.path.expanduser(args.pbf),
                              locations=True)
    osmium_handler.save(os.path.expanduser(args.out))
コード例 #6
0
def main(args):
    config = load_config(args.config)
    args.out = os.path.expanduser(args.out)
    config["model"][
        "loader"] = args.loader if args.loader else config["model"]["loader"]
    config["model"]["bs"] = args.bs if args.bs else config["model"]["bs"]
    config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"]
    config["model"]["ts"] = tuple(map(
        int, args.ts.split(","))) if args.ts else config["model"]["ts"]
    config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"]
    config["model"]["encoder"] = args.encoder if args.encoder else config[
        "model"]["encoder"]
    config["model"][
        "loss"] = args.loss if args.loss else config["model"]["loss"]
    config["model"]["da"] = args.da if args.da else config["model"]["da"]
    config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"]
    args.workers = config["model"]["bs"] if not args.workers else args.workers
    check_classes(config)
    check_channels(config)
    check_model(config)

    assert os.path.isdir(os.path.expanduser(
        args.dataset)), "Dataset is not a directory"
    if args.no_training and args.no_validation:
        sys.exit()

    log = Logs(os.path.join(args.out, "log"))
    csv_train = None if args.no_training else open(
        os.path.join(args.out, "training.csv"), mode="a")
    csv_val = None if args.no_validation else open(
        os.path.join(args.out, "validation.csv"), mode="a")

    if torch.cuda.is_available():
        log.log("RoboSat.pink - training on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - training on CPU, with {} workers - (Torch:{})".
                format(args.workers, torch.__version__))
        log.log("")
        log.log("==========================================================")
        log.log("WARNING: Are you -really- sure about not training on GPU ?")
        log.log("==========================================================")
        log.log("")
        device = torch.device("cpu")

    log.log("--- Input tensor from Dataset: {} ---".format(args.dataset))
    num_channel = 1  # 1-based numerotation
    for channel in config["channels"]:
        for band in channel["bands"]:
            log.log("Channel {}:\t\t {}[band: {}]".format(
                num_channel, channel["name"], band))
            num_channel += 1

    log.log("--- Output Classes ---")
    for c, classe in enumerate(config["classes"]):
        log.log("Class {}:\t\t {}".format(c, classe["title"]))

    log.log("--- Hyper Parameters ---")
    for hp in config["model"]:
        log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp]))

    loader = load_module("robosat_pink.loaders.{}".format(
        config["model"]["loader"].lower()))
    loader_train = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"], os.path.join(args.dataset,
                                                    "training"), None, "train")
    loader_val = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"],
        os.path.join(args.dataset, "validation"), None, "train")

    encoder = config["model"]["encoder"].lower()
    nn_module = load_module("robosat_pink.nn.{}".format(
        config["model"]["nn"].lower()))
    nn = getattr(nn_module, config["model"]["nn"])(loader_train.shape_in,
                                                   loader_train.shape_out,
                                                   encoder, config).to(device)
    nn = torch.nn.DataParallel(nn)
    optimizer = Adam(nn.parameters(), lr=config["model"]["lr"])

    resume = 0
    if args.checkpoint:
        chkpt = torch.load(os.path.expanduser(args.checkpoint),
                           map_location=device)
        nn.load_state_dict(chkpt["state_dict"])
        log.log("--- Using Checkpoint ---")
        log.log("Path:\t\t {}".format(args.checkpoint))
        log.log("UUID:\t\t {}".format(chkpt["uuid"]))

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]
            assert resume < args.epochs, "Epoch asked, already reached by the given checkpoint"

    loss_module = load_module("robosat_pink.losses.{}".format(
        config["model"]["loss"].lower()))
    criterion = getattr(loss_module, config["model"]["loss"])().to(device)

    bs = config["model"]["bs"]
    train_loader = DataLoader(loader_train,
                              batch_size=bs,
                              shuffle=True,
                              drop_last=True,
                              num_workers=args.workers)
    val_loader = DataLoader(loader_val,
                            batch_size=bs,
                            shuffle=False,
                            drop_last=True,
                            num_workers=args.workers)

    if args.no_training:
        epoch = 0
        process(val_loader, config, log, csv_val, epoch, device, nn, criterion,
                "eval")
        sys.exit()

    for epoch in range(resume + 1, args.epochs + 1):  # 1-N based
        UUID = uuid.uuid1()
        log.log("---{}Epoch: {}/{} -- UUID: {}".format(os.linesep, epoch,
                                                       args.epochs, UUID))

        process(train_loader, config, log, csv_train, epoch, device, nn,
                criterion, "train", optimizer)

        try:  # https://github.com/pytorch/pytorch/issues/9176
            nn_doc = nn.module.doc
            nn_version = nn.module.version
        except AttributeError:
            nn_doc = nn.doc
            nn_version = nn.version

        states = {
            "uuid": UUID,
            "model_version": nn_version,
            "producer_name": "RoboSat.pink",
            "producer_version": rsp.__version__,
            "model_licence": "MIT",
            "domain": "pink.RoboSat",  # reverse-DNS
            "doc_string": nn_doc,
            "shape_in": loader_train.shape_in,
            "shape_out": loader_train.shape_out,
            "state_dict": nn.state_dict(),
            "epoch": epoch,
            "nn": config["model"]["nn"],
            "encoder": config["model"]["encoder"],
            "optimizer": optimizer.state_dict(),
            "loader": config["model"]["loader"],
        }
        checkpoint_path = os.path.join(args.out,
                                       "checkpoint-{:05d}.pth".format(epoch))
        if epoch == args.epochs or not (epoch % args.saving):
            log.log("[Saving checkpoint]")
            torch.save(states, checkpoint_path)

        if not args.no_validation:
            process(val_loader, config, log, csv_val, epoch, device, nn,
                    criterion, "eval")
コード例 #7
0
ファイル: predict.py プロジェクト: dselivanov/robosat.pink
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    palette = make_palette([classe["color"] for classe in config["classes"]])
    args.workers = torch.cuda.device_count() * 2 if torch.device(
        "cuda") and not args.workers else args.workers
    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        log.log("")
        log.log("============================================================")
        log.log("WARNING: Are you -really- sure about not predicting on GPU ?")
        log.log("============================================================")
        log.log("")
        device = torch.device("cpu")

    chkpt = torch.load(args.checkpoint, map_location=device)
    model_module = load_module("robosat_pink.models.{}".format(
        chkpt["nn"].lower()))
    nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"],
                                            chkpt["shape_out"]).to(device)
    nn = torch.nn.DataParallel(nn)
    nn.load_state_dict(chkpt["state_dict"])
    nn.eval()

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    mode = "predict" if not args.translate else "predict_translate"
    loader_module = load_module("robosat_pink.loaders.{}".format(
        chkpt["loader"].lower()))
    loader_predict = getattr(loader_module,
                             chkpt["loader"])(config,
                                              chkpt["shape_in"][1:3],
                                              args.dataset,
                                              cover,
                                              mode=mode)

    loader = DataLoader(loader_predict,
                        batch_size=args.bs,
                        num_workers=args.workers)
    assert len(loader), "Empty predict dataset directory. Check your path."

    tiled = []
    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):

            images = images.to(device)

            outputs = nn(images)
            probs = torch.nn.functional.softmax(outputs,
                                                dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))
                mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()
                if args.translate:
                    tile_translate_to_file(args.out, mercantile.Tile(x, y, z),
                                           palette, mask)
                else:
                    tile_label_to_file(args.out, mercantile.Tile(x, y, z),
                                       palette, mask)
                tiled.append(mercantile.Tile(x, y, z))

    if not args.no_web_ui and not args.translate:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiled, tiled, "png", template)
コード例 #8
0
def main(args):

    chkpt = torch.load(args.checkpoint, map_location=torch.device("cpu"))
    UUID = chkpt["uuid"] if "uuid" in chkpt else uuid.uuid1()

    try:
        nn_name = chkpt["nn"]
    except:
        assert args.nn, "--nn mandatory as not already in input .pth"
        nn_name = args.nn

    try:
        encoder = chkpt["encoder"]
    except:
        assert args.encoder, "--encoder mandatory as not already in input .pth"
        encoder = args.encoder

    try:
        loader = chkpt["loader"]
    except:
        assert args.loader, "--loader mandatory as not already in input .pth"
        doc_string = args.doc_string

    try:
        doc_string = chkpt["doc_string"]
    except:
        assert args.doc_string, "--doc_string mandatory as not already in input .pth"
        doc_string = args.doc_string

    try:
        shape_in = chkpt["shape_in"]
    except:
        assert args.shape_in, "--shape_in mandatory as not already in input .pth"
        shape_in = tuple(map(int, args.shape_in.split(",")))

    try:
        shape_out = chkpt["shape_out"]
    except:
        assert args.shape_out, "--shape_out mandatory as not already in input .pth"
        shape_out = tuple(map(int, args.shape_out.split(",")))

    nn_module = load_module("robosat_pink.nn.{}".format(nn_name.lower()))
    nn = getattr(nn_module, nn_name)(shape_in, shape_out,
                                     encoder.lower()).to("cpu")

    print("RoboSat.pink - export model to {}".format(args.type),
          file=sys.stderr)
    print("Model: {}".format(nn_name, file=sys.stderr))
    print("UUID: {}".format(UUID, file=sys.stderr))

    if args.type == "pth":

        states = {
            "uuid": UUID,
            "model_version": None,
            "producer_name": "RoboSat.pink",
            "producer_version": rsp.__version__,
            "model_licence": "MIT",
            "domain": "pink.RoboSat",  # reverse-DNS
            "doc_string": doc_string,
            "shape_in": shape_in,
            "shape_out": shape_out,
            "state_dict": nn.state_dict(),
            "epoch": 0,
            "nn": nn_name,
            "encoder": encoder,
            "optimizer": None,
            "loader": loader,
        }

        torch.save(states, args.out)

    else:

        try:  # https://github.com/pytorch/pytorch/issues/9176
            nn.module.state_dict(chkpt["state_dict"])
        except AttributeError:
            nn.state_dict(chkpt["state_dict"])

        nn.eval()

        batch = torch.rand(1, *shape_in)

        if args.type == "onnx":
            torch.onnx.export(
                nn,
                torch.autograd.Variable(batch),
                args.out,
                input_names=["input", "shape_in", "shape_out"],
                output_names=["output"],
                dynamic_axes={
                    "input": {
                        0: "num_batch"
                    },
                    "output": {
                        0: "num_batch"
                    }
                },
            )

        if args.type == "jit":
            torch.jit.trace(nn, batch).save(args.out)
コード例 #9
0
def main(args):
    config = load_config(args.config)
    args.out = os.path.expanduser(args.out)
    args.workers = torch.cuda.device_count() * 2 if torch.device("cuda") and not args.workers else args.workers
    config["model"]["loader"] = args.loader if args.loader else config["model"]["loader"]
    config["model"]["bs"] = args.bs if args.bs else config["model"]["bs"]
    config["model"]["lr"] = args.lr if args.lr else config["model"]["lr"]
    config["model"]["ts"] = args.ts if args.ts else config["model"]["ts"]
    config["model"]["nn"] = args.nn if args.nn else config["model"]["nn"]
    config["model"]["loss"] = args.loss if args.loss else config["model"]["loss"]
    config["model"]["da"] = args.da if args.da else config["model"]["da"]
    config["model"]["dap"] = args.dap if args.dap else config["model"]["dap"]
    check_classes(config)
    check_channels(config)
    check_model(config)

    if not os.path.isdir(os.path.expanduser(args.dataset)):
        sys.exit("ERROR: dataset {} is not a directory".format(args.dataset))

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - training on {} GPUs, with {} workers".format(torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(torch.__version__, torch.version.cuda, torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - training on CPU, with {} workers - (Torch:{})".format(args.workers, torch.__version__))
        log.log("WARNING: Are you really sure sure about not training on GPU ?")
        device = torch.device("cpu")

    loader = load_module("robosat_pink.loaders.{}".format(config["model"]["loader"].lower()))
    loader_train = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"], os.path.join(args.dataset, "training"), "train"
    )
    loader_val = getattr(loader, config["model"]["loader"])(
        config, config["model"]["ts"], os.path.join(args.dataset, "validation"), "train"
    )

    model_module = load_module("robosat_pink.models.{}".format(config["model"]["nn"].lower()))

    nn = getattr(model_module, config["model"]["nn"])(loader_train.shape_in, loader_train.shape_out, config).to(device)
    nn = torch.nn.DataParallel(nn)
    optimizer = Adam(nn.parameters(), lr=config["model"]["lr"])

    resume = 0
    if args.checkpoint:
        chkpt = torch.load(os.path.expanduser(args.checkpoint), map_location=device)
        nn.load_state_dict(chkpt["state_dict"])
        log.log("Using checkpoint: {}".format(args.checkpoint))

        if args.resume:
            optimizer.load_state_dict(chkpt["optimizer"])
            resume = chkpt["epoch"]
            if resume >= args.epochs:
                sys.exit("ERROR: Epoch {} already reached by the given checkpoint".format(config["model"]["epochs"]))

    loss_module = load_module("robosat_pink.losses.{}".format(config["model"]["loss"].lower()))
    criterion = getattr(loss_module, config["model"]["loss"])().to(device)

    bs = config["model"]["bs"]
    train_loader = DataLoader(loader_train, batch_size=bs, shuffle=True, drop_last=True, num_workers=args.workers)
    val_loader = DataLoader(loader_val, batch_size=bs, shuffle=False, drop_last=True, num_workers=args.workers)

    log.log("--- Input tensor from Dataset: {} ---".format(args.dataset))
    num_channel = 1  # 1-based numerotation
    for channel in config["channels"]:
        for band in channel["bands"]:
            log.log("Channel {}:\t\t {}[band: {}]".format(num_channel, channel["name"], band))
            num_channel += 1

    log.log("--- Hyper Parameters ---")
    for hp in config["model"]:
        log.log("{}{}".format(hp.ljust(25, " "), config["model"][hp]))

    for epoch in range(resume, args.epochs):
        UUID = uuid.uuid1()
        log.log("---{}Epoch: {}/{} -- UUID: {}".format(os.linesep, epoch + 1, args.epochs, UUID))

        process(train_loader, config, log, device, nn, criterion, "train", optimizer)
        if not args.no_validation:
            process(val_loader, config, log, device, nn, criterion, "eval")

        try:  # https://github.com/pytorch/pytorch/issues/9176
            nn_doc = nn.module.doc
            nn_version = nn.module.version
        except AttributeError:
            nn_version = nn.version
            nn_doc == nn.doc

        states = {
            "uuid": UUID,
            "model_version": nn_version,
            "producer_name": "RoboSat.pink",
            "producer_version": "0.4.0",
            "model_licence": "MIT",
            "domain": "pink.RoboSat",  # reverse-DNS
            "doc_string": nn_doc,
            "shape_in": loader_train.shape_in,
            "shape_out": loader_train.shape_out,
            "state_dict": nn.state_dict(),
            "epoch": epoch + 1,
            "nn": config["model"]["nn"],
            "optimizer": optimizer.state_dict(),
            "loader": config["model"]["loader"],
        }
        checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}.pth".format(epoch + 1))
        torch.save(states, checkpoint_path)
コード例 #10
0
def main(args):
    config = load_config(args.config)
    check_channels(config)
    check_classes(config)
    palette = make_palette([classe["color"] for classe in config["classes"]])
    if not args.bs:
        try:
            args.bs = config["model"]["bs"]
        except:
            pass

    assert args.bs, "For rsp predict, model/bs must be set either in config file, or pass trought parameter --bs"
    args.workers = args.bs if not args.workers else args.workers
    cover = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))
             ] if args.cover else None

    log = Logs(os.path.join(args.out, "log"))

    if torch.cuda.is_available():
        log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(
            torch.cuda.device_count(), args.workers))
        log.log("(Torch:{} Cuda:{} CudNN:{})".format(
            torch.__version__, torch.version.cuda,
            torch.backends.cudnn.version()))
        device = torch.device("cuda")
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
    else:
        log.log("RoboSat.pink - predict on CPU, with {} workers".format(
            args.workers))
        log.log("")
        log.log("============================================================")
        log.log("WARNING: Are you -really- sure about not predicting on GPU ?")
        log.log("============================================================")
        log.log("")
        device = torch.device("cpu")

    chkpt = torch.load(args.checkpoint, map_location=device)
    nn_module = load_module("robosat_pink.nn.{}".format(chkpt["nn"].lower()))
    nn = getattr(nn_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"],
                                         chkpt["encoder"].lower()).to(device)
    nn = torch.nn.DataParallel(nn)
    nn.load_state_dict(chkpt["state_dict"])
    nn.eval()

    log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"]))

    with torch.no_grad(
    ):  # don't track tensors with autograd during prediction

        tiled = []
        if args.passes in ["first", "both"]:
            log.log("== Predict First Pass ==")
            tiled = predict(config, cover, args, palette, chkpt, nn, device,
                            "predict")

        if args.passes in ["second", "both"]:
            log.log("== Predict Second Pass ==")
            predict(config, cover, args, palette, chkpt, nn, device,
                    "predict_translate")

    if not args.no_web_ui and tiled:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "."
        web_ui(args.out, base_url, tiled, tiled, "png", template)