def test_generator_combine_data_sources():
    kitti_data_source: KittiDataSource = KittiDataSource(KITTI_BASE_DIR)
    cityscapes_data_source: CityscapesDataSource = CityscapesDataSource(
        CITYSCAPES_BASE_DIR)
    data_sources: List[DataSource] = [
        kitti_data_source, cityscapes_data_source
    ]

    train_data_generator: DataGenerator = DataGenerator(
        data_sources=data_sources,
        phase='train',
        batch_size=4,
        transformation=Crop((256, 256)),
        target_size=(256, 256),
        active_labels=[0, 1],
        random_seed=42)

    assert len(train_data_generator) == 784

    val_data_generator: DataGenerator = DataGenerator(
        data_sources=data_sources,
        phase='val',
        batch_size=4,
        transformation=Crop((256, 256)),
        target_size=(256, 256),
        active_labels=[0, 1],
        random_seed=42)

    assert len(val_data_generator) == 135
def test_generator_combine_data_sources_returns_sample_weight():
    kitti_data_source: KittiDataSource = KittiDataSource(KITTI_BASE_DIR)
    cityscapes_data_source: CityscapesDataSource = CityscapesDataSource(
        CITYSCAPES_BASE_DIR)
    data_sources: List[DataSource] = [
        kitti_data_source, cityscapes_data_source
    ]

    data_generator: DataGenerator = DataGenerator(data_sources=data_sources,
                                                  phase='train',
                                                  batch_size=4,
                                                  transformation=Crop(
                                                      (256, 256)),
                                                  target_size=(256, 256),
                                                  active_labels=[0, 1],
                                                  random_seed=42)

    original_batch = data_generator.get_batch(0)
    original_batch_tuple = original_batch[0]
    original_batch_ds_name = original_batch_tuple[2]
    assert original_batch_ds_name == 'cityscapes'

    batch_images, batch_masks, batch_sample_weights = data_generator[0]

    assert isinstance(batch_sample_weights, dict)
    assert np.equal(batch_sample_weights['kitti'], np.asarray([0, 0, 0,
                                                               0])).all()
    assert np.equal(batch_sample_weights['cityscapes'],
                    np.asarray([1, 1, 1, 1])).all()
def test_label_integrity():
    """
    Check that the resized version of the original labels can be reconstructed from the matrix which will be the
    neural network's input
    """

    kitti_data_source: KittiDataSource = KittiDataSource(KITTI_BASE_DIR)
    train_data_generator: DataGenerator = DataGenerator(
        data_sources=[kitti_data_source],
        phase='train',
        batch_size=4,
        transformation=Crop((256, 256)),
        target_size=(256, 256),
        active_labels=[0, 1],
        random_seed=42)

    original_image, original_labels, _ = train_data_generator.get_batch(0)[0]
    resized_original = Resize((256, 256))(original_labels)[0]
    resized_original_array = from_pil_to_np(resized_original)

    resized_split = split_label_image(resized_original_array,
                                      CityscapesLabels.ALL)
    resized_merged = merge_label_images(resized_split, CityscapesLabels.ALL)

    assert (resized_original_array == resized_merged).all()
def test_kitti_data_source_with_limig():
    kitti_data_source: KittiDataSource = KittiDataSource(KITTI_BASE_DIR,
                                                         limit=4)

    train_data = kitti_data_source.get_train_data()
    assert len(train_data) == 4
    val_data = kitti_data_source.get_val_data()
    assert len(val_data) == 4
    def _read_datasets(self, config_data) -> List[DataSource]:
        dataset_names: List[str] = config_data['dataset']

        datasets: List[DataSource] = []
        if 'kitti' in dataset_names:
            datasets.append(KittiDataSource(KITTI_BASE_DIR, limit=self.limit))
        if 'cityscapes' in dataset_names:
            datasets.append(
                CityscapesDataSource(CITYSCAPES_BASE_DIR, limit=self.limit))

        return datasets
def test_kitti_data_source():
    kitti_data_source: KittiDataSource = KittiDataSource(KITTI_BASE_DIR)

    train_data = kitti_data_source.get_train_data()
    assert len(train_data) == 160
    image_name: str = pl.Path(train_data[0][0]).name
    label_name: str = pl.Path(train_data[0][1]).name
    assert image_name == label_name
    assert image_name == '000079_10.png'

    val_data = kitti_data_source.get_val_data()
    assert len(val_data) == 40
    image_name: str = pl.Path(val_data[0][0]).name
    label_name: str = pl.Path(val_data[0][1]).name
    assert image_name == label_name
    assert image_name == '000095_10.png'
