Esempio n. 1
0
def test_classification_task_predict_folder_path(tmpdir):
    train_dir = Path(tmpdir / "train")
    train_dir.mkdir()

    def _rand_image():
        return Image.fromarray(np.random.randint(0, 255, (256, 256, 3), dtype="uint8"))

    _rand_image().save(train_dir / "1.png")
    _rand_image().save(train_dir / "2.png")

    datamodule = ImageClassificationData.from_folders(predict_folder=train_dir)

    task = ImageClassifier(num_classes=10)
    predictions = task.predict(str(train_dir), data_pipeline=datamodule.data_pipeline)
    assert len(predictions) == 2
Esempio n. 2
0
def test_classification(tmpdir):
    tmpdir = Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "b").mkdir()
    _rand_image().save(tmpdir / "a" / "a_1.png")
    _rand_image().save(tmpdir / "b" / "a_1.png")

    data = ImageClassificationData.from_filepaths(
        train_filepaths=[tmpdir / "a", tmpdir / "b"],
        train_labels=[0, 1],
        train_transform={"per_batch_transform": lambda x: x},
        num_workers=0,
        batch_size=2,
    )
    model = ImageClassifier(num_classes=2, backbone="resnet18")
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.finetune(model, datamodule=data, strategy="freeze")
Esempio n. 3
0
def test_from_filepaths_list_image_paths(tmpdir):
    tmpdir = Path(tmpdir)

    (tmpdir / "e").mkdir()
    _rand_image().save(tmpdir / "e_1.png")

    train_images = [
        str(tmpdir / "e_1.png"),
        str(tmpdir / "e_1.png"),
        str(tmpdir / "e_1.png"),
    ]

    img_data = ImageClassificationData.from_filepaths(
        train_filepaths=train_images,
        train_labels=[0, 3, 6],
        val_filepaths=train_images,
        val_labels=[1, 4, 7],
        test_filepaths=train_images,
        test_labels=[2, 5, 8],
        batch_size=2,
        num_workers=0,
    )

    # check training data
    data = next(iter(img_data.train_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, )
    assert labels.numpy()[0] in [0, 3, 6]  # data comes shuffled here
    assert labels.numpy()[1] in [0, 3, 6]  # data comes shuffled here

    # check validation data
    data = next(iter(img_data.val_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, )
    assert list(labels.numpy()) == [1, 4]

    # check test data
    data = next(iter(img_data.test_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, )
    assert list(labels.numpy()) == [2, 5]
Esempio n. 4
0
def test_classification(tmpdir):
    tmpdir = Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "b").mkdir()

    image_a = str(tmpdir / "a" / "a_1.png")
    image_b = str(tmpdir / "b" / "b_1.png")

    _rand_image().save(image_a)
    _rand_image().save(image_b)

    data = ImageClassificationData.from_files(
        train_files=[image_a, image_b],
        train_targets=[0, 1],
        num_workers=0,
        batch_size=2,
        image_size=(64, 64),
    )
    model = ImageClassifier(num_classes=2, backbone="resnet18")
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.finetune(model, datamodule=data, strategy="freeze")
Esempio n. 5
0
def test_from_filepaths_multilabel(tmpdir):
    tmpdir = Path(tmpdir)

    (tmpdir / "a").mkdir()
    _rand_image().save(tmpdir / "a1.png")
    _rand_image().save(tmpdir / "a2.png")

    train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")]
    train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]]
    valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]]
    test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]]

    dm = ImageClassificationData.from_filepaths(
        train_filepaths=train_images,
        train_labels=train_labels,
        val_filepaths=train_images,
        val_labels=valid_labels,
        test_filepaths=train_images,
        test_labels=test_labels,
        batch_size=2,
        num_workers=0,
    )

    data = next(iter(dm.train_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, 4)

    data = next(iter(dm.val_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, 4)
    torch.testing.assert_allclose(labels, torch.tensor(valid_labels))

    data = next(iter(dm.test_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, 4)
    torch.testing.assert_allclose(labels, torch.tensor(test_labels))
Esempio n. 6
0
def test_from_filepaths_smoke(tmpdir):
    tmpdir = Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "b").mkdir()
    _rand_image().save(tmpdir / "a_1.png")
    _rand_image().save(tmpdir / "b_1.png")

    img_data = ImageClassificationData.from_filepaths(
        train_filepaths=[tmpdir / "a_1.png", tmpdir / "b_1.png"],
        train_labels=[1, 2],
        batch_size=2,
        num_workers=0,
    )
    assert img_data.train_dataloader() is not None
    assert img_data.val_dataloader() is None
    assert img_data.test_dataloader() is None

    data = next(iter(img_data.train_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, )
    assert sorted(list(labels.numpy())) == [1, 2]
Esempio n. 7
0
def test_from_folders_only_train(tmpdir):
    train_dir = Path(tmpdir / "train")
    train_dir.mkdir()

    (train_dir / "a").mkdir()
    _rand_image().save(train_dir / "a" / "1.png")
    _rand_image().save(train_dir / "a" / "2.png")

    (train_dir / "b").mkdir()
    _rand_image().save(train_dir / "b" / "1.png")
    _rand_image().save(train_dir / "b" / "2.png")

    img_data = ImageClassificationData.from_folders(train_dir,
                                                    train_transform=None,
                                                    batch_size=1)

    data = next(iter(img_data.train_dataloader()))
    imgs, labels = data
    assert imgs.shape == (1, 3, 196, 196)
    assert labels.shape == (1, )

    assert img_data.val_dataloader() is None
    assert img_data.test_dataloader() is None
Esempio n. 8
0
def test_from_folders_train_val(tmpdir):

    train_dir = Path(tmpdir / "train")
    train_dir.mkdir()

    (train_dir / "a").mkdir()
    _rand_image().save(train_dir / "a" / "1.png")
    _rand_image().save(train_dir / "a" / "2.png")

    (train_dir / "b").mkdir()
    _rand_image().save(train_dir / "b" / "1.png")
    _rand_image().save(train_dir / "b" / "2.png")
    img_data = ImageClassificationData.from_folders(
        train_dir,
        val_folder=train_dir,
        test_folder=train_dir,
        batch_size=2,
        num_workers=0,
    )

    data = next(iter(img_data.train_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, )

    data = next(iter(img_data.val_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, )
    assert list(labels.numpy()) == [0, 0]

    data = next(iter(img_data.test_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, )
    assert list(labels.numpy()) == [0, 0]
import flash
from flash.core.data import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.vision import ImageClassificationData, ImageClassifier

if __name__ == "__main__":

    # 1. Download the data
    download_data(
        "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

    # 2. Load the data
    datamodule = ImageClassificationData.from_folders(
        train_folder="data/hymenoptera_data/train/",
        valid_folder="data/hymenoptera_data/val/",
        test_folder="data/hymenoptera_data/test/",
    )

    # 3. Build the model
    model = ImageClassifier(num_classes=datamodule.num_classes)

    # 4. Create the trainer. Run twice on data
    trainer = flash.Trainer(max_epochs=2)

    # 5. Train the model
    trainer.finetune(model,
                     datamodule=datamodule,
                     strategy=FreezeUnfreeze(unfreeze_epoch=1))

    # 6. Test the model
    trainer.test()
Esempio n. 10
0
def test_categorical_csv_labels(tmpdir):
    train_dir = Path(tmpdir / "some_dataset")
    train_dir.mkdir()

    (train_dir / "train").mkdir()
    _rand_image().save(train_dir / "train" / "train_1.png")
    _rand_image().save(train_dir / "train" / "train_2.png")

    (train_dir / "valid").mkdir()
    _rand_image().save(train_dir / "valid" / "val_1.png")
    _rand_image().save(train_dir / "valid" / "val_2.png")

    (train_dir / "test").mkdir()
    _rand_image().save(train_dir / "test" / "test_1.png")
    _rand_image().save(train_dir / "test" / "test_2.png")

    train_csv = os.path.join(tmpdir, 'some_dataset', 'train.csv')
    text_file = open(train_csv, 'w')
    text_file.write(
        'my_id,label_a,label_b,label_c\n"train_1.png", 0, 1, 0\n"train_2.png", 0, 0, 1\n"train_2.png", 1, 0, 0\n'
    )
    text_file.close()

    val_csv = os.path.join(tmpdir, 'some_dataset', 'valid.csv')
    text_file = open(val_csv, 'w')
    text_file.write(
        'my_id,label_a,label_b,label_c\n"val_1.png", 0, 1, 0\n"val_2.png", 0, 0, 1\n"val_3.png", 1, 0, 0\n'
    )
    text_file.close()

    test_csv = os.path.join(tmpdir, 'some_dataset', 'test.csv')
    text_file = open(test_csv, 'w')
    text_file.write(
        'my_id,label_a,label_b,label_c\n"test_1.png", 0, 1, 0\n"test_2.png", 0, 0, 1\n"test_3.png", 1, 0, 0\n'
    )
    text_file.close()

    def index_col_collate_fn(x):
        return os.path.splitext(x)[0]

    train_labels = labels_from_categorical_csv(
        train_csv,
        'my_id',
        feature_cols=['label_a', 'label_b', 'label_c'],
        index_col_collate_fn=index_col_collate_fn)
    val_labels = labels_from_categorical_csv(
        val_csv,
        'my_id',
        feature_cols=['label_a', 'label_b', 'label_c'],
        index_col_collate_fn=index_col_collate_fn)
    test_labels = labels_from_categorical_csv(
        test_csv,
        'my_id',
        feature_cols=['label_a', 'label_b', 'label_c'],
        index_col_collate_fn=index_col_collate_fn)
    data = ImageClassificationData.from_filepaths(
        batch_size=2,
        train_transform=None,
        val_transform=None,
        test_transform=None,
        train_filepaths=os.path.join(tmpdir, 'some_dataset', 'train'),
        train_labels=train_labels.values(),
        val_filepaths=os.path.join(tmpdir, 'some_dataset', 'valid'),
        val_labels=val_labels.values(),
        test_filepaths=os.path.join(tmpdir, 'some_dataset', 'test'),
        test_labels=test_labels.values(),
    )

    for (x, y) in data.train_dataloader():
        assert len(x) == 2

    for (x, y) in data.val_dataloader():
        assert len(x) == 2

    for (x, y) in data.test_dataloader():
        assert len(x) == 2
Esempio n. 11
0
def test_from_filepaths(tmpdir):
    tmpdir = Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "b").mkdir()
    _rand_image().save(tmpdir / "a" / "a_1.png")
    _rand_image().save(tmpdir / "a" / "a_2.png")

    _rand_image().save(tmpdir / "b" / "a_1.png")
    _rand_image().save(tmpdir / "b" / "a_2.png")

    img_data = ImageClassificationData.from_filepaths(
        train_filepaths=[tmpdir / "a", tmpdir / "b"],
        train_transform=None,
        train_labels=[0, 1],
        batch_size=2,
        num_workers=0,
    )
    data = next(iter(img_data.train_dataloader()))
    imgs, labels = data
    assert imgs.shape == (2, 3, 196, 196)
    assert labels.shape == (2, )

    assert img_data.val_dataloader() is None
    assert img_data.test_dataloader() is None

    (tmpdir / "c").mkdir()
    (tmpdir / "d").mkdir()
    _rand_image().save(tmpdir / "c" / "c_1.png")
    _rand_image().save(tmpdir / "c" / "c_2.png")
    _rand_image().save(tmpdir / "d" / "d_1.png")
    _rand_image().save(tmpdir / "d" / "d_2.png")

    (tmpdir / "e").mkdir()
    (tmpdir / "f").mkdir()
    _rand_image().save(tmpdir / "e" / "e_1.png")
    _rand_image().save(tmpdir / "e" / "e_2.png")
    _rand_image().save(tmpdir / "f" / "f_1.png")
    _rand_image().save(tmpdir / "f" / "f_2.png")

    img_data = ImageClassificationData.from_filepaths(
        train_filepaths=[tmpdir / "a", tmpdir / "b"],
        train_labels=[0, 1],
        train_transform=None,
        val_filepaths=[tmpdir / "c", tmpdir / "d"],
        val_labels=[0, 1],
        val_transform=None,
        test_transform=None,
        test_filepaths=[tmpdir / "e", tmpdir / "f"],
        test_labels=[0, 1],
        batch_size=1,
        num_workers=0,
    )

    data = next(iter(img_data.val_dataloader()))
    imgs, labels = data
    assert imgs.shape == (1, 3, 196, 196)
    assert labels.shape == (1, )

    data = next(iter(img_data.test_dataloader()))
    imgs, labels = data
    assert imgs.shape == (1, 3, 196, 196)
    assert labels.shape == (1, )
        root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]:
    metadata = pd.read_csv(osp.join(root, data, "metadata.csv"))
    return ([
        osp.join(root, data, row['Id'] + ".jpg")
        for _, row in metadata.iterrows()
    ], [[int(row[genre]) for genre in genres]
        for _, row in metadata.iterrows()])


train_files, train_targets = load_data('train')
test_files, test_targets = load_data('test')

datamodule = ImageClassificationData.from_files(
    train_files=train_files,
    train_targets=train_targets,
    test_files=test_files,
    test_targets=test_targets,
    val_split=0.1,  # Use 10 % of the train dataset to generate validation one.
    image_size=(128, 128),
)

# 3. Build the model
model = ImageClassifier(
    backbone="resnet18",
    num_classes=len(genres),
    multi_label=True,
    metrics=F1(num_classes=len(genres)),
)

# 4. Create the trainer. Train on 2 gpus for 10 epochs.
trainer = flash.Trainer(max_epochs=1,
                        limit_train_batches=1,
Esempio n. 13
0
        image = make_grid(images, nrow=2)
        image = T.to_pil_image(image, 'RGB')
        image.show()


# 3. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/image_classification_multi_label_model.pt",
)

# 4a. Predict the genres of a few movie posters!
predictions = model.predict([
    "data/movie_posters/predict/tt0085318.jpg",
    "data/movie_posters/predict/tt0089461.jpg",
    "data/movie_posters/predict/tt0097179.jpg",
])
print(predictions)

# 4b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folders(
    predict_folder="data/movie_posters/predict/",
    data_fetcher=CustomViz(),
    image_size=(128, 128),
)

predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)

# 5. Show some data (unless we're just testing)!
datamodule.show_predict_batch("per_batch_transform")
        image = make_grid(images, nrow=2)
        image = T.to_pil_image(image, 'RGB')
        image.show()


# 3. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/image_classification_multi_label_model.pt",
)

# 4a. Predict the genres of a few movie posters!
predictions = model.predict([
    "data/movie_posters/predict/tt0085318.jpg",
    "data/movie_posters/predict/tt0089461.jpg",
    "data/movie_posters/predict/tt0097179.jpg",
])
print(predictions)

# 4b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folders(
    predict_folder="data/movie_posters/predict/",
    data_fetcher=CustomViz(),
    preprocess=model.preprocess,
)

predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)

