def test_predict_numpy(): img = np.ones((1, 3, 64, 64)) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], list) assert len(out[0]) == 64 assert len(out[0][0]) == 64
def test_predict_numpy(): img = np.ones((1, 3, 10, 20)) model = SemanticSegmentation(2) data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess( num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (10, 20)
def test_predict_numpy(): img = np.ones((1, 3, 10, 20)) model = SemanticSegmentation(2) data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess( num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], list) assert len(out[0]) == 10 assert len(out[0][0]) == 20
def test_predict_tensor(): img = torch.rand(1, 3, 10, 20) model = SemanticSegmentation(2) data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess( num_classes=1)) out = model.predict(img, data_source="tensors", data_pipeline=data_pipe) assert isinstance(out[0], list) assert len(out[0]) == 10 assert len(out[0][0]) == 20
# 2.2 Visualise the samples datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) # 3.a List available backbones and heads print(f"Backbones: {SemanticSegmentation.available_backbones()}") print(f"Heads: {SemanticSegmentation.available_heads()}") # 3.b Build the model model = SemanticSegmentation( backbone="mobilenet_v3_large", head="fcn", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=False), ) # 4. Create the trainer. trainer = flash.Trainer(fast_dev_run=True) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 6. Segment a few images! predictions = model.predict([ "data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png", "data/CameraRGB/F63-1.png", ]) # 7. Save it! trainer.save_checkpoint("semantic_segmentation_model.pt")