Пример #1
0
    def __init__(
        self,
        train_transform: Optional[Dict[str, Callable]] = None,
        val_transform: Optional[Dict[str, Callable]] = None,
        test_transform: Optional[Dict[str, Callable]] = None,
        predict_transform: Optional[Dict[str, Callable]] = None,
        image_size: Tuple[int, int] = (196, 196),
        deserializer: Optional['Deserializer'] = None,
        num_classes: int = None,
        labels_map: Dict[int, Tuple[int, int, int]] = None,
        **data_source_kwargs: Any,
    ) -> None:
        """Preprocess pipeline for semantic segmentation tasks.

        Args:
            train_transform: Dictionary with the set of transforms to apply during training.
            val_transform: Dictionary with the set of transforms to apply during validation.
            test_transform: Dictionary with the set of transforms to apply during testing.
            predict_transform: Dictionary with the set of transforms to apply during prediction.
            image_size: A tuple with the expected output image size.
            **data_source_kwargs: Additional arguments passed on to the data source constructors.
        """
        if not _IMAGE_AVAILABLE:
            raise ModuleNotFoundError(
                "Please, pip install 'lightning-flash[image]'")
        self.image_size = image_size
        self.num_classes = num_classes
        if num_classes:
            labels_map = labels_map or SegmentationLabels.create_random_labels_map(
                num_classes)

        super().__init__(
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            predict_transform=predict_transform,
            data_sources={
                DefaultDataSources.FIFTYONE:
                SemanticSegmentationFiftyOneDataSource(**data_source_kwargs),
                DefaultDataSources.FILES:
                SemanticSegmentationPathsDataSource(),
                DefaultDataSources.FOLDERS:
                SemanticSegmentationPathsDataSource(),
                DefaultDataSources.TENSORS:
                SemanticSegmentationTensorDataSource(),
                DefaultDataSources.NUMPY:
                SemanticSegmentationNumpyDataSource(),
            },
            deserializer=deserializer or SemanticSegmentationDeserializer(),
            default_data_source=DefaultDataSources.FILES,
        )

        if labels_map:
            self.set_state(ImageLabelsMap(labels_map))

        self.labels_map = labels_map
Пример #2
0
    def from_data_source(
        cls,
        data_source: str,
        train_data: Any = None,
        val_data: Any = None,
        test_data: Any = None,
        predict_data: Any = None,
        train_transform: Optional[Dict[str, Callable]] = None,
        val_transform: Optional[Dict[str, Callable]] = None,
        test_transform: Optional[Dict[str, Callable]] = None,
        predict_transform: Optional[Dict[str, Callable]] = None,
        data_fetcher: Optional[BaseDataFetcher] = None,
        preprocess: Optional[Preprocess] = None,
        val_split: Optional[float] = None,
        batch_size: int = 4,
        num_workers: int = 0,
        **preprocess_kwargs: Any,
    ) -> "DataModule":

        if "num_classes" not in preprocess_kwargs:
            raise MisconfigurationException(
                "`num_classes` should be provided during instantiation.")

        num_classes = preprocess_kwargs["num_classes"]

        labels_map = getattr(
            preprocess_kwargs, "labels_map",
            None) or SegmentationLabels.create_random_labels_map(num_classes)

        data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map)

        if flash._IS_TESTING:
            data_fetcher.block_viz_window = True

        dm = super().from_data_source(
            data_source=data_source,
            train_data=train_data,
            val_data=val_data,
            test_data=test_data,
            predict_data=predict_data,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            predict_transform=predict_transform,
            data_fetcher=data_fetcher,
            preprocess=preprocess,
            val_split=val_split,
            batch_size=batch_size,
            num_workers=num_workers,
            **preprocess_kwargs,
        )

        if dm.train_dataset is not None:
            dm.train_dataset.num_classes = num_classes
        return dm