# 5. Show some data (unless we're just testing)!
datamodule.show_predict_batch("per_batch_transform")
Esempio n. 15
0
from flash import Trainer
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier


# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")

# 3a. Predict what's on a few images! ants or bees?
predictions = model.predict([
    "data/hymenoptera_data/test/ants/8124241_36b290d372.jpg",
    "data/hymenoptera_data/test/ants/147542264_79506478c2.jpg",
    "data/hymenoptera_data/test/ants/212100470_b485e7b7b9.jpg",
])
print(predictions)

# 3b. Generate predictions with a whole folder
datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/test/ants/")
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
import flash
from flash import Trainer
from flash.core.classification import Labels
from flash.core.finetuning import FreezeUnfreeze
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              "data/")

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)


# 3.a Optional: Register a custom backbone
# This is useful to create new backbone and make them accessible from `ImageClassifier`
@ImageClassifier.backbones(name="resnet18")
def fn_resnet(pretrained: bool = True):
    model = torchvision.models.resnet18(pretrained)
    # remove the last two layers & turn it into a Sequential model
    backbone = nn.Sequential(*list(model.children())[:-2])
    num_features = model.fc.in_features
    # backbones need to return the num_features to build the head
    return backbone, num_features
genres = ["Action", "Romance", "Crime", "Thriller", "Adventure"]


