def test_from_folders(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() (train_dir / "a").mkdir() _rand_image().save(train_dir / "a" / "1.png") _rand_image().save(train_dir / "a" / "2.png") (train_dir / "b").mkdir() _rand_image().save(train_dir / "b" / "1.png") _rand_image().save(train_dir / "b" / "2.png") img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, loader=_dummy_image_loader, batch_size=1) data = next(iter(img_data.train_dataloader())) imgs, labels = data assert imgs.shape == (1, 3, 64, 64) assert labels.shape == (1, ) assert img_data.val_dataloader() is None assert img_data.test_dataloader() is None img_data = ImageClassificationData.from_folders( train_dir, train_transform=T.ToTensor(), valid_folder=train_dir, valid_transform=T.ToTensor(), test_folder=train_dir, batch_size=1, num_workers=0, ) data = next(iter(img_data.val_dataloader())) imgs, labels = data assert imgs.shape == (1, 3, 64, 64) assert labels.shape == (1, ) data = next(iter(img_data.test_dataloader())) imgs, labels = data assert imgs.shape == (1, 3, 64, 64) assert labels.shape == (1, )
def test_classification_task_predict_folder_path(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() def _rand_image(): return Image.fromarray(np.random.randint(0, 255, (256, 256, 3), dtype="uint8")) _rand_image().save(train_dir / "1.png") _rand_image().save(train_dir / "2.png") datamodule = ImageClassificationData.from_folders(predict_folder=train_dir) task = ImageClassifier(num_classes=10) predictions = task.predict(str(train_dir), data_pipeline=datamodule.data_pipeline) assert len(predictions) == 2
def test_from_folders_only_train(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() (train_dir / "a").mkdir() _rand_image().save(train_dir / "a" / "1.png") _rand_image().save(train_dir / "a" / "2.png") (train_dir / "b").mkdir() _rand_image().save(train_dir / "b" / "1.png") _rand_image().save(train_dir / "b" / "2.png") img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) data = next(iter(img_data.train_dataloader())) imgs, labels = data assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1, ) assert img_data.val_dataloader() is None assert img_data.test_dataloader() is None
def test_from_folders_train_val(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() (train_dir / "a").mkdir() _rand_image().save(train_dir / "a" / "1.png") _rand_image().save(train_dir / "a" / "2.png") (train_dir / "b").mkdir() _rand_image().save(train_dir / "b" / "1.png") _rand_image().save(train_dir / "b" / "2.png") img_data = ImageClassificationData.from_folders( train_dir, val_folder=train_dir, test_folder=train_dir, batch_size=2, num_workers=0, ) data = next(iter(img_data.train_dataloader())) imgs, labels = data assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) data = next(iter(img_data.val_dataloader())) imgs, labels = data assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [0, 0] data = next(iter(img_data.test_dataloader())) imgs, labels = data assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) assert list(labels.numpy()) == [0, 0]
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from flash import Trainer from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Load the model from a checkpoint model = ImageClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/image_classification_model.pt") # 3a. Predict what's on a few images! ants or bees? predictions = model.predict([ "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", ]) print(predictions) # 3b. Or generate predictions with a whole folder! datamodule = ImageClassificationData.from_folders( predict_folder="data/hymenoptera_data/predict/") predictions = Trainer().predict(model, datamodule=datamodule) print(predictions)
# See the License for the specific language governing permissions and # limitations under the License. from flash import Trainer from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Load the model from a checkpoint model = ImageClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/image_classification_model.pt") # 3a. Predict what's on a few images! ants or bees? predictions = model.predict([ "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", ]) print(predictions) # 3b. Or generate predictions with a whole folder! datamodule = ImageClassificationData.from_folders( predict_folder="data/hymenoptera_data/predict/", preprocess=model.preprocess, ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions)
import flash from flash.core.data import download_data from flash.core.finetuning import FreezeUnfreeze from flash.vision import ImageClassificationData, ImageClassifier if __name__ == "__main__": # 1. Download the data download_data( "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') # 2. Load the data datamodule = ImageClassificationData.from_folders( train_folder="data/hymenoptera_data/train/", valid_folder="data/hymenoptera_data/val/", test_folder="data/hymenoptera_data/test/", ) # 3. Build the model model = ImageClassifier(num_classes=datamodule.num_classes) # 4. Create the trainer. Run twice on data trainer = flash.Trainer(max_epochs=2) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) # 6. Test the model trainer.test()
import flash from flash import Trainer from flash.core.classification import Labels from flash.core.finetuning import FreezeUnfreeze from flash.data.utils import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Load the data datamodule = ImageClassificationData.from_folders( train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", test_folder="data/hymenoptera_data/test/", ) # 3.a Optional: Register a custom backbone # This is useful to create new backbone and make them accessible from `ImageClassifier` @ImageClassifier.backbones(name="resnet18") def fn_resnet(pretrained: bool = True): model = torchvision.models.resnet18(pretrained) # remove the last two layers & turn it into a Sequential model backbone = nn.Sequential(*list(model.children())[:-2]) num_features = model.fc.in_features # backbones need to return the num_features to build the head return backbone, num_features
import flash from flash import download_data from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') # 2. Load the data datamodule = ImageClassificationData.from_folders( backbone="resnet34", num_workers=8, train_folder="data/hymenoptera_data/train/", valid_folder="data/hymenoptera_data/val/", test_folder="data/hymenoptera_data/test/", ) # 3. Build the model model = ImageClassifier(num_classes=datamodule.num_classes, backbone="resnet18") # 4. Create the trainer. Run once on data trainer = flash.Trainer(max_epochs=4) # 5. Finetune the model trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze") # 7. Save it! trainer.save_checkpoint("image_classification_model.pt")
image = make_grid(images, nrow=2) image = T.to_pil_image(image, 'RGB') image.show() # 3. Load the model from a checkpoint model = ImageClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/image_classification_multi_label_model.pt", ) # 4a. Predict the genres of a few movie posters! predictions = model.predict([ "data/movie_posters/predict/tt0085318.jpg", "data/movie_posters/predict/tt0089461.jpg", "data/movie_posters/predict/tt0097179.jpg", ]) print(predictions) # 4b. Or generate predictions with a whole folder! datamodule = ImageClassificationData.from_folders( predict_folder="data/movie_posters/predict/", data_fetcher=CustomViz(), image_size=(128, 128), ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) # 5. Show some data (unless we're just testing)! datamodule.show_predict_batch("per_batch_transform")
def cli_main(): pl.seed_everything(1234) # ------------ # args # ------------ parser = ArgumentParser() parser.add_argument('--data_dir', type=str, default='cifar5') #parser.add_argument('--max_epochs', type=int, default=2) # add trainer args (gpus=x, precision=...) parser = pl.Trainer.add_argparse_args(parser) # add model args (batch_size hidden_dim, etc...), anything defined in add_model_specific_args parser = LitClassifier.add_model_specific_args(parser) args = parser.parse_args() print(args) # ------------ # data # ------------ transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.4913, 0.482, 0.446], std=[0.247, 0.243, 0.261]) ]) # in real life you would have a separate validation split datamodule = ImageClassificationData.from_folders( train_folder=args.data_dir + '/train', valid_folder=args.data_dir + '/test', test_folder=args.data_dir + '/test', batch_size=args.batch_size, transform=transform ) # ------------ # model # ------------ model = LitClassifier( backbone=args.backbone, learning_rate=args.learning_rate, hidden_dim=args.hidden_dim ) # ------------ # training # ------------ print('training') trainer = pl.Trainer.from_argparse_args(args) #, fast_dev_run=True) trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader()) # ------------ # testing # ------------ print('testing') result = trainer.test(model, test_dataloaders=datamodule.test_dataloader()) print(result) # predicting print('predicting') preds = trainer.predict(model, datamodule.test_dataloader()) #import pdb; pdb.set_trace() #print(preds) # list of n=N/B tensors, each of size B=batchsize=32. #preds = list(np.stack(preds).flatten()) # fails on last batch, which is shorter path = os.getcwd() + '/predictions.txt' with open(path, 'w') as f: preds_str = [str(x) for lst in preds for x in lst] f.write('\n'.join(preds_str))
image = make_grid(images, nrow=2) image = T.to_pil_image(image, 'RGB') image.show() # 3. Load the model from a checkpoint model = ImageClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/image_classification_multi_label_model.pt", ) # 4a. Predict the genres of a few movie posters! predictions = model.predict([ "data/movie_posters/predict/tt0085318.jpg", "data/movie_posters/predict/tt0089461.jpg", "data/movie_posters/predict/tt0097179.jpg", ]) print(predictions) # 4b. Or generate predictions with a whole folder! datamodule = ImageClassificationData.from_folders( predict_folder="data/movie_posters/predict/", data_fetcher=CustomViz(), preprocess=model.preprocess, ) predictions = Trainer().predict(model, datamodule=datamodule) print(predictions) # 5. Show some data (unless we're just testing)! datamodule.show_predict_batch("per_batch_transform")