Example #1
0
    def __init__(self):
        super(BVAE, self).__init__()

        if torch.cuda.is_available():
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor

        self.z_size = 20
        self.input_size = 784

        #Encoder
        self.fc1 = nn.Linear(self.input_size, 200)
        self.fc2 = nn.Linear(200, self.z_size * 2)
        #Decoder
        # self.fc3 = nn.Linear(self.z_size, 200)
        # self.fc4 = nn.Linear(200, 784)
        # self.decoder = BNN([self.z_size, 200, 784], [torch.nn.Softplus, torch.nn.Softplus])
        self.decoder = BNN([self.z_size, 200, 784], [F.relu, F.relu])
Example #2
0
    def __init__(self, qW_weight, seed=-1):
        super(BVAE, self).__init__()

        if seed != -1:
            torch.manual_seed(seed)

        if torch.cuda.is_available():
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor

        self.qW_weight = qW_weight

        self.z_size = 2
        self.input_size = 784

        #Encoder
        self.fc1 = nn.Linear(self.input_size, 200)
        self.fc2 = nn.Linear(200, self.z_size * 2)
        #Decoder
        self.decoder = BNN([self.z_size, 200, 784], [F.relu, F.relu])
Example #3
0
    def __init__(self):
        super(BVAE, self).__init__()

        if torch.cuda.is_available():
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor
            

        self.z_size = 20
        self.input_size = 784

        #Encoder
        self.fc1 = nn.Linear(self.input_size, 200)
        self.fc2 = nn.Linear(200, self.z_size*2)
        #Decoder
        # self.fc3 = nn.Linear(self.z_size, 200)
        # self.fc4 = nn.Linear(200, 784)
        # self.decoder = BNN([self.z_size, 200, 784], [torch.nn.Softplus, torch.nn.Softplus])
        self.decoder = BNN([self.z_size, 200, 784], [F.relu, F.relu])
    def __init__(self, qW_weight, seed=-1):
        super(BVAE, self).__init__()

        if seed != -1:
            torch.manual_seed(seed)

        if torch.cuda.is_available():
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor
            
        self.qW_weight = qW_weight

        self.z_size = 2
        self.input_size = 784

        #Encoder
        self.fc1 = nn.Linear(self.input_size, 200)
        self.fc2 = nn.Linear(200, self.z_size*2)
        #Decoder
        self.decoder = BNN([self.z_size, 200, 784], [F.relu, F.relu])
