def test_from_filepaths_visualise_multilabel(tmpdir): tmpdir = Path(tmpdir) (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() image_a = str(tmpdir / "a" / "a_1.png") image_b = str(tmpdir / "b" / "b_1.png") _rand_image().save(image_a) _rand_image().save(image_b) dm = ImageClassificationData.from_files( train_files=[image_a, image_b], train_targets=[[0, 1, 0], [0, 1, 1]], val_files=[image_b, image_a], val_targets=[[1, 1, 0], [0, 0, 1]], test_files=[image_b, image_b], test_targets=[[0, 0, 1], [1, 1, 0]], batch_size=2, image_size=(64, 64), ) # disable visualisation for testing assert dm.data_fetcher.block_viz_window is True dm.set_block_viz_window(False) assert dm.data_fetcher.block_viz_window is False # call show functions dm.show_train_batch() dm.show_train_batch("pre_tensor_transform") dm.show_train_batch("to_tensor_transform") dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) dm.show_val_batch("per_batch_transform")
def _test_learn2learning_training_strategies(gpus, accelerator, training_strategy, tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() (train_dir / "a").mkdir() pa_1 = train_dir / "a" / "1.png" pa_2 = train_dir / "a" / "2.png" pb_1 = train_dir / "b" / "1.png" pb_2 = train_dir / "b" / "2.png" image_size = (96, 96) _rand_image(image_size).save(pa_1) _rand_image(image_size).save(pa_2) (train_dir / "b").mkdir() _rand_image(image_size).save(pb_1) _rand_image(image_size).save(pb_2) n = 5 dm = ImageClassificationData.from_files( train_files=[str(pa_1)] * n + [str(pa_2)] * n + [str(pb_1)] * n + [str(pb_2)] * n, train_targets=[0] * n + [1] * n + [2] * n + [3] * n, batch_size=1, num_workers=0, transform_kwargs=dict(image_size=image_size), ) model = ImageClassifier( backbone="resnet18", training_strategy=training_strategy, training_strategy_kwargs={"ways": dm.num_classes, "shots": 4, "meta_batch_size": 4}, ) trainer = Trainer(fast_dev_run=2, gpus=gpus, accelerator=accelerator) trainer.fit(model, datamodule=dm)
def test_vissl_training(backbone, training_strategy, head, pretraining_transform, embedding_size): datamodule = ImageClassificationData.from_datasets( train_dataset=FakeData(16), predict_dataset=FakeData(8), batch_size=4, ) embedder = ImageEmbedder( backbone=backbone, training_strategy=training_strategy, head=head, pretraining_transform=pretraining_transform, ) trainer = flash.Trainer( max_steps=3, max_epochs=1, gpus=torch.cuda.device_count(), ) trainer.fit(embedder, datamodule=datamodule) predictions = trainer.predict(embedder, datamodule=datamodule) for prediction_batch in predictions: for prediction in prediction_batch: assert prediction.size(0) == embedding_size
def test_from_filepaths_smoke(tmpdir): tmpdir = Path(tmpdir) (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() _rand_image().save(tmpdir / "a_1.png") _rand_image().save(tmpdir / "b_1.png") train_images = [ str(tmpdir / "a_1.png"), str(tmpdir / "b_1.png"), ] img_data = ImageClassificationData.from_files( train_files=train_images, train_targets=[1, 2], batch_size=2, num_workers=0, ) assert img_data.train_dataloader() is not None assert img_data.val_dataloader() is None assert img_data.test_dataloader() is None data = next(iter(img_data.train_dataloader())) imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert sorted(list(labels.numpy())) == [1, 2]
def test_classification_fiftyone(tmpdir): tmpdir = Path(tmpdir) (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() _rand_image().save(tmpdir / "a_1.png") _rand_image().save(tmpdir / "b_1.png") train_images = [ str(tmpdir / "a_1.png"), str(tmpdir / "b_1.png"), ] train_dataset = fo.Dataset.from_dir(str(tmpdir), dataset_type=fo.types.ImageDirectory) s1 = train_dataset[train_images[0]] s2 = train_dataset[train_images[1]] s1["test"] = fo.Classification(label="1") s2["test"] = fo.Classification(label="2") s1.save() s2.save() data = ImageClassificationData.from_fiftyone( train_dataset=train_dataset, label_field="test", batch_size=2, num_workers=0, image_size=(64, 64), ) model = ImageClassifier(num_classes=2, backbone="resnet18") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, datamodule=data, strategy="freeze")
def simple_datamodule(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() (train_dir / "a").mkdir() pa_1 = train_dir / "a" / "1.png" pa_2 = train_dir / "a" / "2.png" pb_1 = train_dir / "b" / "1.png" pb_2 = train_dir / "b" / "2.png" image_size = (96, 96) _rand_image(image_size).save(pa_1) _rand_image(image_size).save(pa_2) (train_dir / "b").mkdir() _rand_image(image_size).save(pb_1) _rand_image(image_size).save(pb_2) n = 10 dm = ImageClassificationData.from_files( train_files=[str(pa_1)] * n + [str(pa_2)] * n + [str(pb_1)] * n + [str(pb_2)] * n, train_targets=[0] * n + [1] * n + [2] * n + [3] * n, test_files=[str(pa_1)] * n, test_targets=[0] * n, batch_size=2, num_workers=0, transform_kwargs=dict(image_size=image_size), ) return dm
def test_albumentations_mixup(single_target_csv): def mixup(batch, alpha=1.0): images = batch["input"] targets = batch["target"].float().unsqueeze(1) lam = np.random.beta(alpha, alpha) perm = torch.randperm(images.size(0)) batch["input"] = images * lam + images[perm] * (1 - lam) batch["target"] = targets * lam + targets[perm] * (1 - lam) for e in batch["metadata"]: e.update({"lam": lam}) return batch train_transform = { # applied only on images as ApplyToKeys is used with `input` "post_tensor_transform": ApplyToKeys("input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))), "per_batch_transform": mixup, } # merge the default transform for this task with new one. train_transform = merge_transforms(default_transforms((256, 256)), train_transform) img_data = ImageClassificationData.from_csv( "image", "target", train_file=single_target_csv, batch_size=2, num_workers=0, train_transform=train_transform, ) batch = next(iter(img_data.train_dataloader())) assert "lam" in batch["metadata"][0]
def test_from_folders_train_val(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() (train_dir / "a").mkdir() _rand_image().save(train_dir / "a" / "1.png") _rand_image().save(train_dir / "a" / "2.png") (train_dir / "b").mkdir() _rand_image().save(train_dir / "b" / "1.png") _rand_image().save(train_dir / "b" / "2.png") img_data = ImageClassificationData.from_folders( train_folder=train_dir, val_folder=train_dir, test_folder=train_dir, batch_size=2, num_workers=0, ) data = next(iter(img_data.train_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) data = next(iter(img_data.val_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) assert list(labels.numpy()) == [0, 0] data = next(iter(img_data.test_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) assert list(labels.numpy()) == [0, 0]
def test_from_datasets(): img_data = ImageClassificationData.from_datasets( train_dataset=FakeData(size=3, num_classes=2), val_dataset=FakeData(size=3, num_classes=2), test_dataset=FakeData(size=3, num_classes=2), batch_size=2, num_workers=0, ) # check training data data = next(iter(img_data.train_dataloader())) imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) # check validation data data = next(iter(img_data.val_dataloader())) imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) # check test data data = next(iter(img_data.test_dataloader())) imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,)
def test_from_filepaths_visualise(tmpdir): tmpdir = Path(tmpdir) (tmpdir / "e").mkdir() _rand_image().save(tmpdir / "e_1.png") train_images = [ str(tmpdir / "e_1.png"), str(tmpdir / "e_1.png"), str(tmpdir / "e_1.png"), ] dm = ImageClassificationData.from_files( train_files=train_images, train_targets=[0, 3, 6], val_files=train_images, val_targets=[1, 4, 7], test_files=train_images, test_targets=[2, 5, 8], batch_size=2, num_workers=0, ) # disable visualisation for testing assert dm.data_fetcher.block_viz_window is True dm.set_block_viz_window(False) assert dm.data_fetcher.block_viz_window is False # call show functions # dm.show_train_batch() dm.show_train_batch("pre_tensor_transform") dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"])
def test_mixup(single_target_csv): @dataclass class MyTransform(ImageClassificationInputTransform): alpha: float = 1.0 def mixup(self, batch): images = batch["input"] targets = batch["target"].float().unsqueeze(1) lam = np.random.beta(self.alpha, self.alpha) perm = torch.randperm(images.size(0)) batch["input"] = images * lam + images[perm] * (1 - lam) batch["target"] = targets * lam + targets[perm] * (1 - lam) for e in batch["metadata"]: e.update({"lam": lam}) return batch def per_batch_transform(self): return self.mixup img_data = ImageClassificationData.from_csv( "image", "target", train_file=single_target_csv, batch_size=2, num_workers=0, transform=MyTransform, ) batch = next(iter(img_data.train_dataloader())) assert "lam" in batch["metadata"][0]
def test_training_from_scratch(capsys): """ Execute Training for 2 epoch to check for error """ Path("testModels").mkdir(parents=True, exist_ok=True) with capsys.disabled(): with initialize(config_path="conf"): cfg = compose(config_name="classification") seed_everything(42, workers=cfg.trainer.workers) cfg.trainer.default.callbacks[ 0].dirpath = "/home/Develop/ai4prod_python/classification/testModels" cfg.trainer.default.callbacks[0].filename = MODEL_NAME cfg.trainer.default.max_epochs = 2 @dataclass class ImageClassificationInputTransform(InputTransform): # transforms added to input training data def train_input_per_sample_transform(self): return instantiate(cfg.dataset.train_transform, _convert_="all") # transform label to tensor def target_per_sample_transform(self) -> Callable: return torch.as_tensor # transforms added to input validation data def val_input_per_sample_transform(self): return instantiate(cfg.dataset.val_transform, _convert_="all") # Dataset Setup dm = ImageClassificationData.from_folders( train_folder=cfg.dataset.datasetPath + "train", train_transform=ImageClassificationInputTransform, val_folder=cfg.dataset.datasetPath + "val", val_transform=ImageClassificationInputTransform, batch_size=cfg.dataset.batch_size) # Model Instantiate model = instantiate(cfg.model.image_classifier) if cfg.model.from_scratch: cfg.model.image_classifier.pretrained = False trainer = instantiate(cfg.trainer.default) trainer.fit(model=model, datamodule=dm) assert True
def test_from_bad_csv_no_image(bad_csv_no_image): with pytest.raises(ValueError, match="Found no matches"): img_data = ImageClassificationData.from_csv( "image", ["target"], train_file=bad_csv_no_image, batch_size=1, num_workers=0, ) _ = next(iter(img_data.train_dataloader()))
def test_from_fiftyone(tmpdir): tmpdir = Path(tmpdir) (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() _rand_image().save(tmpdir / "a_1.png") _rand_image().save(tmpdir / "b_1.png") train_images = [ str(tmpdir / "a_1.png"), str(tmpdir / "b_1.png"), ] dataset = fo.Dataset.from_dir(str(tmpdir), dataset_type=fo.types.ImageDirectory) s1 = dataset[train_images[0]] s2 = dataset[train_images[1]] s1["test"] = fo.Classification(label="1") s2["test"] = fo.Classification(label="2") s1.save() s2.save() img_data = ImageClassificationData.from_fiftyone( train_dataset=dataset, test_dataset=dataset, val_dataset=dataset, label_field="test", batch_size=2, num_workers=0, ) assert img_data.train_dataloader() is not None assert img_data.val_dataloader() is not None assert img_data.test_dataloader() is not None # check train data data = next(iter(img_data.train_dataloader())) imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert sorted(list(labels.numpy())) == [0, 1] # check val data data = next(iter(img_data.val_dataloader())) imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert sorted(list(labels.numpy())) == [0, 1] # check test data data = next(iter(img_data.test_dataloader())) imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert sorted(list(labels.numpy())) == [0, 1]
def test_from_bad_csv_no_image(bad_csv_no_image): bad_file = os.path.join(os.path.dirname(bad_csv_no_image), "image_3") with pytest.raises(ValueError, match=f"File ID `image_3` resolved to `{bad_file}`, which does not exist."): img_data = ImageClassificationData.from_csv( "image", ["target"], train_file=bad_csv_no_image, batch_size=1, num_workers=0, ) _ = next(iter(img_data.train_dataloader()))
def run(transform: Any = None): dm = ImageClassificationData.from_files( train_files=train_filepaths, train_targets=train_labels, transform=transform, batch_size=B, num_workers=0, val_split=val_split, ) data = next(iter(dm.train_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (B, 3, H, W) assert labels.shape == (B,)
def test_only_embedding(backbone, embedding_size): datamodule = ImageClassificationData.from_datasets( predict_dataset=FakeData(8), batch_size=4, transform_kwargs=dict(image_size=(224, 224)), ) embedder = ImageEmbedder(backbone=backbone) trainer = flash.Trainer() predictions = trainer.predict(embedder, datamodule=datamodule) for prediction_batch in predictions: for prediction in prediction_batch: assert prediction.size(0) == embedding_size
def test_from_csv_multi_target(multi_target_csv): img_data = ImageClassificationData.from_csv( "image", ["target_1", "target_2"], train_file=multi_target_csv, batch_size=2, num_workers=0, ) # check training data data = next(iter(img_data.train_dataloader())) imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 2)
def from_hymenoptera( batch_size: int = 4, num_workers: int = 0, **preprocess_kwargs, ) -> ImageClassificationData: """Downloads and loads the Hymenoptera (Ants, Bees) data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") return ImageClassificationData.from_folders( train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", batch_size=batch_size, num_workers=num_workers, **preprocess_kwargs, )
def test_from_data_frame_smoke(tmpdir): tmpdir = Path(tmpdir) df = pd.DataFrame( {"file": ["train.png", "valid.png", "test.png"], "split": ["train", "valid", "test"], "target": [0, 1, 1]} ) [_rand_image().save(tmpdir / row.file) for i, row in df.iterrows()] img_data = ImageClassificationData.from_data_frame( "file", "target", train_images_root=str(tmpdir), val_images_root=str(tmpdir), test_images_root=str(tmpdir), train_data_frame=df[df.split == "train"], val_data_frame=df[df.split == "valid"], test_data_frame=df[df.split == "test"], predict_images_root=str(tmpdir), batch_size=1, predict_data_frame=df, ) assert img_data.train_dataloader() is not None assert img_data.val_dataloader() is not None assert img_data.test_dataloader() is not None assert img_data.predict_dataloader() is not None data = next(iter(img_data.train_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1,) assert sorted(list(labels.numpy())) == [0] data = next(iter(img_data.val_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1,) assert sorted(list(labels.numpy())) == [1] data = next(iter(img_data.test_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1,) assert sorted(list(labels.numpy())) == [1] data = next(iter(img_data.predict_dataloader())) imgs = data["input"] assert imgs.shape == (1, 3, 196, 196)
def from_movie_posters( batch_size: int = 4, num_workers: int = 0, **preprocess_kwargs, ) -> ImageClassificationData: """Downloads and loads the movie posters genre classification data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "./data") return ImageClassificationData.from_csv( "Id", ["Action", "Romance", "Crime", "Thriller", "Adventure"], train_file="data/movie_posters/train/metadata.csv", val_file="data/movie_posters/val/metadata.csv", batch_size=batch_size, num_workers=num_workers, **preprocess_kwargs, )
def test_classification_task_predict_folder_path(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() def _rand_image(): return Image.fromarray( np.random.randint(0, 255, (256, 256, 3), dtype="uint8")) _rand_image().save(train_dir / "1.png") _rand_image().save(train_dir / "2.png") datamodule = ImageClassificationData.from_folders(predict_folder=train_dir) task = ImageClassifier(num_classes=10) predictions = task.predict(str(train_dir), data_pipeline=datamodule.data_pipeline) assert len(predictions) == 2
def test_from_folders_only_train(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() (train_dir / "a").mkdir() _rand_image().save(train_dir / "a" / "1.png") _rand_image().save(train_dir / "a" / "2.png") (train_dir / "b").mkdir() _rand_image().save(train_dir / "b" / "1.png") _rand_image().save(train_dir / "b" / "2.png") img_data = ImageClassificationData.from_folders(train_dir, batch_size=1) data = img_data.train_dataset[0] imgs, labels = data["input"], data["target"] assert isinstance(imgs, Image.Image) assert labels == 0
def test_from_filepaths_list_image_paths(tmpdir): tmpdir = Path(tmpdir) (tmpdir / "e").mkdir() _rand_image().save(tmpdir / "e_1.png") train_images = [ str(tmpdir / "e_1.png"), str(tmpdir / "e_1.png"), str(tmpdir / "e_1.png"), ] img_data = ImageClassificationData.from_files( train_files=train_images, train_targets=[0, 3, 6], val_files=train_images, val_targets=[1, 4, 7], test_files=train_images, test_targets=[2, 5, 8], batch_size=2, num_workers=0, ) # check training data data = next(iter(img_data.train_dataloader())) imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here # check validation data data = next(iter(img_data.val_dataloader())) imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [1, 4] # check test data data = next(iter(img_data.test_dataloader())) imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [2, 5]
def test_from_folders_only_train(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() (train_dir / "a").mkdir() _rand_image().save(train_dir / "a" / "1.png") _rand_image().save(train_dir / "a" / "2.png") (train_dir / "b").mkdir() _rand_image().save(train_dir / "b" / "1.png") _rand_image().save(train_dir / "b" / "2.png") img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) data = next(iter(img_data.train_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1,) assert img_data.val_dataloader() is None assert img_data.test_dataloader() is None
def test_from_filepaths_multilabel(tmpdir): tmpdir = Path(tmpdir) (tmpdir / "a").mkdir() _rand_image().save(tmpdir / "a1.png") _rand_image().save(tmpdir / "a2.png") train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")] train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]] valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]] test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]] dm = ImageClassificationData.from_files( train_files=train_images, train_targets=train_labels, val_files=train_images, val_targets=valid_labels, test_files=train_images, test_targets=test_labels, batch_size=2, num_workers=0, ) data = next(iter(dm.train_dataloader())) imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) data = next(iter(dm.val_dataloader())) imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) data = next(iter(dm.test_dataloader())) imgs, labels = data['input'], data['target'] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(test_labels))
def test_classification(tmpdir): tmpdir = Path(tmpdir) (tmpdir / "a").mkdir() (tmpdir / "b").mkdir() image_a = str(tmpdir / "a" / "a_1.png") image_b = str(tmpdir / "b" / "b_1.png") _rand_image().save(image_a) _rand_image().save(image_b) data = ImageClassificationData.from_files( train_files=[image_a, image_b], train_targets=[0, 1], num_workers=0, batch_size=2, image_size=(64, 64), ) model = ImageClassifier(num_classes=2, backbone="resnet18") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, datamodule=data, strategy="freeze")
def run(self, train_folder): # Create a datamodule from the given dataset datamodule = ImageClassificationData.from_folders( train_folder=train_folder, batch_size=1, val_split=0.5, ) # Create an image classfier task with the given backbone model = ImageClassifier(datamodule.num_classes, backbone=self.backbone) # Start a Lightning trainer, with 1 training batch and 4 validation batches trainer = flash.Trainer( max_epochs=self.max_epochs, limit_train_batches=1, limit_val_batches=4, callbacks=[ModelCheckpoint(monitor="val_cross_entropy")], ) # Train the model trainer.fit(model, datamodule=datamodule) # Save the model path self.best_model_path = trainer.checkpoint_callback.best_model_path # Save the model score self.best_model_score = trainer.checkpoint_callback.best_model_score.item( )
def from_movie_posters( batch_size: int = 4, num_workers: int = 0, **data_module_kwargs, ) -> ImageClassificationData: """Downloads and loads the movie posters genre classification data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "./data") def resolver(root, file_id): return os.path.join(root, f"{file_id}.jpg") return ImageClassificationData.from_csv( "Id", ["Action", "Romance", "Crime", "Thriller", "Adventure"], train_file="data/movie_posters/train/metadata.csv", train_resolver=resolver, val_file="data/movie_posters/val/metadata.csv", val_resolver=resolver, batch_size=batch_size, num_workers=num_workers, **data_module_kwargs, )
def test_multicrop_input_transform(): batch_size = 8 total_crops = 6 num_crops = [2, 4] size_crops = [160, 96] crop_scales = [[0.4, 1], [0.05, 0.4]] multi_crop_transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"]( total_crops, num_crops, size_crops, crop_scales) to_tensor_transform = ApplyToKeys( DefaultDataKeys.INPUT, multi_crop_transform, ) preprocess = DefaultPreprocess(train_transform={ "to_tensor_transform": to_tensor_transform, "collate": vissl_collate_fn, }) datamodule = ImageClassificationData.from_datasets( train_dataset=FakeData(), preprocess=preprocess, batch_size=batch_size, ) train_dataloader = datamodule._train_dataloader() batch = next(iter(train_dataloader)) assert len(batch[DefaultDataKeys.INPUT]) == total_crops assert batch[DefaultDataKeys.INPUT][0].shape == (batch_size, 3, size_crops[0], size_crops[0]) assert batch[DefaultDataKeys.INPUT][-1].shape == (batch_size, 3, size_crops[-1], size_crops[-1]) assert list(batch[DefaultDataKeys.TARGET].shape) == [batch_size]