def test_len(self): path = "tests/fixtures" target = "tests/fixtures/labels" channels = [{"sub": "images", "bands": [1, 2, 3]}] transform = JointCompose( [JointTransform(ImageToTensor(), MaskToTensor())]) dataset = DatasetTilesConcat(path, channels, target, transform) self.assertEqual(len(dataset), 3)
def test_getitem(self): path = "tests/fixtures" target = "tests/fixtures/labels" channels = [{"sub": "images", "bands": [1, 2, 3]}] transform = JointCompose( [JointTransform(ImageToTensor(), MaskToTensor())]) dataset = DatasetTilesConcat(path, channels, target, transform) images, mask, tiles = dataset[0] self.assertEqual(tiles, mercantile.Tile(69105, 105093, 18)) self.assertEqual(type(images), torch.Tensor) self.assertEqual(type(mask), torch.Tensor)
def get_dataset_loaders(path, config, workers): std = [] mean = [] for channel in config["channels"]: std.extend(channel["std"]) mean.extend(channel["mean"]) transform = JointCompose([ JointResize(config["model"]["tile_size"]), JointRandomFlipOrRotate(config["model"]["data_augmentation"]), JointTransform(ImageToTensor(), MaskToTensor()), JointTransform(Normalize(mean=mean, std=std), None), ]) train_dataset = SlippyMapTilesConcatenation( os.path.join(path, "training"), config["channels"], os.path.join(path, "training", "labels"), joint_transform=transform, ) val_dataset = SlippyMapTilesConcatenation( os.path.join(path, "validation"), config["channels"], os.path.join(path, "validation", "labels"), joint_transform=transform, ) batch_size = config["model"]["batch_size"] train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=workers) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=workers) return train_loader, val_loader
def main(args): config = load_config(args.config) num_classes = len(config["classes"]) batch_size = args.batch_size if args.batch_size else config["model"][ "batch_size"] tile_size = args.tile_size if args.tile_size else config["model"][ "tile_size"] if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cudnn.benchmark = True else: device = torch.device("cpu") def map_location(storage, _): return storage.cuda() if torch.cuda.is_available() else storage.cpu() # https://github.com/pytorch/pytorch/issues/7178 chkpt = torch.load(args.checkpoint, map_location=map_location) models = [ name for _, name, _ in pkgutil.iter_modules( [os.path.dirname(robosat_pink.models.__file__)]) ] if config["model"]["name"] not in [model for model in models]: sys.exit("Unknown model, thoses available are {}".format( [model for model in models])) std = [] mean = [] num_channels = 0 for channel in config["channels"]: std.extend(channel["std"]) mean.extend(channel["mean"]) num_channels += len(channel["bands"]) encoder = config["model"]["encoder"] pretrained = config["model"]["pretrained"] model_module = import_module("robosat_pink.models.{}".format( config["model"]["name"])) net = getattr(model_module, "{}".format(config["model"]["name"].title()))( num_classes=num_classes, num_channels=num_channels, encoder=encoder, pretrained=pretrained).to(device) net = torch.nn.DataParallel(net) net.load_state_dict(chkpt["state_dict"]) net.eval() transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)]) directory = BufferedSlippyMapTiles(args.tiles, transform=transform, size=tile_size, overlap=args.overlap) loader = DataLoader(directory, batch_size=batch_size, num_workers=args.workers) palette = make_palette(config["classes"][0]["color"], config["classes"][1]["color"]) # don't track tensors with autograd during prediction with torch.no_grad(): for images, tiles in tqdm(loader, desc="Eval", unit="batch", ascii=True): images = images.to(device) outputs = net(images) # manually compute segmentation mask class probabilities per pixel probs = torch.nn.functional.softmax(outputs, dim=1).data.cpu().numpy() for tile, prob in zip(tiles, probs): x, y, z = list(map(int, tile)) # we predicted on buffered tiles; now get back probs for original image prob = directory.unbuffer(prob) assert prob.shape[ 0] == 2, "single channel requires binary model" assert np.allclose( np.sum(prob, axis=0), 1.0 ), "single channel requires probabilities to sum up to one" image = np.around(prob[1:, :, :]).astype(np.uint8).squeeze() out = Image.fromarray(image, mode="P") out.putpalette(palette) os.makedirs(os.path.join(args.probs, str(z), str(x)), exist_ok=True) path = os.path.join(args.probs, str(z), str(x), str(y) + ".png") out.save(path, optimize=True) if args.web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template base_url = args.web_ui_base_url if args.web_ui_base_url else "./" tiles = [tile for tile, _ in tiles_from_slippy_map(args.tiles)] web_ui(args.probs, base_url, tiles, tiles, "png", template)