Example #5
0
class BVAE(nn.Module):
    def __init__(self, qW_weight, seed=-1):
        super(BVAE, self).__init__()

        if seed != -1:
            torch.manual_seed(seed)

        if torch.cuda.is_available():
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor

        self.qW_weight = qW_weight

        self.z_size = 2
        self.input_size = 784

        #Encoder
        self.fc1 = nn.Linear(self.input_size, 200)
        self.fc2 = nn.Linear(200, self.z_size * 2)
        #Decoder
        self.decoder = BNN([self.z_size, 200, 784], [F.relu, F.relu])

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = self.fc2(h1)
        mean = h2[:, :self.z_size]
        logvar = h2[:, self.z_size:]
        return mean, logvar

    def sample_z(self, mu, logvar, k):
        B = mu.size()[0]
        eps = Variable(
            torch.FloatTensor(k, B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mu  #[P,B,Z]
        logpz = lognormal(
            z, Variable(torch.zeros(B, self.z_size).type(self.dtype)),
            Variable(torch.zeros(B, self.z_size)).type(self.dtype))  #[P,B]

        logqz = lognormal(z, mu, logvar)
        return z, logpz, logqz

    def sample_W(self):

        Ws, log_p_W_sum, log_q_W_sum = self.decoder.sample_weights()
        return Ws, log_p_W_sum, log_q_W_sum

    def decode(self, Ws, z):
        k = z.size()[0]
        B = z.size()[1]
        z = z.view(-1, self.z_size)
        x = self.decoder.forward(Ws, z)
        x = x.view(k, B, self.input_size)
        return x

    def forward(self, x, k, s):

        self.B = x.size()[0]  #batch size
        # self.k = k  #number of z samples aka particles P
        # self.s = s  #number of W samples

        elbo1s = []
        logprobs = [[] for _ in range(5)]
        for i in range(s):

            Ws, logpW, logqW = self.sample_W()  #_ , [1], [1]

            mu, logvar = self.encode(x)  #[B,Z]
            z, logpz, logqz = self.sample_z(mu, logvar, k=k)  #[P,B,Z], [P,B]

            x_hat = self.decode(Ws, z)  #[P,B,X]
            logpx = log_bernoulli(x_hat, x)  #[P,B]

            elbo = logpx + logpz - logqz  #[P,B]
            if k > 1:
                max_ = torch.max(elbo, 0)[0]  #[B]
                elbo1 = torch.log(torch.mean(torch.exp(elbo - max_),
                                             0)) + max_  #[B]
            elbo = elbo + (logpW * .000001) - (logqW * self.qW_weight
                                               )  #[B], logp(x|W)p(w)/q(w)
            elbo1s.append(elbo)
            logprobs[0].append(torch.mean(logpx))
            logprobs[1].append(torch.mean(logpz))
            logprobs[2].append(torch.mean(logqz))
            logprobs[3].append(torch.mean(logpW))
            logprobs[4].append(torch.mean(logqW))

        elbo1s = torch.stack(elbo1s)  #[S,B]
        if s > 1:
            max_ = torch.max(elbo1s, 0)[0]  #[B]
            elbo1 = torch.log(torch.mean(torch.exp(elbo1s - max_),
                                         0)) + max_  #[B]

        elbo = torch.mean(elbo1s)  #[1]

        #for printing
        # logpx = torch.mean(logpx)
        # logpz = torch.mean(logpz)
        # logqz = torch.mean(logqz)
        # self.x_hat_sigmoid = F.sigmoid(x_hat)

        logprobs2 = [torch.mean(torch.stack(aa)) for aa in logprobs]

        return elbo, logprobs2[0], logprobs2[1], logprobs2[2], logprobs2[
            3], logprobs2[4]

    def reconstruct(self, x):

        Ws, logpW, logqW = self.sample_W()  #_ , [1], [1]

        mu, logvar = self.encode(x)  #[B,Z]
        z, logpz, logqz = self.sample_z(mu, logvar, k=1)  #[P,B,Z], [P,B]

        x_hat = self.decode(Ws, z)  #[P,B,X]

        return F.sigmoid(x_hat)

    def predictive_elbo(self, x, k, s):
        # No pW or qW

        self.B = x.size()[0]  #batch size
        # self.k = k  #number of z samples aka particles P
        # self.s = s  #number of W samples

        elbo1s = []
        for i in range(s):

            Ws, logpW, logqW = self.sample_W()  #_ , [1], [1]

            mu, logvar = self.encode(x)  #[B,Z]
            z, logpz, logqz = self.sample_z(mu, logvar, k=k)  #[P,B,Z], [P,B]

            x_hat = self.decode(Ws, z)  #[P,B,X]
            logpx = log_bernoulli(x_hat, x)  #[P,B]

            elbo = logpx + logpz - logqz  #[P,B]
            if k > 1:
                max_ = torch.max(elbo, 0)[0]  #[B]
                elbo = torch.log(torch.mean(torch.exp(elbo - max_),
                                            0)) + max_  #[B]
            # elbo1 = elbo1 #+ (logpW - logqW)*.00000001 #[B], logp(x|W)p(w)/q(w)
            elbo1s.append(elbo)

        elbo1s = torch.stack(elbo1s)  #[S,B]
        if s > 1:
            max_ = torch.max(elbo1s, 0)[0]  #[B]
            elbo1 = torch.log(torch.mean(torch.exp(elbo1s - max_),
                                         0)) + max_  #[B]

        elbo = torch.mean(elbo1s)  #[1]
        return elbo  #, logprobs2[0], logprobs2[1], logprobs2[2], logprobs2[3], logprobs2[4]
Example #6
0
class BVAE(nn.Module):
    def __init__(self):
        super(BVAE, self).__init__()

        if torch.cuda.is_available():
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor

        self.z_size = 20
        self.input_size = 784

        #Encoder
        self.fc1 = nn.Linear(self.input_size, 200)
        self.fc2 = nn.Linear(200, self.z_size * 2)
        #Decoder
        # self.fc3 = nn.Linear(self.z_size, 200)
        # self.fc4 = nn.Linear(200, 784)
        # self.decoder = BNN([self.z_size, 200, 784], [torch.nn.Softplus, torch.nn.Softplus])
        self.decoder = BNN([self.z_size, 200, 784], [F.relu, F.relu])

        # self.add_module('BNN', self.decoder)
        # for idx, m in enumerate(self.modules()):
        #     print(idx, '->', m)
        # fsdf

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = self.fc2(h1)
        mean = h2[:, :self.z_size]
        logvar = h2[:, self.z_size:]
        return mean, logvar

    def sample(self, mu, logvar, k):

        # if torch.cuda.is_available():
        #     eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_()).cuda() #[P,B,Z]
        # else:
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]

        z = eps.mul(torch.exp(.5 * logvar)) + mu  #[P,B,Z]

        # if torch.cuda.is_available():
        #     logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size).cuda()),
        #                     Variable(torch.zeros(self.B, self.z_size)).cuda())  #[P,B]
        # else:
        logpz = lognormal(
            z, Variable(torch.zeros(self.B, self.z_size).type(self.dtype)),
            Variable(torch.zeros(self.B,
                                 self.z_size)).type(self.dtype))  #[P,B]

        logqz = lognormal(z, mu, logvar)
        return z, logpz, logqz

    def decode(self, z):
        # h3 = F.relu(self.fc3(z))
        # return self.fc4(h3)

        z = z.view(-1, self.z_size)

        Ws, log_p_W_sum, log_q_W_sum = self.decoder.sample_weights()

        x = self.decoder.forward(Ws, z)

        x = x.view(self.k, self.B, self.input_size)

        return x, log_p_W_sum, log_q_W_sum

    def forward(self, x, k=1):
        self.k = k
        self.B = x.size()[0]
        mu, logvar = self.encode(x)
        z, logpz, logqz = self.sample(mu, logvar, k=k)
        x_hat, logpW, logqW = self.decode(z)

        logpx = log_bernoulli(x_hat, x)  #[P,B]

        elbo = logpx + logpz - logqz + (logpW - logqW) * .00000001  #[P,B]

        if k > 1:
            max_ = torch.max(elbo, 0)[0]  #[B]
            elbo = torch.log(torch.mean(torch.exp(elbo - max_),
                                        0)) + max_  #[B]

        elbo = torch.mean(elbo)  #[1]

        #for printing
        logpx = torch.mean(logpx)
        logpz = torch.mean(logpz)
        logqz = torch.mean(logqz)
        self.x_hat_sigmoid = F.sigmoid(x_hat)

        return elbo, logpx, logpz, logqz, logpW, logqW
class BVAE(nn.Module):
    def __init__(self, qW_weight, seed=-1):
        super(BVAE, self).__init__()

        if seed != -1:
            torch.manual_seed(seed)

        if torch.cuda.is_available():
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor
            
        self.qW_weight = qW_weight

        self.z_size = 2
        self.input_size = 784

        #Encoder
        self.fc1 = nn.Linear(self.input_size, 200)
        self.fc2 = nn.Linear(200, self.z_size*2)
        #Decoder
        self.decoder = BNN([self.z_size, 200, 784], [F.relu, F.relu])

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = self.fc2(h1)
        mean = h2[:,:self.z_size]
        logvar = h2[:,self.z_size:]
        return mean, logvar

    def sample_z(self, mu, logvar, k):
        B = mu.size()[0]
        eps = Variable(torch.FloatTensor(k, B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mu  #[P,B,Z]
        logpz = lognormal(z, Variable(torch.zeros(B, self.z_size).type(self.dtype)), 
                            Variable(torch.zeros(B, self.z_size)).type(self.dtype))  #[P,B]

        logqz = lognormal(z, mu, logvar)
        return z, logpz, logqz

    def sample_W(self):

        Ws, log_p_W_sum, log_q_W_sum = self.decoder.sample_weights()
        return Ws, log_p_W_sum, log_q_W_sum


    def decode(self, Ws, z):
        k = z.size()[0]
        B = z.size()[1]
        z = z.view(-1, self.z_size)
        x = self.decoder.forward(Ws, z)
        x = x.view(k, B, self.input_size)
        return x


    def forward(self, x, k, s):

        self.B = x.size()[0] #batch size
        # self.k = k  #number of z samples aka particles P
        # self.s = s  #number of W samples

        elbo1s = []
        logprobs = [[] for _ in range(5)]
        for i in range(s):

            Ws, logpW, logqW = self.sample_W()  #_ , [1], [1]

            mu, logvar = self.encode(x)  #[B,Z]
            z, logpz, logqz = self.sample_z(mu, logvar, k=k) #[P,B,Z], [P,B]

            x_hat = self.decode(Ws, z) #[P,B,X]
            logpx = log_bernoulli(x_hat, x)  #[P,B]

            elbo = logpx + logpz - logqz #[P,B]
            if k>1:
                max_ = torch.max(elbo, 0)[0] #[B]
                elbo1 = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]
            elbo = elbo + (logpW*.000001) - (logqW*self.qW_weight) #[B], logp(x|W)p(w)/q(w)
            elbo1s.append(elbo)
            logprobs[0].append(torch.mean(logpx))
            logprobs[1].append(torch.mean(logpz))
            logprobs[2].append(torch.mean(logqz))
            logprobs[3].append(torch.mean(logpW))
            logprobs[4].append(torch.mean(logqW))




        elbo1s = torch.stack(elbo1s) #[S,B]
        if s>1:
            max_ = torch.max(elbo1s, 0)[0] #[B]
            elbo1 = torch.log(torch.mean(torch.exp(elbo1s - max_), 0)) + max_ #[B]            

        elbo = torch.mean(elbo1s) #[1]

        #for printing
        # logpx = torch.mean(logpx)
        # logpz = torch.mean(logpz)
        # logqz = torch.mean(logqz)
        # self.x_hat_sigmoid = F.sigmoid(x_hat)

        logprobs2 = [torch.mean(torch.stack(aa)) for aa in logprobs]

        return elbo, logprobs2[0], logprobs2[1], logprobs2[2], logprobs2[3], logprobs2[4]

    def reconstruct(self, x):

        Ws, logpW, logqW = self.sample_W()  #_ , [1], [1]

        mu, logvar = self.encode(x)  #[B,Z]
        z, logpz, logqz = self.sample_z(mu, logvar, k=1) #[P,B,Z], [P,B]

        x_hat = self.decode(Ws, z) #[P,B,X]

        return F.sigmoid(x_hat)


    def predictive_elbo(self, x, k, s):
        # No pW or qW

        self.B = x.size()[0] #batch size
        # self.k = k  #number of z samples aka particles P
        # self.s = s  #number of W samples

        elbo1s = []
        for i in range(s):

            Ws, logpW, logqW = self.sample_W()  #_ , [1], [1]

            mu, logvar = self.encode(x)  #[B,Z]
            z, logpz, logqz = self.sample_z(mu, logvar, k=k) #[P,B,Z], [P,B]

            x_hat = self.decode(Ws, z) #[P,B,X]
            logpx = log_bernoulli(x_hat, x)  #[P,B]

            elbo = logpx + logpz - logqz #[P,B]
            if k>1:
                max_ = torch.max(elbo, 0)[0] #[B]
                elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]
            # elbo1 = elbo1 #+ (logpW - logqW)*.00000001 #[B], logp(x|W)p(w)/q(w)
            elbo1s.append(elbo)

        elbo1s = torch.stack(elbo1s) #[S,B]
        if s>1:
            max_ = torch.max(elbo1s, 0)[0] #[B]
            elbo1 = torch.log(torch.mean(torch.exp(elbo1s - max_), 0)) + max_ #[B]            

        elbo = torch.mean(elbo1s) #[1]
        return elbo#, logprobs2[0], logprobs2[1], logprobs2[2], logprobs2[3], logprobs2[4]
Example #8
0
class BVAE(nn.Module):
    def __init__(self):
        super(BVAE, self).__init__()

        if torch.cuda.is_available():
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor
            

        self.z_size = 20
        self.input_size = 784

        #Encoder
        self.fc1 = nn.Linear(self.input_size, 200)
        self.fc2 = nn.Linear(200, self.z_size*2)
        #Decoder
        # self.fc3 = nn.Linear(self.z_size, 200)
        # self.fc4 = nn.Linear(200, 784)
        # self.decoder = BNN([self.z_size, 200, 784], [torch.nn.Softplus, torch.nn.Softplus])
        self.decoder = BNN([self.z_size, 200, 784], [F.relu, F.relu])

        # self.add_module('BNN', self.decoder)
        # for idx, m in enumerate(self.modules()):
        #     print(idx, '->', m)
        # fsdf

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = self.fc2(h1)
        mean = h2[:,:self.z_size]
        logvar = h2[:,self.z_size:]
        return mean, logvar

    def sample(self, mu, logvar, k):
        
        # if torch.cuda.is_available():
        #     eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_()).cuda() #[P,B,Z]
        # else:
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]

        z = eps.mul(torch.exp(.5*logvar)) + mu  #[P,B,Z]

        # if torch.cuda.is_available():
        #     logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size).cuda()), 
        #                     Variable(torch.zeros(self.B, self.z_size)).cuda())  #[P,B]
        # else:
        logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size).type(self.dtype)), 
                            Variable(torch.zeros(self.B, self.z_size)).type(self.dtype))  #[P,B]


        logqz = lognormal(z, mu, logvar)
        return z, logpz, logqz

    def decode(self, z):
        # h3 = F.relu(self.fc3(z))
        # return self.fc4(h3)

        z = z.view(-1, self.z_size)

        Ws, log_p_W_sum, log_q_W_sum = self.decoder.sample_weights()

        x = self.decoder.forward(Ws, z)

        x = x.view(self.k, self.B, self.input_size)

        return x, log_p_W_sum, log_q_W_sum


    def forward(self, x, k=1):
        self.k = k
        self.B = x.size()[0]
        mu, logvar = self.encode(x)
        z, logpz, logqz = self.sample(mu, logvar, k=k)
        x_hat, logpW, logqW = self.decode(z)

        logpx = log_bernoulli(x_hat, x)  #[P,B]


        elbo = logpx + logpz - logqz + (logpW - logqW)*.00000001  #[P,B]

        if k>1:
            max_ = torch.max(elbo, 0)[0] #[B]
            elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]

        elbo = torch.mean(elbo) #[1]

        #for printing
        logpx = torch.mean(logpx)
        logpz = torch.mean(logpz)
        logqz = torch.mean(logqz)
        self.x_hat_sigmoid = F.sigmoid(x_hat)

        return elbo, logpx, logpz, logqz, logpW, logqW