Example #1
0
def test_load_from_checkpoint_dependency_error():
    with pytest.raises(ModuleNotFoundError,
                       match=re.escape("'lightning-flash[image]'")):
        StyleTransfer.load_from_checkpoint("not_a_real_checkpoint.pt")
    def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices,
                           batch, batch_idx, dataloader_idx) -> None:
        """
        Implement the logic to save a given batch of predictions.
        torch.save({"preds": prediction, "batch_indices": batch_indices}, "prediction_{batch_idx}.pt")
        """


# 1. Download the data
download_data(
    "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip",
    "data/")

# 2. Load the model from a checkpoint
model = StyleTransfer.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/style_transfer_model.pt")

# 3. Generate predictions (stylize images) for the whole folder!
datamodule = StyleTransferData.from_folders(
    predict_folder="data/coco128/images/train2017", batch_size=4)

trainer = flash.Trainer(max_epochs=2,
                        callbacks=StyleTransferWriter(),
                        limit_predict_batches=1)
predictions = trainer.predict(model, datamodule=datamodule)

# 4. Display the first stylized image
image_prediction = torch.stack(predictions[0])[0].numpy()

if _MATPLOTLIB_AVAILABLE and not flash._IS_TESTING:
    import matplotlib.pyplot as plt