def test_from_pretrained(tmpdir): vae = VAE(input_height=32) ae = AE(input_height=32) assert len(VAE.pretrained_weights_available()) > 0 assert len(AE.pretrained_weights_available()) > 0 exception_raised = False try: vae = vae.from_pretrained('cifar10-resnet18') vae = vae.from_pretrained( 'stl10-resnet18' ) # try loading weights not compatible with exact architecture ae = ae.from_pretrained('cifar10-resnet18') except Exception as e: exception_raised = True assert exception_raised is False, "error in loading weights" keyerror = False try: vae = vae.from_pretrained('abc') ae = ae.from_pretrained('xyz') except KeyError: keyerror = True assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
def test_from_pretrained(datadir): dm = CIFAR10DataModule(data_dir=datadir, batch_size=4) dm.setup() dm.prepare_data() data_loader = dm.val_dataloader() vae = VAE(input_height=32) ae = AE(input_height=32) assert len(VAE.pretrained_weights_available()) > 0 assert len(AE.pretrained_weights_available()) > 0 exception_raised = False try: vae = vae.from_pretrained("cifar10-resnet18") # test forward method on pre-trained weights for x, y in data_loader: vae(x) break vae = vae.from_pretrained( "stl10-resnet18" ) # try loading weights not compatible with exact architecture ae = ae.from_pretrained("cifar10-resnet18") # test forward method on pre-trained weights with torch.no_grad(): for x, y in data_loader: ae(x) break except Exception: exception_raised = True assert exception_raised is False, "error in loading weights" keyerror = False try: vae.from_pretrained("abc") ae.from_pretrained("xyz") except KeyError: keyerror = True assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
def basic_ae_example(): from pl_bolts.models.autoencoders import AE from pytorch_lightning.plugins import DDPPlugin # Data. train_dataset = torchvision.datasets.CIFAR10( "", train=True, download=True, transform=torchvision.transforms.ToTensor()) val_dataset = torchvision.datasets.CIFAR10( "", train=False, download=True, transform=torchvision.transforms.ToTensor()) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12) #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12, persistent_workers=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12) #val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12, persistent_workers=True) # Model. model = AE(input_height=32) """ # Override any part of this AE to build your own variation. class MyAEFlavor(AE): def init_encoder(self, hidden_dim, latent_dim, input_width, input_height): encoder = YourSuperFancyEncoder(...) return encoder model = MyAEFlavor(...) """ # Fit. trainer = pl.Trainer(gpus=2, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False)) trainer.fit(model, train_loader, val_loader) #-------------------- # CIFAR-10 pretrained model: ae = AE(input_height=32) print(AE.pretrained_weights_available()) ae = ae.from_pretrained("cifar10-resnet18") ae.freeze()