def test_jit(tmpdir): path = os.path.join(tmpdir, "test.pt") model = StyleTransfer() model.eval() model = torch.jit.trace(model, torch.rand(1, 3, 32, 32)) # torch.jit.script doesn't work with pystiche torch.jit.save(model, path) model = torch.jit.load(path) out = model(torch.rand(1, 3, 32, 32)) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([1, 3, 32, 32])
def test_style_transfer_task(): model = StyleTransfer( backbone="vgg11", content_layer="relu1_2", content_weight=10, style_layers="relu1_2", style_weight=11 ) assert model.perceptual_loss.content_loss.encoder.layer == "relu1_2" assert model.perceptual_loss.content_loss.score_weight == 10 assert "relu1_2" in [n for n, m in model.perceptual_loss.style_loss.named_modules()] assert model.perceptual_loss.style_loss.score_weight == 11
def test_style_transfer_task_import(): with pytest.raises(ModuleNotFoundError, match="[image_style_transfer]"): StyleTransfer()
def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")): StyleTransfer.load_from_checkpoint("not_a_real_checkpoint.pt")
import torch import flash from flash.core.data.utils import download_data from flash.image.style_transfer import StyleTransfer, StyleTransferData # 1. Create the DataModule download_data( "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "./data") datamodule = StyleTransferData.from_folders( train_folder="data/coco128/images/train2017", batch_size=1) # 2. Build the task model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg")) # 3. Create the trainer and train the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.fit(model, datamodule=datamodule) # 4. Apply style transfer to a few images! datamodule = StyleTransferData.from_files( predict_files=[ "data/coco128/images/train2017/000000000625.jpg", "data/coco128/images/train2017/000000000626.jpg", "data/coco128/images/train2017/000000000629.jpg", ], batch_size=3, ) predictions = trainer.predict(model, datamodule=datamodule)
import torch import flash from flash.core.data.utils import download_data from flash.image.style_transfer import StyleTransfer, StyleTransferData # 1. Create the DataModule download_data( "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "./data") datamodule = StyleTransferData.from_folders( train_folder="data/coco128/images/train2017") # 2. Build the task model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg")) # 3. Create the trainer and train the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.fit(model, datamodule=datamodule) # 4. Apply style transfer to a few images! predictions = model.predict([ "data/coco128/images/train2017/000000000625.jpg", "data/coco128/images/train2017/000000000626.jpg", "data/coco128/images/train2017/000000000629.jpg", ]) print(predictions) # 5. Save the model! trainer.save_checkpoint("style_transfer_model.pt")
import pystiche.demo from flash.image.style_transfer import StyleTransfer, StyleTransferData else: print("Please, run `pip install pystiche`") sys.exit(1) # 1. Download the data download_data( "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") # 2. Load the data datamodule = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4) # 3. Load the style image style_image = pystiche.demo.images()["paint"].read(size=256) # 4. Build the model model = StyleTransfer(style_image) # 5. Create the trainer trainer = flash.Trainer(max_epochs=2) # 6. Train the model trainer.fit(model, datamodule=datamodule) # 7. Save it! trainer.save_checkpoint("style_transfer_model.pt")
def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx) -> None: """ Implement the logic to save a given batch of predictions. torch.save({"preds": prediction, "batch_indices": batch_indices}, "prediction_{batch_idx}.pt") """ # 1. Download the data download_data( "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") # 2. Load the model from a checkpoint model = StyleTransfer.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/style_transfer_model.pt") # 3. Generate predictions (stylize images) for the whole folder! datamodule = StyleTransferData.from_folders( predict_folder="data/coco128/images/train2017", batch_size=4) trainer = flash.Trainer(max_epochs=2, callbacks=StyleTransferWriter(), limit_predict_batches=1) predictions = trainer.predict(model, datamodule=datamodule) # 4. Display the first stylized image image_prediction = torch.stack(predictions[0])[0].numpy() if _MATPLOTLIB_AVAILABLE and not flash._IS_TESTING: import matplotlib.pyplot as plt