Exemplo n.º 1
0
def test_predict_numpy():
    img = np.ones((1, 3, 64, 64))
    model = SemanticSegmentation(2, backbone="mobilenetv3_large_100")
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], list)
    assert len(out[0]) == 64
    assert len(out[0][0]) == 64
Exemplo n.º 2
0
def test_predict_numpy():
    img = np.ones((1, 3, 10, 20))
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(
        num_classes=1))
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], torch.Tensor)
    assert out[0].shape == (10, 20)
Exemplo n.º 3
0
def test_predict_numpy():
    img = np.ones((1, 3, 10, 20))
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(
        num_classes=1))
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], list)
    assert len(out[0]) == 10
    assert len(out[0][0]) == 20
Exemplo n.º 4
0
def test_predict_tensor():
    img = torch.rand(1, 3, 10, 20)
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(
        num_classes=1))
    out = model.predict(img, data_source="tensors", data_pipeline=data_pipe)
    assert isinstance(out[0], list)
    assert len(out[0]) == 10
    assert len(out[0][0]) == 20
Exemplo n.º 5
0
def test_jit(tmpdir, jitter, args):
    path = os.path.join(tmpdir, "test.pt")

    model = SemanticSegmentation(2)
    model.eval()

    model = jitter(model, *args)

    torch.jit.save(model, path)
    model = torch.jit.load(path)

    out = model(torch.rand(1, 3, 32, 32))
    assert isinstance(out, torch.Tensor)
    assert out.shape == torch.Size([1, 2, 32, 32])
Exemplo n.º 6
0
def test_predict_numpy():
    img = np.ones((1, 3, 64, 64))
    model = SemanticSegmentation(2, backbone="mobilenetv3_large_100")
    datamodule = SemanticSegmentationData.from_numpy(predict_data=img,
                                                     batch_size=1)
    trainer = Trainer()
    out = trainer.predict(model, datamodule=datamodule, output="labels")
    assert isinstance(out[0][0], list)
    assert len(out[0][0]) == 64
    assert len(out[0][0][0]) == 64
Exemplo n.º 7
0
def test_forward(num_classes, img_shape):
    model = SemanticSegmentation(
        num_classes=num_classes,
        backbone='torchvision/fcn_resnet50',
    )

    B, C, H, W = img_shape
    img = torch.rand(B, C, H, W)

    out = model(img)
    assert out.shape == (B, num_classes, H, W)
Exemplo n.º 8
0
def test_forward(num_classes, img_shape):
    model = SemanticSegmentation(
        num_classes=num_classes,
        backbone="resnet50",
        head="fpn",
    )

    B, C, H, W = img_shape
    img = torch.rand(B, C, H, W)

    out = model(img)
    assert out.shape == (B, num_classes, H, W)
    val_split=0.1,
    image_size=(200, 200),
    num_classes=21,
)

# 2.2 Visualise the samples
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3.a List available backbones and heads
print(f"Backbones: {SemanticSegmentation.available_backbones()}")
print(f"Heads: {SemanticSegmentation.available_heads()}")

# 3.b Build the model
model = SemanticSegmentation(
    backbone="mobilenet_v3_large",
    head="fcn",
    num_classes=datamodule.num_classes,
    serializer=SegmentationLabels(visualize=False),
)

# 4. Create the trainer.
trainer = flash.Trainer(fast_dev_run=True)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Segment a few images!
predictions = model.predict([
    "data/CameraRGB/F61-1.png",
    "data/CameraRGB/F62-1.png",
    "data/CameraRGB/F63-1.png",
])
# 2.1 Load the data
datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    batch_size=4,
    val_split=0.3,
    image_size=(200, 200),  # (600, 800)
    num_classes=21,
)

# 2.2 Visualise the samples
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3. Build the model
model = SemanticSegmentation(backbone="torchvision/fcn_resnet50",
                             num_classes=datamodule.num_classes,
                             serializer=SegmentationLabels(visualize=True))

# 4. Create the trainer.
trainer = flash.Trainer(
    max_epochs=1,
    fast_dev_run=1,
)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

