コード例 #1
0
def test_classification_task_predict_folder_path(tmpdir):
    train_dir = Path(tmpdir / "train")
    train_dir.mkdir()

    _rand_image().save(train_dir / "1.png")
    _rand_image().save(train_dir / "2.png")

    task = ImageClassifier(num_classes=10)
    predictions = task.predict(str(train_dir))
    assert len(predictions) == 2
コード例 #2
0
def test_classification_task_predict_folder_path(tmpdir):
    train_dir = Path(tmpdir / "train")
    train_dir.mkdir()

    def _rand_image():
        return Image.fromarray(np.random.randint(0, 255, (256, 256, 3), dtype="uint8"))

    _rand_image().save(train_dir / "1.png")
    _rand_image().save(train_dir / "2.png")

    datamodule = ImageClassificationData.from_folders(predict_folder=train_dir)

    task = ImageClassifier(num_classes=10)
    predictions = task.predict(str(train_dir), data_pipeline=datamodule.data_pipeline)
    assert len(predictions) == 2
コード例 #3
0
def test_multilabel(tmpdir):

    num_classes = 4
    ds = DummyMultiLabelDataset(num_classes)
    model = ImageClassifier(num_classes,
                            multi_label=True,
                            serializer=Probabilities(multi_label=True))
    train_dl = torch.utils.data.DataLoader(ds, batch_size=2)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
    image, label = ds[0]
    predictions = model.predict(image.unsqueeze(0))
    assert (torch.tensor(predictions) > 1).sum() == 0
    assert (torch.tensor(predictions) < 0).sum() == 0
    assert len(predictions[0]) == num_classes == len(label)
    assert len(torch.unique(label)) <= 2
コード例 #4
0
def test_available_backbones():
    backbones = ImageClassifier.available_backbones()
    assert "resnet152" in backbones

    class Foo(ImageClassifier):
        backbones = None

    assert Foo.available_backbones() == []
コード例 #5
0
def test_classification(tmpdir):
    data = ImageClassificationData.from_filepaths(
        train_filepaths=["a", "b"],
        train_labels=[0, 1],
        train_transform=lambda x: x,
        loader=_dummy_image_loader,
        num_workers=0,
        batch_size=2,
    )
    model = ImageClassifier(2, backbone="resnet18")
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.finetune(model, datamodule=data, strategy="freeze")
コード例 #6
0
def test_classification(tmpdir):
    tmpdir = Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "b").mkdir()
    _rand_image().save(tmpdir / "a" / "a_1.png")
    _rand_image().save(tmpdir / "b" / "a_1.png")

    data = ImageClassificationData.from_filepaths(
        train_filepaths=[tmpdir / "a", tmpdir / "b"],
        train_labels=[0, 1],
        train_transform={"per_batch_transform": lambda x: x},
        num_workers=0,
        batch_size=2,
    )
    model = ImageClassifier(num_classes=2, backbone="resnet18")
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.finetune(model, datamodule=data, strategy="freeze")
コード例 #7
0
def test_classification(tmpdir):
    tmpdir = Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "b").mkdir()

    image_a = str(tmpdir / "a" / "a_1.png")
    image_b = str(tmpdir / "b" / "b_1.png")

    _rand_image().save(image_a)
    _rand_image().save(image_b)

    data = ImageClassificationData.from_files(
        train_files=[image_a, image_b],
        train_targets=[0, 1],
        num_workers=0,
        batch_size=2,
        image_size=(64, 64),
    )
    model = ImageClassifier(num_classes=2, backbone="resnet18")
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.finetune(model, datamodule=data, strategy="freeze")
コード例 #8
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash import Trainer
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3a. Predict what's on a few images! ants or bees?
predictions = model.predict([
    "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
    "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
    "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)

# 3b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folders(
    predict_folder="data/hymenoptera_data/predict/")

predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
コード例 #9
0
if __name__ == "__main__":

    # 1. Download the data
    download_data(
        "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

    # 2. Load the data
    datamodule = ImageClassificationData.from_folders(
        train_folder="data/hymenoptera_data/train/",
        valid_folder="data/hymenoptera_data/val/",
        test_folder="data/hymenoptera_data/test/",
    )

    # 3. Build the model
    model = ImageClassifier(num_classes=datamodule.num_classes)

    # 4. Create the trainer. Run twice on data
    trainer = flash.Trainer(max_epochs=2)

    # 5. Train the model
    trainer.finetune(model,
                     datamodule=datamodule,
                     strategy=FreezeUnfreeze(unfreeze_epoch=1))

    # 6. Test the model
    trainer.test()

    # 7. Save it!
    trainer.save_checkpoint("image_classification_model.pt")
コード例 #10
0

# 3.a Optional: Register a custom backbone
# This is useful to create new backbone and make them accessible from `ImageClassifier`
@ImageClassifier.backbones(name="resnet18")
def fn_resnet(pretrained: bool = True):
    model = torchvision.models.resnet18(pretrained)
    # remove the last two layers & turn it into a Sequential model
    backbone = nn.Sequential(*list(model.children())[:-2])
    num_features = model.fc.in_features
    # backbones need to return the num_features to build the head
    return backbone, num_features


# 3.b Optional: List available backbones
print(ImageClassifier.available_backbones())

# 4. Build the model
model = ImageClassifier(backbone="resnet18",
                        num_classes=datamodule.num_classes)

# 5. Create the trainer.
trainer = flash.Trainer(max_epochs=1,
                        limit_train_batches=1,
                        limit_val_batches=1)

# 6. Train the model
trainer.finetune(model,
                 datamodule=datamodule,
                 strategy=FreezeUnfreeze(unfreeze_epoch=1))
コード例 #11
0
def test_unfreeze():
    model = ImageClassifier(2)
    model.unfreeze()
    for p in model.backbone.parameters():
        assert p.requires_grad is True
コード例 #12
0
def test_non_existent_backbone():
    with pytest.raises(ValueError):
        ImageClassifier(2, "i am never going to implement this lol")
コード例 #13
0
def test_init_train(tmpdir, backbone):
    model = ImageClassifier(10, backbone=backbone)
    train_dl = torch.utils.data.DataLoader(DummyDataset())
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
コード例 #14
0
from flash.core.finetuning import FreezeUnfreeze
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              "data/")

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)
# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer.
trainer = flash.Trainer(max_epochs=1,
                        limit_train_batches=1,
                        limit_val_batches=1)

