def test_len(self): path = "tests/fixtures" target = "tests/fixtures/labels" channels = [{"sub": "images", "bands": [1, 2, 3]}] transform = JointCompose( [JointTransform(ImageToTensor(), MaskToTensor())]) dataset = DatasetTilesConcat(path, channels, target, transform) self.assertEqual(len(dataset), 3)
def test_getitem(self): path = "tests/fixtures" target = "tests/fixtures/labels" channels = [{"sub": "images", "bands": [1, 2, 3]}] transform = JointCompose( [JointTransform(ImageToTensor(), MaskToTensor())]) dataset = DatasetTilesConcat(path, channels, target, transform) images, mask, tiles = dataset[0] self.assertEqual(tiles, mercantile.Tile(69105, 105093, 18)) self.assertEqual(type(images), torch.Tensor) self.assertEqual(type(mask), torch.Tensor)
def get_dataset_loaders(path, config, workers): std = [] mean = [] for channel in config["channels"]: std.extend(channel["std"]) mean.extend(channel["mean"]) transform = JointCompose([ JointResize(config["model"]["tile_size"]), JointRandomFlipOrRotate(config["model"]["data_augmentation"]), JointTransform(ImageToTensor(), MaskToTensor()), JointTransform(Normalize(mean=mean, std=std), None), ]) train_dataset = SlippyMapTilesConcatenation( os.path.join(path, "training"), config["channels"], os.path.join(path, "training", "labels"), joint_transform=transform, ) val_dataset = SlippyMapTilesConcatenation( os.path.join(path, "validation"), config["channels"], os.path.join(path, "validation", "labels"), joint_transform=transform, ) batch_size = config["model"]["batch_size"] train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=workers) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=workers) return train_loader, val_loader