コード例 #1
0
    def test_serialize(self):
        serial = SegmentationLabels()

        sample = torch.zeros(5, 2, 3)
        sample[1, 1, 2] = 1  # add peak in class 2
        sample[3, 0, 1] = 1  # add peak in class 4

        classes = serial.serialize(sample)
        assert classes[1, 2] == 1
        assert classes[0, 1] == 3
コード例 #2
0
    def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str):
        # define the image grid
        cols: int = min(num_samples, self.max_cols)
        rows: int = num_samples // cols

        if not _MATPLOTLIB_AVAILABLE:
            raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib")

        # create figure and set title
        fig, axs = plt.subplots(rows, cols)
        fig.suptitle(title)

        for i, ax in enumerate(axs.ravel()):
            # unpack images and labels
            sample = data[i]
            if isinstance(sample, dict):
                image = sample[DefaultDataKeys.INPUT]
                label = sample[DefaultDataKeys.TARGET]
            elif isinstance(sample, tuple):
                image = sample[0]
                label = sample[1]
            else:
                raise TypeError(f"Unknown data type. Got: {type(data)}.")
            # convert images and labels to numpy and stack horizontally
            image_vis: np.ndarray = self._to_numpy(image.byte())
            label_tmp: torch.Tensor = SegmentationLabels.labels_to_image(label.squeeze().byte(), self.labels_map)
            label_vis: np.ndarray = self._to_numpy(label_tmp)
            img_vis = np.hstack((image_vis, label_vis))
            # send to visualiser
            ax.imshow(img_vis)
            ax.axis('off')
        plt.show(block=self.block_viz_window)
コード例 #3
0
    def test_exception(self):
        serial = SegmentationLabels()

        with pytest.raises(Exception):
            sample = torch.zeros(1, 5, 2, 3)
            serial.serialize(sample)

        with pytest.raises(Exception):
            sample = torch.zeros(2, 3)
            serial.serialize(sample)
コード例 #4
0
    def __init__(
        self,
        num_classes: int,
        backbone: Union[str, Tuple[nn.Module,
                                   int]] = "torchvision/fcn_resnet50",
        backbone_kwargs: Optional[Dict] = None,
        pretrained: bool = True,
        loss_fn: Optional[Callable] = None,
        optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW,
        metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None,
        learning_rate: float = 1e-3,
        multi_label: bool = False,
        serializer: Optional[Union[Serializer, Mapping[str,
                                                       Serializer]]] = None,
    ) -> None:

        if metrics is None:
            metrics = IoU(num_classes=num_classes)

        if loss_fn is None:
            loss_fn = F.cross_entropy

        # TODO: need to check for multi_label
        if multi_label:
            raise NotImplementedError("Multi-label not supported yet.")

        super().__init__(
            model=None,
            loss_fn=loss_fn,
            optimizer=optimizer,
            metrics=metrics,
            learning_rate=learning_rate,
            serializer=serializer or SegmentationLabels(),
        )

        self.save_hyperparameters()

        if not backbone_kwargs:
            backbone_kwargs = {}

        # TODO: pretrained to True causes some issues
        self.backbone = self.backbones.get(backbone)(num_classes,
                                                     pretrained=pretrained,
                                                     **backbone_kwargs)
コード例 #5
0
 def test_smoke(self):
     serial = SegmentationLabels()
     assert serial is not None
     assert serial.labels_map is None
     assert serial.visualize is False
コード例 #6
0
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
    "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/"
)

# 2.1 Load the data
datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    batch_size=4,
    val_split=0.3,
    image_size=(200, 200),  # (600, 800)
)

# 2.2 Visualise the samples
labels_map = SegmentationLabels.create_random_labels_map(num_classes=21)
datamodule.set_labels_map(labels_map)
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3. Build the model
model = SemanticSegmentation(
    backbone="torchvision/fcn_resnet50",
    num_classes=21,
)

# 4. Create the trainer.
trainer = flash.Trainer(
    max_epochs=1,
    fast_dev_run=1,
)
コード例 #7
0
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.data.utils import download_data
from flash.vision import SemanticSegmentation
from flash.vision.segmentation.serialization import SegmentationLabels

# 1. Download the data
# This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator.
# The data was generated as part of the Lyft Udacity Challenge.
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
    "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
    "data/")

# 2. Load the model from a checkpoint
model = SemanticSegmentation.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt")
model.serializer = SegmentationLabels(visualize=True)

# 3. Predict what's on a few images and visualize!
predictions = model.predict([
    "data/CameraRGB/F61-1.png",
    "data/CameraRGB/F62-1.png",
    "data/CameraRGB/F63-1.png",
])