predictions = model.predict([
    "data/CameraRGB/F61-1.png",
    "data/CameraRGB/F62-1.png",
    "data/CameraRGB/F63-1.png",
Exemplo n.º 11
0
def test_freeze():
    model = SemanticSegmentation(2)
    model.freeze()
    for p in model.backbone.parameters():
        assert p.requires_grad is False
Exemplo n.º 12
0
def test_non_existent_backbone():
    with pytest.raises(KeyError):
        SemanticSegmentation(2, "i am never going to implement this lol")
Exemplo n.º 13
0
def test_init_train(tmpdir):
    model = SemanticSegmentation(num_classes=10)
    train_dl = torch.utils.data.DataLoader(DummyDataset())
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
Exemplo n.º 14
0
def test_smoke():
    model = SemanticSegmentation(num_classes=1)
    assert model is not None
Exemplo n.º 15
0
def test_available_pretrained_weights():
    assert SemanticSegmentation.available_pretrained_weights("resnet18") == [
        "imagenet", "ssl", "swsl"
    ]
Exemplo n.º 16
0
def test_serve():
    model = SemanticSegmentation(2)
    # TODO: Currently only servable once a preprocess has been attached
    model._preprocess = SemanticSegmentationPreprocess()
    model.eval()
    model.serve()
    "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
    "./data",
)

datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    val_split=0.1,
    image_size=(256, 256),
    num_classes=21,
)

# 2. Build the task
model = SemanticSegmentation(
    backbone="mobilenetv3_large_100",
    head="fpn",
    num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Segment a few images!
predictions = model.predict([
    "data/CameraRGB/F61-1.png",
    "data/CameraRGB/F62-1.png",
    "data/CameraRGB/F63-1.png",
])
print(predictions)
    "./data",
)

datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    val_split=0.1,
    transform_kwargs=dict(image_size=(256, 256)),
    num_classes=21,
    batch_size=4,
)

# 2. Build the task
model = SemanticSegmentation(
    backbone="mobilenetv3_large_100",
    head="fpn",
    num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Segment a few images!
datamodule = SemanticSegmentationData.from_files(
    predict_files=[
        "data/CameraRGB/F61-1.png",
        "data/CameraRGB/F62-1.png",
        "data/CameraRGB/F63-1.png",
    ],
    batch_size=3,
Exemplo n.º 19
0
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.image import SemanticSegmentation

model = SemanticSegmentation.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt")
model.serve()
Exemplo n.º 20
0
def test_load_from_checkpoint_dependency_error():
    with pytest.raises(ModuleNotFoundError,
                       match=re.escape("'lightning-flash[image]'")):
        SemanticSegmentation.load_from_checkpoint("not_a_real_checkpoint.pt")
Exemplo n.º 21
0
    def test_map_labels(tmpdir):
        tmp_dir = Path(tmpdir)

        # create random dummy data

        images = [
            str(tmp_dir / "img1.png"),
            str(tmp_dir / "img2.png"),
            str(tmp_dir / "img3.png"),
        ]

        targets = [
            str(tmp_dir / "labels_img1.png"),
            str(tmp_dir / "labels_img2.png"),
            str(tmp_dir / "labels_img3.png"),
        ]

        labels_map: Dict[int, Tuple[int, int, int]] = {
            0: [0, 0, 0],
            1: [255, 255, 255],
        }

        num_classes: int = len(labels_map.keys())
        img_size: Tuple[int, int] = (128, 128)
        create_random_data(images, targets, img_size, num_classes)

        # instantiate the data module

        dm = SemanticSegmentationData.from_files(
            train_files=images,
            train_targets=targets,
            val_files=images,
            val_targets=targets,
            batch_size=2,
            num_workers=0,
            num_classes=num_classes,
        )
        assert dm is not None
        assert dm.train_dataloader() is not None

        # disable visualisation for testing
        assert dm.data_fetcher.block_viz_window is True
        dm.set_block_viz_window(False)
        assert dm.data_fetcher.block_viz_window is False

        dm.show_train_batch("load_sample")
        dm.show_train_batch("to_tensor_transform")

        # check training data
        data = next(iter(dm.train_dataloader()))
        imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
        assert imgs.shape == (2, 3, 128, 128)
        assert labels.shape == (2, 128, 128)
        assert labels.min().item() == 0
        assert labels.max().item() == 1
        assert labels.dtype == torch.int64

        # now train with `fast_dev_run`
        model = SemanticSegmentation(num_classes=2, backbone="resnet50", head="fpn")
        trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
        trainer.finetune(model, dm, strategy="freeze_unfreeze")
Exemplo n.º 22
0
def test_serve():
    model = SemanticSegmentation(2)
    model.eval()
    model.serve()