Esempio n. 1
0
def test_from_pretrained(tmpdir):
    vae = VAE(input_height=32)
    ae = AE(input_height=32)

    assert len(VAE.pretrained_weights_available()) > 0
    assert len(AE.pretrained_weights_available()) > 0

    exception_raised = False

    try:
        vae = vae.from_pretrained('cifar10-resnet18')
        vae = vae.from_pretrained(
            'stl10-resnet18'
        )  # try loading weights not compatible with exact architecture

        ae = ae.from_pretrained('cifar10-resnet18')
    except Exception as e:
        exception_raised = True

    assert exception_raised is False, "error in loading weights"

    keyerror = False

    try:
        vae = vae.from_pretrained('abc')
        ae = ae.from_pretrained('xyz')
    except KeyError:
        keyerror = True

    assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
Esempio n. 2
0
def simple_vae_example():
    from pl_bolts.models.autoencoders import VAE
    from pytorch_lightning.plugins import DDPPlugin

    # Data.
    train_dataset = torchvision.datasets.CIFAR10(
        "",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor())
    val_dataset = torchvision.datasets.CIFAR10(
        "",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor())

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=32,
                                               num_workers=12)
    #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12, persistent_workers=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=32,
                                             num_workers=12)
    #val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12, persistent_workers=True)

    # Model.
    model = VAE(input_height=32)
    """
	# Override any part of this AE to build your own variation.
	class MyVAEFlavor(VAE):
		def get_posterior(self, mu, std):
			# Do something other than the default.
			P = self.get_distribution(self.prior, loc=torch.zeros_like(mu), scale=torch.ones_like(std))
			return P

	model = MyVAEFlavor(...)
	"""

    # Fit.
    trainer = pl.Trainer(gpus=2,
                         accelerator="ddp",
                         plugins=DDPPlugin(find_unused_parameters=False))
    trainer.fit(model, train_loader, val_loader)

    #--------------------
    # CIFAR-10 pretrained model:
    vae = VAE(input_height=32)
    print(VAE.pretrained_weights_available())
    vae = vae.from_pretrained("cifar10-resnet18")

    vae.freeze()
Esempio n. 3
0
def test_from_pretrained(datadir):
    dm = CIFAR10DataModule(data_dir=datadir, batch_size=4)
    dm.setup()
    dm.prepare_data()

    data_loader = dm.val_dataloader()

    vae = VAE(input_height=32)
    ae = AE(input_height=32)

    assert len(VAE.pretrained_weights_available()) > 0
    assert len(AE.pretrained_weights_available()) > 0

    exception_raised = False

    try:
        vae = vae.from_pretrained("cifar10-resnet18")

        # test forward method on pre-trained weights
        for x, y in data_loader:
            vae(x)
            break

        vae = vae.from_pretrained(
            "stl10-resnet18"
        )  # try loading weights not compatible with exact architecture

        ae = ae.from_pretrained("cifar10-resnet18")

        # test forward method on pre-trained weights
        with torch.no_grad():
            for x, y in data_loader:
                ae(x)
                break

    except Exception:
        exception_raised = True

    assert exception_raised is False, "error in loading weights"

    keyerror = False

    try:
        vae.from_pretrained("abc")
        ae.from_pretrained("xyz")
    except KeyError:
        keyerror = True

    assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"