def test_repeated_random_crop_returns_different_images():
    # OPEN AN IMAGE
    kitti_data_source: KittiDataSource = KittiDataSource(KITTI_BASE_DIR)
    train_data: List[Tuple[str, str]] = kitti_data_source.get_train_data()
    test_camera_image_path: str = train_data[0][0]
    test_camera_image: Image.Image = Image.open(test_camera_image_path)

    # CROP IT TWICE
    random_crop: RandomCrop = RandomCrop(target_size=(16, 16))
    cropped_1: List[Image.Image] = random_crop(test_camera_image)
    cropped_2: List[Image.Image] = random_crop(test_camera_image)

    cropped_1_np: np.ndarray = from_pil_to_np(cropped_1[0])
    cropped_2_np: np.ndarray = from_pil_to_np(cropped_2[0])

    assert not np.array_equal(cropped_1_np, cropped_2_np)
def test_argmax_on_split_images():
    kitti_data_source: KittiDataSource = KittiDataSource(KITTI_BASE_DIR)
    train_data_generator: DataGenerator = DataGenerator(
        data_sources=[kitti_data_source],
        phase='train',
        batch_size=4,
        transformation=Crop((256, 256)),
        target_size=(256, 256),
        active_labels=CityscapesLabels.ALL,
        random_seed=42)

    original_image, original_labels, _ = train_data_generator.get_batch(0)[0]
    resized_original_labels = Crop((256, 256))(original_labels)[0]
    resized_original_labels_np = from_pil_to_np(resized_original_labels)

    input_image, input_labels, _ = train_data_generator[0]
    input_labels = input_labels['kitti']
    input_labels = input_labels[0]
    input_labels_merged = np.argmax(input_labels, axis=-1)

    assert (input_labels_merged == resized_original_labels_np).all()
from sem_seg.data.data_source import DataSource, KittiDataSource
from sem_seg.data.generator import DataGenerator
from sem_seg.data.transformations import merge_label_images, Fit, from_pil_to_np
from sem_seg.utils.labels import generate_semantic_rgb, CityscapesLabels
from sem_seg.utils.paths import KITTI_BASE_DIR

if __name__ == '__main__':
    """
    This demo allows to visually inspect what the generator is feeding to the model.
    """

    labels = CityscapesLabels.ALL
    index = 6

    # CREATE GENERATOR
    data_sources: List[DataSource] = [KittiDataSource(KITTI_BASE_DIR)]
    generator: DataGenerator = DataGenerator(data_sources=data_sources,
                                             phase='train',
                                             transformation=Fit((256, 256)),
                                             batch_size=1,
                                             target_size=(256, 256),
                                             active_labels=labels)

    # GENERATOR ORIGINAL IMAGES
    original_image, original_labels, _ = generator.get_batch(index)[0]
    original_image_np: np.ndarray = from_pil_to_np(original_image)
    original_labels_np: np.ndarray = from_pil_to_np(original_labels)
    original_labels_rgb: np.ndarray = generate_semantic_rgb(original_labels_np)

    # GENERATOR PRE-PROCESSED IMAGES
    image_batch, labels_batch, _ = generator[index]
예제 #10
0
from sem_seg.models.unet import unet
from sem_seg.utils.labels import CityscapesLabels, generate_semantic_rgb
from sem_seg.utils.paths import CITYSCAPES_BASE_DIR, MODELS_DIR, KITTI_BASE_DIR

if __name__ == '__main__':
    # PARAMETERS
    image_size: Tuple[int, int] = (128, 128)
    input_size: Tuple[int, int, int] = image_size + (3, )
    batch_size: int = 4
    num_epochs: int = 1
    patience: int = 1
    limit: int = 32

    # CREATE DATA GENERATOR
    data_sources: List[DataSource] = [
        KittiDataSource(KITTI_BASE_DIR, limit=limit),
        CityscapesDataSource(CITYSCAPES_BASE_DIR, limit=limit)
    ]
    train_generator = DataGenerator(data_sources=data_sources,
                                    phase='train',
                                    target_size=image_size,
                                    batch_size=batch_size,
                                    transformation=Crop(image_size),
                                    active_labels=CityscapesLabels.ALL)
    validation_generator = DataGenerator(data_sources=data_sources,
                                         phase='val',
                                         target_size=image_size,
                                         batch_size=batch_size,
                                         transformation=Crop(image_size),
                                         active_labels=CityscapesLabels.ALL)