예제 #1
0
파일: vae.py 프로젝트: yngtodd/convae_pyro
class VAE(nn.Module):
    def __init__(self,
                 z_dim=50,
                 hidden_dim=400,
                 enc_kernel1=5,
                 enc_kernel2=5,
                 use_cuda=False):
        super(VAE, self).__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim, enc_kernel1, enc_kernel2)
        self.decoder = Decoder(z_dim, hidden_dim)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim

    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        # setup hyperparameters for prior p(z)
        # the type_as ensures we get cuda Tensors if x is on gpu
        z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data)
        z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data)
        # sample from prior (value will be sampled by guide when computing the ELBO)
        z = pyro.sample("latent", dist.normal, z_mu, z_sigma)
        # decode the latent code z
        mu_img = self.decoder.forward(z)
        # score against actual images
        pyro.observe("obs", dist.bernoulli, x.view(-1, 784), mu_img)

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        # use the encoder to get the parameters used to define q(z|x)
        z_mu, z_sigma = self.encoder.forward(x)
        # sample the latent code z
        pyro.sample("latent", dist.normal, z_mu, z_sigma)

    # define a helper function for reconstructing images
    def reconstruct_img(self, x):
        # encode image x
        x = x.view(1, 1, 28, 28)
        z_mu, z_sigma = self.encoder(x)
        # sample in latent space
        z = dist.normal(z_mu, z_sigma)
        # decode the image (note we don't sample in image space)
        mu_img = self.decoder(z)
        return mu_img

    def model_sample(self, batch_size=1):
        # sample the handwriting style from the constant prior distribution
        prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
        prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
        zs = pyro.sample("z", dist.normal, prior_mu, prior_sigma)
        mu = self.decoder.forward(zs)
        xs = pyro.sample("sample", dist.bernoulli, mu)
        return xs, mu
예제 #2
0
class VAE(Module):
    '''
    Class that define the posterior distribution q(z|x) as the model 
    with the decoder and the prior distribution q(x|z) as the guide 
    using the encoder.
    
    Inputs:  
    :pimg_dim: dimension of image vector
    :label_dim: dimension of label vector
    :latent_dim: dimension of Z space, output
    '''
    def __init__(self,
                 latents_sizes,
                 latents_names,
                 img_dim=4096,
                 label_dim=114,
                 latent_dim=200,
                 use_CUDA=False):
        super(VAE, self).__init__()
        #creating networks
        self.encoder = Encoder(img_dim, label_dim, latent_dim)
        self.decoder = Decoder(img_dim, label_dim, latent_dim)
        self.img_dim = img_dim
        self.label_dim = label_dim
        self.latent_dim = latent_dim
        self.latents_sizes = latents_sizes
        self.latents_names = latents_names
        if use_CUDA:
            self.cuda()
        self.use_CUDA = use_CUDA

    def label_variable(self, label):
        new_label = []
        options = {'device': label.device, 'dtype': label.dtype}
        for i, length in enumerate(self.latents_sizes):
            prior = torch.ones(label.shape[0], length, **
                               options) / (1.0 * length)
            new_label.append(
                pyro.sample("label_" + str(self.latents_names[i]),
                            OneHotCategorical(prior),
                            obs=one_hot(tensor(label[:, i], dtype=torch.int64),
                                        int(length))))
        new_label = torch.cat(new_label, -1)
        return new_label.to(torch.float32).to(label.device)

    def model(self, img, label):
        pyro.module("decoder", self.decoder)
        options = {'device': img.device, 'dtype': img.dtype}
        with pyro.plate("data", img.shape[0]):
            z_mean = torch.zeros(img.shape[0], self.latent_dim, **options)
            z_variance = torch.ones(img.shape[0], self.latent_dim, **options)
            z_sample = pyro.sample("latent",
                                   Normal(z_mean, z_variance).to_event(1))
            image = self.decoder.forward(z_sample, self.label_variable(label))
            pyro.sample("obs", Bernoulli(image).to_event(1), obs=img)

    def guide(self, img, label):
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", img.shape[0]):
            z_mean, z_variance = self.encoder.forward(
                img, self.label_variable(label))
            pyro.sample("latent", Normal(z_mean, z_variance).to_event(1))

    def run_img(self, img, label, num=1):
        label = label.reshape(1, -6)
        dummy_label = dummy_from_label(label)
        img = tensor(img.reshape(-1, 4096)).to(torch.float32)
        mean, var = self.encoder.forward(img, dummy_label)

        fig = plt.figure(figsize=(4, num * 5))
        plots = []
        plots.append(plt.subplot(num + 1, 1, 1))
        plots[0].set_title('Original image')
        plt.imshow(img.reshape(64, 64))

        for i in range(1, num):
            z_sample = Normal(mean, var).sample()
            vae_img = self.decoder.forward(z_sample, dummy_label)
            plots.append(plt.subplot(num + 1, 1, i + 1))
            plots[-1].set_title(str(i) + ' - sample of latent space')
            plt.imshow(vae_img.detach().numpy().reshape(64, 64))
        plt.show()

    def change_attribute(self, img, label, attribute=1):
        print('Attribute changed was ' + str(self.latents_names[attribute]))
        label = label.reshape(1, -6)
        new_label = np.copy(label)
        while (new_label == label).all():
            val = np.random.choice(list(range(self.latents_sizes[attribute])))
            new_label[0, attribute] = val
        dummy_label = dummy_from_label(label)
        new_dummy = dummy_from_label(new_label)
        img = tensor(img.reshape(-1, 4096)).to(torch.float32)
        mean, var = self.encoder.forward(img, dummy_label)

        fig = plt.figure(figsize=(4, 15))
        plots = []

        plots.append(plt.subplot(3, 1, 1))
        plots[0].set_title('Original image')
        plt.imshow(img.reshape(64, 64))

        z_sample = Normal(mean, var).sample()
        vae_img = self.decoder.forward(z_sample, dummy_label)
        plots.append(plt.subplot(3, 1, 2))
        plots[1].set_title('Sample with original attribute')
        plt.imshow(vae_img.detach().numpy().reshape(64, 64))

        z_sample = Normal(mean, var).sample()
        vae_img = self.decoder.forward(z_sample, new_dummy)
        plots.append(plt.subplot(3, 1, 3))
        plots[2].set_title('Sample with changed attribute')
        plt.imshow(vae_img.detach().numpy().reshape(64, 64))
        plt.show()