def test_serialize():
        serial = SegmentationLabels()

        sample = torch.zeros(5, 2, 3)
        sample[1, 1, 2] = 1  # add peak in class 2
        sample[3, 0, 1] = 1  # add peak in class 4

        classes = serial.serialize({DefaultDataKeys.PREDS: sample})
        assert torch.tensor(classes)[1, 2] == 1
        assert torch.tensor(classes)[0, 1] == 3
Ejemplo n.º 2
0
    def _show_images_and_labels(self, data: List[Any], num_samples: int,
                                title: str):
        # define the image grid
        cols: int = min(num_samples, self.max_cols)
        rows: int = num_samples // cols

        # create figure and set title
        fig, axs = plt.subplots(rows, cols)
        fig.suptitle(title)

        for i, ax in enumerate(axs.ravel()):
            # unpack images and labels
            sample = data[i]
            if isinstance(sample, dict):
                image = sample[DefaultDataKeys.INPUT]
                label = sample[DefaultDataKeys.TARGET]
            elif isinstance(sample, tuple):
                image = sample[0]
                label = sample[1]
            else:
                raise TypeError(f"Unknown data type. Got: {type(data)}.")
            # convert images and labels to numpy and stack horizontally
            image_vis: np.ndarray = self._to_numpy(image.byte())
            label_tmp: torch.Tensor = SegmentationLabels.labels_to_image(
                label.squeeze().byte(), self.labels_map)
            label_vis: np.ndarray = self._to_numpy(label_tmp)
            img_vis = np.hstack((image_vis, label_vis))
            # send to visualiser
            ax.imshow(img_vis)
            ax.axis("off")
        plt.show(block=self.block_viz_window)
Ejemplo n.º 3
0
    def __init__(
        self,
        num_classes: int,
        backbone: Union[str, nn.Module] = "resnet50",
        backbone_kwargs: Optional[Dict] = None,
        head: str = "fcn",
        head_kwargs: Optional[Dict] = None,
        pretrained: bool = True,
        loss_fn: Optional[Callable] = None,
        optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW,
        metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None,
        learning_rate: float = 1e-3,
        multi_label: bool = False,
        serializer: Optional[Union[Serializer, Mapping[str,
                                                       Serializer]]] = None,
        postprocess: Optional[Postprocess] = None,
    ) -> None:

        if isinstance(backbone, str) and (not _TORCHVISION_AVAILABLE
                                          or not _TIMM_AVAILABLE):
            raise ModuleNotFoundError(
                "Please, pip install 'lightning-flash[image]'")

        if metrics is None:
            metrics = IoU(num_classes=num_classes)

        if loss_fn is None:
            loss_fn = F.cross_entropy

        # TODO: need to check for multi_label
        if multi_label:
            raise NotImplementedError("Multi-label not supported yet.")

        super().__init__(model=None,
                         loss_fn=loss_fn,
                         optimizer=optimizer,
                         metrics=metrics,
                         learning_rate=learning_rate,
                         serializer=serializer or SegmentationLabels(),
                         postprocess=postprocess or self.postprocess_cls())

        self.save_hyperparameters()

        if not backbone_kwargs:
            backbone_kwargs = {}

        if not head_kwargs:
            head_kwargs = {}

        if isinstance(backbone, nn.Module):
            self.backbone = backbone
        else:
            self.backbone = self.backbones.get(backbone)(pretrained=pretrained,
                                                         **backbone_kwargs)

        self.head = self.heads.get(head)(self.backbone, num_classes,
                                         **head_kwargs)
