Пример #1
0
    def test_len(self):
        path = "tests/fixtures"
        target = "tests/fixtures/labels"
        channels = [{"sub": "images", "bands": [1, 2, 3]}]

        transform = JointCompose(
            [JointTransform(ImageToTensor(), MaskToTensor())])
        dataset = DatasetTilesConcat(path, channels, target, transform)

        self.assertEqual(len(dataset), 3)
Пример #2
0
    def test_getitem(self):
        path = "tests/fixtures"
        target = "tests/fixtures/labels"
        channels = [{"sub": "images", "bands": [1, 2, 3]}]

        transform = JointCompose(
            [JointTransform(ImageToTensor(), MaskToTensor())])
        dataset = DatasetTilesConcat(path, channels, target, transform)

        images, mask, tiles = dataset[0]
        self.assertEqual(tiles, mercantile.Tile(69105, 105093, 18))
        self.assertEqual(type(images), torch.Tensor)
        self.assertEqual(type(mask), torch.Tensor)
Пример #3
0
def get_dataset_loaders(path, config, workers):

    std = []
    mean = []
    for channel in config["channels"]:
        std.extend(channel["std"])
        mean.extend(channel["mean"])

    transform = JointCompose([
        JointResize(config["model"]["tile_size"]),
        JointRandomFlipOrRotate(config["model"]["data_augmentation"]),
        JointTransform(ImageToTensor(), MaskToTensor()),
        JointTransform(Normalize(mean=mean, std=std), None),
    ])

    train_dataset = SlippyMapTilesConcatenation(
        os.path.join(path, "training"),
        config["channels"],
        os.path.join(path, "training", "labels"),
        joint_transform=transform,
    )

    val_dataset = SlippyMapTilesConcatenation(
        os.path.join(path, "validation"),
        config["channels"],
        os.path.join(path, "validation", "labels"),
        joint_transform=transform,
    )

    batch_size = config["model"]["batch_size"]
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=workers)
    val_loader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            drop_last=True,
                            num_workers=workers)

    return train_loader, val_loader
Пример #4
0
def main(args):
    config = load_config(args.config)
    num_classes = len(config["classes"])
    batch_size = args.batch_size if args.batch_size else config["model"][
        "batch_size"]
    tile_size = args.tile_size if args.tile_size else config["model"][
        "tile_size"]

    if torch.cuda.is_available():
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    def map_location(storage, _):
        return storage.cuda() if torch.cuda.is_available() else storage.cpu()

    # https://github.com/pytorch/pytorch/issues/7178
    chkpt = torch.load(args.checkpoint, map_location=map_location)

    models = [
        name for _, name, _ in pkgutil.iter_modules(
            [os.path.dirname(robosat_pink.models.__file__)])
    ]
    if config["model"]["name"] not in [model for model in models]:
        sys.exit("Unknown model, thoses available are {}".format(
            [model for model in models]))

    std = []
    mean = []
    num_channels = 0
    for channel in config["channels"]:
        std.extend(channel["std"])
        mean.extend(channel["mean"])
        num_channels += len(channel["bands"])

    encoder = config["model"]["encoder"]
    pretrained = config["model"]["pretrained"]

    model_module = import_module("robosat_pink.models.{}".format(
        config["model"]["name"]))

    net = getattr(model_module, "{}".format(config["model"]["name"].title()))(
        num_classes=num_classes,
        num_channels=num_channels,
        encoder=encoder,
        pretrained=pretrained).to(device)

    net = torch.nn.DataParallel(net)

    net.load_state_dict(chkpt["state_dict"])
    net.eval()

    transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)])
    directory = BufferedSlippyMapTiles(args.tiles,
                                       transform=transform,
                                       size=tile_size,
                                       overlap=args.overlap)
    loader = DataLoader(directory,
                        batch_size=batch_size,
                        num_workers=args.workers)

    palette = make_palette(config["classes"][0]["color"],
                           config["classes"][1]["color"])

    # don't track tensors with autograd during prediction
    with torch.no_grad():
        for images, tiles in tqdm(loader,
                                  desc="Eval",
                                  unit="batch",
                                  ascii=True):
            images = images.to(device)
            outputs = net(images)

            # manually compute segmentation mask class probabilities per pixel
            probs = torch.nn.functional.softmax(outputs,
                                                dim=1).data.cpu().numpy()

            for tile, prob in zip(tiles, probs):
                x, y, z = list(map(int, tile))

                # we predicted on buffered tiles; now get back probs for original image
                prob = directory.unbuffer(prob)

                assert prob.shape[
                    0] == 2, "single channel requires binary model"
                assert np.allclose(
                    np.sum(prob, axis=0), 1.0
                ), "single channel requires probabilities to sum up to one"

                image = np.around(prob[1:, :, :]).astype(np.uint8).squeeze()

                out = Image.fromarray(image, mode="P")
                out.putpalette(palette)

                os.makedirs(os.path.join(args.probs, str(z), str(x)),
                            exist_ok=True)
                path = os.path.join(args.probs, str(z), str(x),
                                    str(y) + ".png")

                out.save(path, optimize=True)

    if args.web_ui:
        template = "leaflet.html" if not args.web_ui_template else args.web_ui_template
        base_url = args.web_ui_base_url if args.web_ui_base_url else "./"
        tiles = [tile for tile, _ in tiles_from_slippy_map(args.tiles)]
        web_ui(args.probs, base_url, tiles, tiles, "png", template)