def test_serialize(): serial = SegmentationLabels() sample = torch.zeros(5, 2, 3) sample[1, 1, 2] = 1 # add peak in class 2 sample[3, 0, 1] = 1 # add peak in class 4 classes = serial.serialize({DefaultDataKeys.PREDS: sample}) assert torch.tensor(classes)[1, 2] == 1 assert torch.tensor(classes)[0, 1] == 3
def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str): # define the image grid cols: int = min(num_samples, self.max_cols) rows: int = num_samples // cols # create figure and set title fig, axs = plt.subplots(rows, cols) fig.suptitle(title) for i, ax in enumerate(axs.ravel()): # unpack images and labels sample = data[i] if isinstance(sample, dict): image = sample[DefaultDataKeys.INPUT] label = sample[DefaultDataKeys.TARGET] elif isinstance(sample, tuple): image = sample[0] label = sample[1] else: raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images and labels to numpy and stack horizontally image_vis: np.ndarray = self._to_numpy(image.byte()) label_tmp: torch.Tensor = SegmentationLabels.labels_to_image( label.squeeze().byte(), self.labels_map) label_vis: np.ndarray = self._to_numpy(label_tmp) img_vis = np.hstack((image_vis, label_vis)) # send to visualiser ax.imshow(img_vis) ax.axis("off") plt.show(block=self.block_viz_window)
def __init__( self, num_classes: int, backbone: Union[str, nn.Module] = "resnet50", backbone_kwargs: Optional[Dict] = None, head: str = "fcn", head_kwargs: Optional[Dict] = None, pretrained: bool = True, loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None, learning_rate: float = 1e-3, multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, postprocess: Optional[Postprocess] = None, ) -> None: if isinstance(backbone, str) and (not _TORCHVISION_AVAILABLE or not _TIMM_AVAILABLE): raise ModuleNotFoundError( "Please, pip install 'lightning-flash[image]'") if metrics is None: metrics = IoU(num_classes=num_classes) if loss_fn is None: loss_fn = F.cross_entropy # TODO: need to check for multi_label if multi_label: raise NotImplementedError("Multi-label not supported yet.") super().__init__(model=None, loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, serializer=serializer or SegmentationLabels(), postprocess=postprocess or self.postprocess_cls()) self.save_hyperparameters() if not backbone_kwargs: backbone_kwargs = {} if not head_kwargs: head_kwargs = {} if isinstance(backbone, nn.Module): self.backbone = backbone else: self.backbone = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) self.head = self.heads.get(head)(self.backbone, num_classes, **head_kwargs)
def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), deserializer: Optional['Deserializer'] = None, num_classes: int = None, labels_map: Dict[int, Tuple[int, int, int]] = None, **data_source_kwargs: Any, ) -> None: """Preprocess pipeline for semantic segmentation tasks. Args: train_transform: Dictionary with the set of transforms to apply during training. val_transform: Dictionary with the set of transforms to apply during validation. test_transform: Dictionary with the set of transforms to apply during testing. predict_transform: Dictionary with the set of transforms to apply during prediction. image_size: A tuple with the expected output image size. **data_source_kwargs: Additional arguments passed on to the data source constructors. """ if not _IMAGE_AVAILABLE: raise ModuleNotFoundError( "Please, pip install 'lightning-flash[image]'") self.image_size = image_size self.num_classes = num_classes if num_classes: labels_map = labels_map or SegmentationLabels.create_random_labels_map( num_classes) super().__init__( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, data_sources={ DefaultDataSources.FIFTYONE: SemanticSegmentationFiftyOneDataSource(**data_source_kwargs), DefaultDataSources.FILES: SemanticSegmentationPathsDataSource(), DefaultDataSources.FOLDERS: SemanticSegmentationPathsDataSource(), DefaultDataSources.TENSORS: SemanticSegmentationTensorDataSource(), DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(), }, deserializer=deserializer or SemanticSegmentationDeserializer(), default_data_source=DefaultDataSources.FILES, ) if labels_map: self.set_state(ImageLabelsMap(labels_map)) self.labels_map = labels_map
def from_data_source( cls, data_source: str, train_data: Any = None, val_data: Any = None, test_data: Any = None, predict_data: Any = None, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, batch_size: int = 4, num_workers: int = 0, **preprocess_kwargs: Any, ) -> "DataModule": if "num_classes" not in preprocess_kwargs: raise MisconfigurationException( "`num_classes` should be provided during instantiation.") num_classes = preprocess_kwargs["num_classes"] labels_map = getattr( preprocess_kwargs, "labels_map", None) or SegmentationLabels.create_random_labels_map(num_classes) data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map) if flash._IS_TESTING: data_fetcher.block_viz_window = True dm = super().from_data_source( data_source=data_source, train_data=train_data, val_data=val_data, test_data=test_data, predict_data=predict_data, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, data_fetcher=data_fetcher, preprocess=preprocess, val_split=val_split, batch_size=batch_size, num_workers=num_workers, **preprocess_kwargs, ) if dm.train_dataset is not None: dm.train_dataset.num_classes = num_classes return dm
def test_exception(): serial = SegmentationLabels() with pytest.raises(Exception): sample = torch.zeros(1, 5, 2, 3) serial.serialize(sample) with pytest.raises(Exception): sample = torch.zeros(2, 3) serial.serialize(sample)
def test_smoke(): serial = SegmentationLabels() assert serial is not None assert serial.labels_map is None assert serial.visualize is False
num_classes=21, ) # 2.2 Visualise the samples datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) # 3.a List available backbones and heads print(f"Backbones: {SemanticSegmentation.available_backbones()}") print(f"Heads: {SemanticSegmentation.available_heads()}") # 3.b Build the model model = SemanticSegmentation( backbone="mobilenet_v3_large", head="fcn", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=False), ) # 4. Create the trainer. trainer = flash.Trainer(fast_dev_run=True) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 6. Segment a few images! predictions = model.predict([ "data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png", "data/CameraRGB/F63-1.png", ])
datamodule = SemanticSegmentationData.from_folders( train_folder="data/CameraRGB", train_target_folder="data/CameraSeg", batch_size=4, val_split=0.3, image_size=(200, 200), # (600, 800) num_classes=21, ) # 2.2 Visualise the samples datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) # 3. Build the model model = SemanticSegmentation(backbone="torchvision/fcn_resnet50", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True)) # 4. Create the trainer. trainer = flash.Trainer( max_epochs=1, fast_dev_run=1, ) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") predictions = model.predict([ "data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png", "data/CameraRGB/F63-1.png", ])