Example #1
0
def test_predict_numpy():
    img = np.ones((1, 3, 64, 64))
    model = SemanticSegmentation(2, backbone="mobilenetv3_large_100")
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], list)
    assert len(out[0]) == 64
    assert len(out[0][0]) == 64
def test_predict_numpy():
    img = np.ones((1, 3, 10, 20))
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(
        num_classes=1))
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], torch.Tensor)
    assert out[0].shape == (10, 20)
Example #3
0
def test_predict_numpy():
    img = np.ones((1, 3, 10, 20))
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(
        num_classes=1))
    out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
    assert isinstance(out[0], list)
    assert len(out[0]) == 10
    assert len(out[0][0]) == 20
Example #4
0
def test_predict_tensor():
    img = torch.rand(1, 3, 10, 20)
    model = SemanticSegmentation(2)
    data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(
        num_classes=1))
    out = model.predict(img, data_source="tensors", data_pipeline=data_pipe)
    assert isinstance(out[0], list)
    assert len(out[0]) == 10
    assert len(out[0][0]) == 20
# 2.2 Visualise the samples
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3.a List available backbones and heads
print(f"Backbones: {SemanticSegmentation.available_backbones()}")
print(f"Heads: {SemanticSegmentation.available_heads()}")

# 3.b Build the model
model = SemanticSegmentation(
    backbone="mobilenet_v3_large",
    head="fcn",
    num_classes=datamodule.num_classes,
    serializer=SegmentationLabels(visualize=False),
)

# 4. Create the trainer.
trainer = flash.Trainer(fast_dev_run=True)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Segment a few images!
predictions = model.predict([
    "data/CameraRGB/F61-1.png",
    "data/CameraRGB/F62-1.png",
    "data/CameraRGB/F63-1.png",
])

# 7. Save it!
trainer.save_checkpoint("semantic_segmentation_model.pt")