def test_jit(tmpdir, jitter, args):
    path = os.path.join(tmpdir, "test.pt")

    model = ImageEmbedder(embedding_dim=128)
    model.eval()

    model = jitter(model, *args)

    torch.jit.save(model, path)
    model = torch.jit.load(path)

    out = model(torch.rand(1, 3, 32, 32))
    assert isinstance(out, torch.Tensor)
    assert out.shape == torch.Size([1, 128])
def test_vissl_training(backbone, training_strategy, head,
                        pretraining_transform, embedding_size):
    datamodule = ImageClassificationData.from_datasets(
        train_dataset=FakeData(16),
        predict_dataset=FakeData(8),
        batch_size=4,
    )

    embedder = ImageEmbedder(
        backbone=backbone,
        training_strategy=training_strategy,
        head=head,
        pretraining_transform=pretraining_transform,
    )

    trainer = flash.Trainer(
        max_steps=3,
        max_epochs=1,
        gpus=torch.cuda.device_count(),
    )

    trainer.fit(embedder, datamodule=datamodule)
    predictions = trainer.predict(embedder, datamodule=datamodule)
    for prediction_batch in predictions:
        for prediction in prediction_batch:
            assert prediction.size(0) == embedding_size
def test_not_implemented_steps():
    embedder = ImageEmbedder(backbone="resnet18")

    with pytest.raises(NotImplementedError):
        embedder.training_step([], 0)
    with pytest.raises(NotImplementedError):
        embedder.validation_step([], 0)
    with pytest.raises(NotImplementedError):
        embedder.test_step([], 0)
def test_vissl_training_with_wrong_arguments(backbone, training_strategy, head,
                                             pretraining_transform,
                                             expected_exception):
    with pytest.raises(expected_exception):
        ImageEmbedder(
            backbone=backbone,
            training_strategy=training_strategy,
            head=head,
            pretraining_transform=pretraining_transform,
        )
def test_only_embedding(backbone, embedding_size):
    datamodule = ImageClassificationData.from_datasets(
        predict_dataset=FakeData(8),
        batch_size=4,
        transform_kwargs=dict(image_size=(224, 224)),
    )

    embedder = ImageEmbedder(backbone=backbone)
    trainer = flash.Trainer()

    predictions = trainer.predict(embedder, datamodule=datamodule)
    for prediction_batch in predictions:
        for prediction in prediction_batch:
            assert prediction.size(0) == embedding_size
# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")

# 2 Load data into FiftyOne
dataset = fo.Dataset.from_dir(
    "data/hymenoptera_data/test/",
    fo.types.ImageClassificationDirectoryTree,
)
datamodule = ImageClassificationData.from_files(
    predict_files=dataset.values("filepath"),
    batch_size=16,
)

# 3 Load model
embedder = ImageEmbedder(backbone="resnet18")

# 4 Generate embeddings
trainer = flash.Trainer(gpus=torch.cuda.device_count())
embedding_batches = trainer.predict(embedder, datamodule=datamodule)
embeddings = np.stack(sum(embedding_batches, []))

# 5 Visualize in FiftyOne App
results = fob.compute_visualization(dataset, embeddings=embeddings)
session = fo.launch_app(dataset)
plot = results.visualize(labels="ground_truth.label")
plot.show()

# Optional: block execution until App is closed
session.wait()
import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder

# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
    train_dataset=CIFAR10(".", download=True),
    batch_size=4,
)

# 2. Build the task
embedder = ImageEmbedder(
    backbone="resnet18",
    training_strategy="barlow_twins",
    head="barlow_twins_head",
    pretraining_transform="barlow_twins_transform",
    training_strategy_kwargs={"latent_embedding_dim": 128},
    pretraining_transform_kwargs={"size_crops": [32]},
)

# 3. Create the trainer and pre-train the encoder
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)

# 4. Save the model!
trainer.save_checkpoint("image_embedder_model.pt")

# 5. Download the downstream prediction dataset and generate embeddings
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              "data/")
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.data.utils import download_data
from flash.image import ImageEmbedder

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Build the task
embedder = ImageEmbedder(backbone="resnet101")

# 3. Generate an embedding from an image path.
embeddings = embedder.predict(["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"])
print(embeddings)
Beispiel #9
0
import torch

from flash.core.data.utils import download_data
from flash.image import ImageEmbedder

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")

# 2 Load data into FiftyOne
dataset = fo.Dataset.from_dir(
    "data/hymenoptera_data/test/",
    fo.types.ImageClassificationDirectoryTree,
)

# 3 Load model
embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128)

# 4 Generate embeddings
filepaths = dataset.values("filepath")
embeddings = np.stack(embedder.predict(filepaths))

# 5 Visualize in FiftyOne App
results = fob.compute_visualization(dataset, embeddings=embeddings)

session = fo.launch_app(dataset)

plot = results.visualize(labels="ground_truth.label")
plot.show()

# Only when running this in a script
# Block until the FiftyOne App is closed
Beispiel #10
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from flash.core.data.utils import download_data
from flash.image import ImageEmbedder

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              "data/")

# 2. Create an ImageEmbedder with swav trained on imagenet.
# Check out SWAV: https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav
embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128)

# 3. Generate an embedding from an image path.
embeddings = embedder.predict(
    ["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"])

# 4. Print embeddings shape
print(embeddings[0].shape)

# 5. Create a tensor random image
random_image = torch.randn(1, 3, 244, 244)

# 6. Generate an embedding from this random image.
embeddings = embedder.predict(random_image, data_source="tensors")

# 7. Print embeddings shape
Beispiel #11
0
import fiftyone.brain as fob
import numpy as np

from flash.core.data.utils import download_data
from flash.image import ImageEmbedder

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")

# 2 Load data into FiftyOne
dataset = fo.Dataset.from_dir(
    "data/hymenoptera_data/test/",
    fo.types.ImageClassificationDirectoryTree,
)

# 3 Load model
embedder = ImageEmbedder(backbone="resnet101")

# 4 Generate embeddings
filepaths = dataset.values("filepath")
embeddings = np.stack(embedder.predict(filepaths))

# 5 Visualize in FiftyOne App
results = fob.compute_visualization(dataset, embeddings=embeddings)
session = fo.launch_app(dataset)
plot = results.visualize(labels="ground_truth.label")
plot.show()

# Optional: block execution until App is closed
session.wait()
Beispiel #12
0
def test_load_from_checkpoint_dependency_error():
    with pytest.raises(ModuleNotFoundError,
                       match=re.escape("'lightning-flash[image]'")):
        ImageEmbedder.load_from_checkpoint("not_a_real_checkpoint.pt")