def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]:
    metadata = pd.read_csv(osp.join(root, data, "metadata.csv"))
    return ([osp.join(root, data, row['Id'] + ".jpg") for _, row in metadata.iterrows()],
            [[int(row[genre]) for genre in genres] for _, row in metadata.iterrows()])


train_filepaths, train_labels = load_data('train')
test_filepaths, test_labels = load_data('test')

datamodule = ImageClassificationData.from_filepaths(
    train_filepaths=train_filepaths,
    train_labels=train_labels,
    test_filepaths=test_filepaths,
    test_labels=test_labels,
    preprocess=ImageClassificationPreprocess(image_size=(128, 128)),
    val_split=0.1,  # Use 10 % of the train dataset to generate validation one.
)

# 3. Build the model
model = ImageClassifier(
    backbone="resnet18",
    num_classes=len(genres),
    multi_label=True,
    metrics=F1(num_classes=len(genres)),
)

# 4. Create the trainer. Train on 2 gpus for 10 epochs.
trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)
Esempio n. 18
0
import flash

from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
    backbone="resnet34",
    num_workers=8,
    train_folder="data/hymenoptera_data/train/",
    valid_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes,
                        backbone="resnet18")

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=4)

# 5. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")

# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")
Esempio n. 19
0
# See the License for the specific language governing permissions and
# limitations under the License.
from flash import Trainer
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3a. Predict what's on a few images! ants or bees?
predictions = model.predict([
    "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
    "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
    "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)

# 3b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folders(
    predict_folder="data/hymenoptera_data/predict/",
    preprocess=model.preprocess,
)

predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
Esempio n. 20
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash import Trainer
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip",
              "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3a. Predict what's on a few images! ants or bees?
predictions = model.predict([
    "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
    "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
    "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)

# 3b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folders(
    predict_folder="data/hymenoptera_data/predict/")

predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
Esempio n. 21
0
def cli_main():
    pl.seed_everything(1234)

    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='cifar5')
    #parser.add_argument('--max_epochs', type=int, default=2)

    # add trainer args (gpus=x, precision=...)
    parser = pl.Trainer.add_argparse_args(parser)

    # add model args (batch_size hidden_dim, etc...), anything defined in add_model_specific_args
    parser = LitClassifier.add_model_specific_args(parser)
    args = parser.parse_args()
    print(args)

    # ------------
    # data
    # ------------
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4913, 0.482, 0.446], std=[0.247, 0.243, 0.261])
    ])

    # in real life you would have a separate validation split
    datamodule = ImageClassificationData.from_folders(
        train_folder=args.data_dir + '/train',
        valid_folder=args.data_dir + '/test',
        test_folder=args.data_dir + '/test',
        batch_size=args.batch_size,
        transform=transform
    )

    # ------------
    # model
    # ------------
    model = LitClassifier(
        backbone=args.backbone,
        learning_rate=args.learning_rate,
        hidden_dim=args.hidden_dim
    )

    # ------------
    # training
    # ------------
    print('training')
    trainer = pl.Trainer.from_argparse_args(args) #, fast_dev_run=True)
    trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())
    
    # ------------
    # testing
    # ------------
    print('testing')
    result = trainer.test(model, test_dataloaders=datamodule.test_dataloader())
    print(result)

    # predicting
    print('predicting')
    preds = trainer.predict(model, datamodule.test_dataloader())
    #import pdb; pdb.set_trace()
    #print(preds) # list of n=N/B tensors, each of size B=batchsize=32.
    #preds = list(np.stack(preds).flatten()) # fails on last batch, which is shorter

    path = os.getcwd() + '/predictions.txt'
    with open(path, 'w') as f:
        preds_str = [str(x) for lst in preds for x in lst]
        f.write('\n'.join(preds_str))