def basic_ae_example(): from pl_bolts.models.autoencoders import AE from pytorch_lightning.plugins import DDPPlugin # Data. train_dataset = torchvision.datasets.CIFAR10( "", train=True, download=True, transform=torchvision.transforms.ToTensor()) val_dataset = torchvision.datasets.CIFAR10( "", train=False, download=True, transform=torchvision.transforms.ToTensor()) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12) #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12, persistent_workers=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12) #val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12, persistent_workers=True) # Model. model = AE(input_height=32) """ # Override any part of this AE to build your own variation. class MyAEFlavor(AE): def init_encoder(self, hidden_dim, latent_dim, input_width, input_height): encoder = YourSuperFancyEncoder(...) return encoder model = MyAEFlavor(...) """ # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False)) trainer.fit(model, train_loader, val_loader) #-------------------- # CIFAR-10 pretrained model: ae = AE(input_height=32) print(AE.pretrained_weights_available()) ae = ae.from_pretrained("cifar10-resnet18") ae.freeze()
def test_ae(tmpdir): reset_seed() model = AE(data_dir=tmpdir, batch_size=2) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model) trainer.test(model)
def test_from_pretrained(tmpdir): vae = VAE(input_height=32) ae = AE(input_height=32) assert len(VAE.pretrained_weights_available()) > 0 assert len(AE.pretrained_weights_available()) > 0 exception_raised = False try: vae = vae.from_pretrained('cifar10-resnet18') vae = vae.from_pretrained( 'stl10-resnet18' ) # try loading weights not compatible with exact architecture ae = ae.from_pretrained('cifar10-resnet18') except Exception as e: exception_raised = True assert exception_raised is False, "error in loading weights" keyerror = False try: vae = vae.from_pretrained('abc') ae = ae.from_pretrained('xyz') except KeyError: keyerror = True assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
def test_ae(tmpdir, datadir, dm_cls): seed_everything() dm = dm_cls(data_dir=datadir, batch_size=4) model = AE(input_height=dm.size()[-1]) trainer = Trainer( fast_dev_run=True, default_root_dir=tmpdir, gpus=None, ) trainer.fit(model, datamodule=dm)
def test_ae(tmpdir, datadir, dm_cls): seed_everything() dm = dm_cls(data_dir=datadir, batch_size=4) model = AE(input_height=dm.size()[-1]) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir, max_epochs=1, gpus=None) result = trainer.fit(model, dm) assert result == 1
def test_from_pretrained(datadir): dm = CIFAR10DataModule(data_dir=datadir, batch_size=4) dm.setup() dm.prepare_data() data_loader = dm.val_dataloader() vae = VAE(input_height=32) ae = AE(input_height=32) assert len(VAE.pretrained_weights_available()) > 0 assert len(AE.pretrained_weights_available()) > 0 exception_raised = False try: vae = vae.from_pretrained("cifar10-resnet18") # test forward method on pre-trained weights for x, y in data_loader: vae(x) break vae = vae.from_pretrained( "stl10-resnet18" ) # try loading weights not compatible with exact architecture ae = ae.from_pretrained("cifar10-resnet18") # test forward method on pre-trained weights with torch.no_grad(): for x, y in data_loader: ae(x) break except Exception: exception_raised = True assert exception_raised is False, "error in loading weights" keyerror = False try: vae.from_pretrained("abc") ae.from_pretrained("xyz") except KeyError: keyerror = True assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
import argparse import pytorch_lightning as pl from pl_bolts.models.autoencoders import AE from ae_datasets import iCLEVRDataModule from pytorch_lightning.callbacks import ModelCheckpoint parser = argparse.ArgumentParser() parser = AE.add_model_specific_args(parser) parser = iCLEVRDataModule.add_dataset_specific_args(parser) parser = pl.Trainer.add_argparse_args(parser) opt = parser.parse_args() model = AE(**vars(opt)) data = iCLEVRDataModule(**vars(opt)) trainer = pl.Trainer.from_argparse_args( opt, log_gpu_memory='all', checkpoint_callback=ModelCheckpoint(filepath="{epoch:02d}.ckpt")) trainer.fit(model, data)