Ejemplo n.º 4
0
    def __init__(
        self,
        train_transform: Optional[Dict[str, Callable]] = None,
        val_transform: Optional[Dict[str, Callable]] = None,
        test_transform: Optional[Dict[str, Callable]] = None,
        predict_transform: Optional[Dict[str, Callable]] = None,
        image_size: Tuple[int, int] = (196, 196),
        deserializer: Optional['Deserializer'] = None,
        num_classes: int = None,
        labels_map: Dict[int, Tuple[int, int, int]] = None,
        **data_source_kwargs: Any,
    ) -> None:
        """Preprocess pipeline for semantic segmentation tasks.

        Args:
            train_transform: Dictionary with the set of transforms to apply during training.
            val_transform: Dictionary with the set of transforms to apply during validation.
            test_transform: Dictionary with the set of transforms to apply during testing.
            predict_transform: Dictionary with the set of transforms to apply during prediction.
            image_size: A tuple with the expected output image size.
            **data_source_kwargs: Additional arguments passed on to the data source constructors.
        """
        if not _IMAGE_AVAILABLE:
            raise ModuleNotFoundError(
                "Please, pip install 'lightning-flash[image]'")
        self.image_size = image_size
        self.num_classes = num_classes
        if num_classes:
            labels_map = labels_map or SegmentationLabels.create_random_labels_map(
                num_classes)

        super().__init__(
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            predict_transform=predict_transform,
            data_sources={
                DefaultDataSources.FIFTYONE:
                SemanticSegmentationFiftyOneDataSource(**data_source_kwargs),
                DefaultDataSources.FILES:
                SemanticSegmentationPathsDataSource(),
                DefaultDataSources.FOLDERS:
                SemanticSegmentationPathsDataSource(),
                DefaultDataSources.TENSORS:
                SemanticSegmentationTensorDataSource(),
                DefaultDataSources.NUMPY:
                SemanticSegmentationNumpyDataSource(),
            },
            deserializer=deserializer or SemanticSegmentationDeserializer(),
            default_data_source=DefaultDataSources.FILES,
        )

        if labels_map:
            self.set_state(ImageLabelsMap(labels_map))

        self.labels_map = labels_map
Ejemplo n.º 5
0
    def from_data_source(
        cls,
        data_source: str,
        train_data: Any = None,
        val_data: Any = None,
        test_data: Any = None,
        predict_data: Any = None,
        train_transform: Optional[Dict[str, Callable]] = None,
        val_transform: Optional[Dict[str, Callable]] = None,
        test_transform: Optional[Dict[str, Callable]] = None,
        predict_transform: Optional[Dict[str, Callable]] = None,
        data_fetcher: Optional[BaseDataFetcher] = None,
        preprocess: Optional[Preprocess] = None,
        val_split: Optional[float] = None,
        batch_size: int = 4,
        num_workers: int = 0,
        **preprocess_kwargs: Any,
    ) -> "DataModule":

        if "num_classes" not in preprocess_kwargs:
            raise MisconfigurationException(
                "`num_classes` should be provided during instantiation.")

        num_classes = preprocess_kwargs["num_classes"]

        labels_map = getattr(
            preprocess_kwargs, "labels_map",
            None) or SegmentationLabels.create_random_labels_map(num_classes)

        data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map)

        if flash._IS_TESTING:
            data_fetcher.block_viz_window = True

        dm = super().from_data_source(
            data_source=data_source,
            train_data=train_data,
            val_data=val_data,
            test_data=test_data,
            predict_data=predict_data,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            predict_transform=predict_transform,
            data_fetcher=data_fetcher,
            preprocess=preprocess,
            val_split=val_split,
            batch_size=batch_size,
            num_workers=num_workers,
            **preprocess_kwargs,
        )

        if dm.train_dataset is not None:
            dm.train_dataset.num_classes = num_classes
        return dm
    def test_exception():
        serial = SegmentationLabels()

        with pytest.raises(Exception):
            sample = torch.zeros(1, 5, 2, 3)
            serial.serialize(sample)

        with pytest.raises(Exception):
            sample = torch.zeros(2, 3)
            serial.serialize(sample)
 def test_smoke():
     serial = SegmentationLabels()
     assert serial is not None
     assert serial.labels_map is None
     assert serial.visualize is False
    num_classes=21,
)

# 2.2 Visualise the samples
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3.a List available backbones and heads
print(f"Backbones: {SemanticSegmentation.available_backbones()}")
print(f"Heads: {SemanticSegmentation.available_heads()}")

# 3.b Build the model
model = SemanticSegmentation(
    backbone="mobilenet_v3_large",
    head="fcn",
    num_classes=datamodule.num_classes,
    serializer=SegmentationLabels(visualize=False),
)

# 4. Create the trainer.
trainer = flash.Trainer(fast_dev_run=True)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Segment a few images!
predictions = model.predict([
    "data/CameraRGB/F61-1.png",
    "data/CameraRGB/F62-1.png",
    "data/CameraRGB/F63-1.png",
])
datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    batch_size=4,
    val_split=0.3,
    image_size=(200, 200),  # (600, 800)
    num_classes=21,
)

# 2.2 Visualise the samples
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3. Build the model
model = SemanticSegmentation(backbone="torchvision/fcn_resnet50",
                             num_classes=datamodule.num_classes,
                             serializer=SegmentationLabels(visualize=True))

# 4. Create the trainer.
trainer = flash.Trainer(
    max_epochs=1,
    fast_dev_run=1,
)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

predictions = model.predict([
    "data/CameraRGB/F61-1.png",
    "data/CameraRGB/F62-1.png",
    "data/CameraRGB/F63-1.png",
])