# 5. Train the model
trainer.finetune(model,
                 datamodule=datamodule,
                 strategy=FreezeUnfreeze(unfreeze_epoch=1))

# 3a. Predict what's on a few images! ants or bees?
predictions = model.predict([
    "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
    "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
コード例 #15
0
def test_non_existent_backbone():
    with pytest.raises(MisconfigurationException):
        ImageClassifier(2, "i am never going to implement this lol")
コード例 #16
0
ファイル: fine_tune.py プロジェクト: zlapp/medium
from flash import Trainer
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier


# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")

# 3a. Predict what's on a few images! ants or bees?
predictions = model.predict([
    "data/hymenoptera_data/test/ants/8124241_36b290d372.jpg",
    "data/hymenoptera_data/test/ants/147542264_79506478c2.jpg",
    "data/hymenoptera_data/test/ants/212100470_b485e7b7b9.jpg",
])
print(predictions)

# 3b. Generate predictions with a whole folder
datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/test/ants/")
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
コード例 #17
0
train_filepaths, train_labels = load_data('train')
test_filepaths, test_labels = load_data('test')

datamodule = ImageClassificationData.from_filepaths(
    train_filepaths=train_filepaths,
    train_labels=train_labels,
    test_filepaths=test_filepaths,
    test_labels=test_labels,
    preprocess=ImageClassificationPreprocess(image_size=(128, 128)),
    val_split=0.1,  # Use 10 % of the train dataset to generate validation one.
)

# 3. Build the model
model = ImageClassifier(
    backbone="resnet18",
    num_classes=len(genres),
    multi_label=True,
    metrics=F1(num_classes=len(genres)),
)

# 4. Create the trainer. Train on 2 gpus for 10 epochs.
trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Predict what's on a few images!

# Serialize predictions as labels, low threshold to see more predictions.
model.serializer = Labels(genres, multi_label=True, threshold=0.25)

predictions = model.predict([
コード例 #18
0
ファイル: predict.py プロジェクト: zlapp/medium
import flash

from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
    backbone="resnet34",
    num_workers=8,
    train_folder="data/hymenoptera_data/train/",
    valid_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes,
                        backbone="resnet18")

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=4)

# 5. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")

# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")