コード例 #1
0
    def test_len(self):
        path = "tests/fixtures"
        target = "tests/fixtures/labels"
        channels = [{"sub": "images", "bands": [1, 2, 3]}]

        transform = JointCompose(
            [JointTransform(ImageToTensor(), MaskToTensor())])
        dataset = DatasetTilesConcat(path, channels, target, transform)

        self.assertEqual(len(dataset), 3)
コード例 #2
0
    def test_getitem(self):
        path = "tests/fixtures"
        target = "tests/fixtures/labels"
        channels = [{"sub": "images", "bands": [1, 2, 3]}]

        transform = JointCompose(
            [JointTransform(ImageToTensor(), MaskToTensor())])
        dataset = DatasetTilesConcat(path, channels, target, transform)

        images, mask, tiles = dataset[0]
        self.assertEqual(tiles, mercantile.Tile(69105, 105093, 18))
        self.assertEqual(type(images), torch.Tensor)
        self.assertEqual(type(mask), torch.Tensor)
コード例 #3
0
ファイル: train.py プロジェクト: martham93/robosat.pink
def get_dataset_loaders(path, config, workers):

    std = []
    mean = []
    for channel in config["channels"]:
        std.extend(channel["std"])
        mean.extend(channel["mean"])

    transform = JointCompose([
        JointResize(config["model"]["tile_size"]),
        JointRandomFlipOrRotate(config["model"]["data_augmentation"]),
        JointTransform(ImageToTensor(), MaskToTensor()),
        JointTransform(Normalize(mean=mean, std=std), None),
    ])

    train_dataset = SlippyMapTilesConcatenation(
        os.path.join(path, "training"),
        config["channels"],
        os.path.join(path, "training", "labels"),
        joint_transform=transform,
    )

    val_dataset = SlippyMapTilesConcatenation(
        os.path.join(path, "validation"),
        config["channels"],
        os.path.join(path, "validation", "labels"),
        joint_transform=transform,
    )

    batch_size = config["model"]["batch_size"]
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=workers)
    val_loader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            drop_last=True,
                            num_workers=workers)

    return train_loader, val_loader