Пример #1
0
def test_jit(tmpdir):
    path = os.path.join(tmpdir, "test.pt")

    model = StyleTransfer()
    model.eval()

    model = torch.jit.trace(model, torch.rand(1, 3, 32, 32))  # torch.jit.script doesn't work with pystiche

    torch.jit.save(model, path)
    model = torch.jit.load(path)

    out = model(torch.rand(1, 3, 32, 32))
    assert isinstance(out, torch.Tensor)
    assert out.shape == torch.Size([1, 3, 32, 32])
Пример #2
0
def test_style_transfer_task():

    model = StyleTransfer(
        backbone="vgg11", content_layer="relu1_2", content_weight=10, style_layers="relu1_2", style_weight=11
    )
    assert model.perceptual_loss.content_loss.encoder.layer == "relu1_2"
    assert model.perceptual_loss.content_loss.score_weight == 10
    assert "relu1_2" in [n for n, m in model.perceptual_loss.style_loss.named_modules()]
    assert model.perceptual_loss.style_loss.score_weight == 11
Пример #3
0
def test_style_transfer_task_import():
    with pytest.raises(ModuleNotFoundError, match="[image_style_transfer]"):
        StyleTransfer()
Пример #4
0
def test_load_from_checkpoint_dependency_error():
    with pytest.raises(ModuleNotFoundError,
                       match=re.escape("'lightning-flash[image]'")):
        StyleTransfer.load_from_checkpoint("not_a_real_checkpoint.pt")
import torch

import flash
from flash.core.data.utils import download_data
from flash.image.style_transfer import StyleTransfer, StyleTransferData

# 1. Create the DataModule
download_data(
    "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip",
    "./data")

datamodule = StyleTransferData.from_folders(
    train_folder="data/coco128/images/train2017", batch_size=1)

# 2. Build the task
model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"))

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Apply style transfer to a few images!
datamodule = StyleTransferData.from_files(
    predict_files=[
        "data/coco128/images/train2017/000000000625.jpg",
        "data/coco128/images/train2017/000000000626.jpg",
        "data/coco128/images/train2017/000000000629.jpg",
    ],
    batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule)
Пример #6
0
import torch

import flash
from flash.core.data.utils import download_data
from flash.image.style_transfer import StyleTransfer, StyleTransferData

# 1. Create the DataModule
download_data(
    "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip",
    "./data")

datamodule = StyleTransferData.from_folders(
    train_folder="data/coco128/images/train2017")

# 2. Build the task
model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"))

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Apply style transfer to a few images!
predictions = model.predict([
    "data/coco128/images/train2017/000000000625.jpg",
    "data/coco128/images/train2017/000000000626.jpg",
    "data/coco128/images/train2017/000000000629.jpg",
])
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("style_transfer_model.pt")
Пример #7
0
    import pystiche.demo

    from flash.image.style_transfer import StyleTransfer, StyleTransferData
else:
    print("Please, run `pip install pystiche`")
    sys.exit(1)

# 1. Download the data
download_data(
    "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip",
    "data/")

# 2. Load the data
datamodule = StyleTransferData.from_folders(train_folder="data/coco128/images",
                                            batch_size=4)

# 3. Load the style image
style_image = pystiche.demo.images()["paint"].read(size=256)

# 4. Build the model
model = StyleTransfer(style_image)

# 5. Create the trainer
trainer = flash.Trainer(max_epochs=2)

# 6. Train the model
trainer.fit(model, datamodule=datamodule)

# 7. Save it!
trainer.save_checkpoint("style_transfer_model.pt")
Пример #8
0
    def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices,
                           batch, batch_idx, dataloader_idx) -> None:
        """
        Implement the logic to save a given batch of predictions.
        torch.save({"preds": prediction, "batch_indices": batch_indices}, "prediction_{batch_idx}.pt")
        """


# 1. Download the data
download_data(
    "https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip",
    "data/")

# 2. Load the model from a checkpoint
model = StyleTransfer.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/style_transfer_model.pt")

# 3. Generate predictions (stylize images) for the whole folder!
datamodule = StyleTransferData.from_folders(
    predict_folder="data/coco128/images/train2017", batch_size=4)

trainer = flash.Trainer(max_epochs=2,
                        callbacks=StyleTransferWriter(),
                        limit_predict_batches=1)
predictions = trainer.predict(model, datamodule=datamodule)

# 4. Display the first stylized image
image_prediction = torch.stack(predictions[0])[0].numpy()

if _MATPLOTLIB_AVAILABLE and not flash._IS_TESTING:
    import matplotlib.pyplot as plt