def test_predict_numpy(): img = np.ones((1, 3, 10, 20)) model = SemanticSegmentation(2) data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess()) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (196, 196)
def test_forward(num_classes, img_shape): model = SemanticSegmentation( num_classes=num_classes, backbone='torchvision/fcn_resnet50', ) B, C, H, W = img_shape img = torch.rand(B, C, H, W) out = model(img) assert out.shape == (B, num_classes, H, W)
def test_unfreeze(): model = SemanticSegmentation(2) model.unfreeze() for p in model.backbone.parameters(): assert p.requires_grad is True
def test_non_existent_backbone(): with pytest.raises(KeyError): SemanticSegmentation(2, "i am never going to implement this lol")
def test_init_train(tmpdir, backbone): model = SemanticSegmentation(num_classes=10, backbone=backbone) train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
def test_smoke(): model = SemanticSegmentation(num_classes=1) assert model is not None
datamodule = SemanticSegmentationData.from_folders( train_folder="data/CameraRGB", train_target_folder="data/CameraSeg", batch_size=4, val_split=0.3, image_size=(200, 200), # (600, 800) ) # 2.2 Visualise the samples labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) datamodule.set_labels_map(labels_map) datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) # 3. Build the model model = SemanticSegmentation( backbone="torchvision/fcn_resnet50", num_classes=21, ) # 4. Create the trainer. trainer = flash.Trainer( max_epochs=1, fast_dev_run=1, ) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 6. Predict what's on a few images! model.serializer = SegmentationLabels(labels_map, visualize=True) predictions = model.predict([
# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from flash.data.utils import download_data from flash.vision import SemanticSegmentation from flash.vision.segmentation.serialization import SegmentationLabels # 1. Download the data # This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator. # The data was generated as part of the Lyft Udacity Challenge. # More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge download_data( "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/") # 2. Load the model from a checkpoint model = SemanticSegmentation.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt") model.serializer = SegmentationLabels(visualize=True) # 3. Predict what's on a few images and visualize! predictions = model.predict([ "data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png", "data/CameraRGB/F63-1.png", ])
def test_map_labels(self, tmpdir): tmp_dir = Path(tmpdir) # create random dummy data images = [ str(tmp_dir / "img1.png"), str(tmp_dir / "img2.png"), str(tmp_dir / "img3.png"), ] targets = [ str(tmp_dir / "labels_img1.png"), str(tmp_dir / "labels_img2.png"), str(tmp_dir / "labels_img3.png"), ] labels_map: Dict[int, Tuple[int, int, int]] = { 0: [0, 0, 0], 1: [255, 255, 255], } num_classes: int = len(labels_map.keys()) img_size: Tuple[int, int] = (196, 196) create_random_data(images, targets, img_size, num_classes) # instantiate the data module dm = SemanticSegmentationData.from_files( train_files=images, train_targets=targets, val_files=images, val_targets=targets, batch_size=2, num_workers=0, ) assert dm is not None assert dm.train_dataloader() is not None # disable visualisation for testing assert dm.data_fetcher.block_viz_window is True dm.set_block_viz_window(False) assert dm.data_fetcher.block_viz_window is False dm.set_labels_map(labels_map) dm.show_train_batch("load_sample") dm.show_train_batch("to_tensor_transform") # check training data data = next(iter(dm.train_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[ DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 196, 196) assert labels.min().item() == 0 assert labels.max().item() == 1 assert labels.dtype == torch.int64 # now train with `fast_dev_run` model = SemanticSegmentation(num_classes=2, backbone="torchvision/fcn_resnet50") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, dm, strategy="freeze_unfreeze")