def intermediate_dist(t, z, mean, logvar, zeros, batch): logp1 = lognormal(z, mean, logvar) #[P,B] log_prior = lognormal(z, zeros, zeros) #[P,B] log_likelihood = log_bernoulli(model.decode(z), batch) logpT = log_prior + log_likelihood log_intermediate_2 = (1 - float(t)) * logp1 + float(t) * logpT return log_intermediate_2
def intermediate_dist(t, z, mean, logvar, zeros, batch): logp1 = lognormal(z, mean, logvar) #[P,B] log_prior = lognormal(z, zeros, zeros) #[P,B] log_likelihood = log_bernoulli(model.decode(z), batch) logpT = log_prior + log_likelihood log_intermediate_2 = (1-float(t))*logp1 + float(t)*logpT return log_intermediate_2
def sample(self, mu, logvar, k): eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_()) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mu #[P,B,Z] logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size)), Variable(torch.zeros(self.B, self.z_size))) #[P,B] logqz = lognormal(z, mu, logvar) return z, logpz, logqz
def sample(self, mu, logvar, k): eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_()) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mu #[P,B,Z] logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size)), Variable(torch.zeros(self.B, self.z_size))) #[P,B] logqz = lognormal(z, mu, logvar) return z, logpz, logqz
def sample_z(self, mu, logvar, k): B = mu.size()[0] eps = Variable(torch.FloatTensor(k, B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mu #[P,B,Z] logpz = lognormal(z, Variable(torch.zeros(B, self.z_size).type(self.dtype)), Variable(torch.zeros(B, self.z_size)).type(self.dtype)) #[P,B] logqz = lognormal(z, mu, logvar) return z, logpz, logqz
def sample(self, mu, logvar): std = logvar.mul(0.5).exp_() eps = Variable(torch.FloatTensor(std.size()).normal_()) z = eps.mul(std).add_(mu) logpz = lognormal(z, Variable(torch.zeros(z.size())), Variable(torch.zeros(z.size()))) # logpz = self.lognormal(z, torch.zeros(z.size()), torch.zeros(z.size())) logqz = lognormal(z, mu, logvar) return z, logpz, logqz
def sample_weights(self): Ws = [] log_p_W_sum = 0 log_q_W_sum = 0 for layer_i in range(len(self.net) - 1): input_size_i = self.net[layer_i] + 1 #plus 1 for bias output_size_i = self.net[layer_i + 1] #plus 1 because we want layer i+1 #Get vars [I,O] W_means = self.W_means[layer_i] W_logvars = self.W_logvars[layer_i] #Sample weights [IS,OS]*[IS,OS]=[IS,OS] eps = Variable( torch.randn(input_size_i, output_size_i).type(self.dtype)) # print eps # print torch.sqrt(torch.exp(W_logvars)) # W = torch.add(W_means, torch.sqrt(torch.exp(W_logvars)) * eps) W = (torch.sqrt(torch.exp(W_logvars)) * eps) + W_means # W = W_means #Compute probs of samples [1] flat_w = W.view(input_size_i * output_size_i) #[IS*OS] flat_W_means = W_means.view(input_size_i * output_size_i) #[IS*OS] flat_W_logvars = W_logvars.view(input_size_i * output_size_i) #[IS*OS] log_p_W_sum += lognormal( flat_w, Variable( torch.zeros([input_size_i * output_size_i ]).type(self.dtype)), Variable( torch.log( torch.ones([input_size_i * output_size_i ]).type(self.dtype)))) # log_p_W_sum += log_normal3(flat_w, tf.zeros([input_size_i*output_size_i]), tf.log(tf.ones([input_size_i*output_size_i])*100.)) log_q_W_sum += lognormal(flat_w, flat_W_means, flat_W_logvars) Ws.append(W) return Ws, log_p_W_sum, log_q_W_sum
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] #Encode out = x for i in range(len(self.encoder_weights) - 1): out = self.act_func(self.encoder_weights[i](out)) # out = self.act_func(self.layer_norms[i].forward(self.encoder_weights[i](out))) out = self.encoder_weights[-1](out) mean = out[:, :self.z_size] logvar = out[:, self.z_size:] #Sample eps = Variable( torch.FloatTensor(k, self.B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] return z, logqz
def return_current_state(self, x, a, k): self.B = x.size()[0] self.T = x.size()[1] self.k = k a = a.float() x = x.float() states = [] prev_z = Variable(torch.zeros(k, self.B, self.z_size)) # prev_z = torch.zeros(k, self.B, self.z_size) for t in range(self.T): current_x = x[:,t] #[B,X] current_a = a[:,t] #[B,A] #Encode mu, logvar = self.encode(current_x, current_a, prev_z) #[P,B,Z] #Sample z, logqz = self.sample(mu, logvar) #[P,B,Z], [P,B] #Decode x_hat = self.decode(z) #[P,B,X] logpx = log_bernoulli(x_hat, current_x) #[P,B] #Transition/Prior prob prior_mean, prior_log_var = self.transition_prior(prev_z, current_a) #[P,B,Z] logpz = lognormal(z, prior_mean, prior_log_var) #[P,B] prev_z = z states.append(z) return states
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] self.P = k #Encode out = self.act_func(self.fc1(x)) out = self.act_func(self.fc2(out)) out = self.fc3(out) mean = out[:, :self.z_size] logvar = out[:, self.z_size:] #Sample eps = Variable( torch.FloatTensor(k, self.B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] logdetsum = 0. for i in range(self.n_flows): z, logdet = self.norm_flow(self.params[i], z) logdetsum += logdet return z, logqz - logdetsum
def forward(self, x, k, warmup=1.): self.B = x.size()[0] #batch size self.zeros = Variable( torch.zeros(self.B, self.z_size).type(self.dtype)) self.logposterior = lambda aa: lognormal( aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x) z, logqz = self.q_dist.forward(k, x, self.logposterior) logpxz = self.logposterior(z) #Compute elbo elbo = logpxz - (warmup * logqz) #[P,B] if k > 1: max_ = torch.max(elbo, 0)[0] #[B] elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] elbo = torch.mean(elbo) #[1] logpxz = torch.mean(logpxz) #[1] logqz = torch.mean(logqz) return elbo, logpxz, logqz
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] self.P = k #Encode out = x for i in range(len(self.encoder_weights)-1): out = self.act_func(self.encoder_weights[i](out)) out = self.encoder_weights[-1](out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] #Sample eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] logdetsum = 0. for i in range(self.n_flows): z, logdet = self.norm_flow(self.params[i],z) logdetsum += logdet return z, logqz-logdetsum
def sample(self, mu, logvar, k): B = mu.size()[0] eps = Variable(torch.FloatTensor(k, B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mu #[P,B,Z] logqz = lognormal(z, mu, logvar) #[P,B] #[P,B,Z], [P,B] if self.flow_bool: z, logdet = self.q_dist.forward(z) logqz = logqz - logdet logpz = lognormal(z, Variable(torch.zeros(B, self.z_size).type(self.dtype)), Variable(torch.zeros(B, self.z_size)).type(self.dtype)) #[P,B] return z, logpz, logqz
def logposterior_func(self, x, z): self.B = x.size()[0] #batch size self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) # print (x) #[B,X] # print(z) #[P,Z] z = Variable(z).type(self.dtype) z = z.view(-1,self.B,self.z_size) return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(z), x)
def logprob(self, z, mean, logvar): # self.B = mean.size()[0] # eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] # z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] return logqz
def sample(self, mean, logvar, k): self.B = mean.size()[0] eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] return z, logqz
def logposterior_func(self, x, z): self.B = x.size()[0] #batch size self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) # print (x) #[B,X] # print(z) #[P,Z] z = Variable(z).type(self.dtype) z = z.view(-1,self.B,self.z_size) return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.decode(z), x)
def mh_step(z0, v0, z, v, step_size, intermediate_dist_func): logpv0 = lognormal(v0, zeros, zeros) #[P,B] hamil_0 = intermediate_dist_func(z0) + logpv0 logpvT = lognormal(v, zeros, zeros) #[P,B] hamil_T = intermediate_dist_func(z) + logpvT accept_prob = torch.exp(hamil_T - hamil_0) if torch.cuda.is_available(): rand_uni = Variable(torch.FloatTensor( accept_prob.size()).uniform_(), volatile=volatile_, requires_grad=requires_grad).cuda() else: rand_uni = Variable( torch.FloatTensor(accept_prob.size()).uniform_()) accept = accept_prob > rand_uni if torch.cuda.is_available(): accept = accept.type(torch.FloatTensor).cuda() else: accept = accept.type(torch.FloatTensor) accept = accept.view(k, model.B, 1) z = (accept * z) + ((1 - accept) * z0) #Adapt step size avg_acceptance_rate = torch.mean(accept) if avg_acceptance_rate.cpu().data.numpy() > .7: step_size = 1.02 * step_size else: step_size = .98 * step_size if step_size < 0.0001: step_size = 0.0001 if step_size > 0.5: step_size = 0.5 return z, step_size
def sample_q(self, x, k): self.B = x.size()[0] #batch size self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x) z, logqz = self.q_dist.forward(k=k, x=x, logposterior=self.logposterior) return z
def sample(self, mu, logvar, k): # if torch.cuda.is_available(): # eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_()).cuda() #[P,B,Z] # else: eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mu #[P,B,Z] # if torch.cuda.is_available(): # logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size).cuda()), # Variable(torch.zeros(self.B, self.z_size)).cuda()) #[P,B] # else: logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size).type(self.dtype)), Variable(torch.zeros(self.B, self.z_size)).type(self.dtype)) #[P,B] logqz = lognormal(z, mu, logvar) return z, logpz, logqz
def sample(self, mu, logvar, k): B = mu.size()[0] eps = Variable( torch.FloatTensor(k, B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mu #[P,B,Z] logqz = lognormal(z, mu, logvar) #[P,B] #[P,B,Z], [P,B] if self.flow_bool: z, logdet = self.q_dist.forward(z) logqz = logqz - logdet logpz = lognormal( z, Variable(torch.zeros(B, self.z_size).type(self.dtype)), Variable(torch.zeros(B, self.z_size)).type(self.dtype)) #[P,B] return z, logpz, logqz
def sample_q(self, x, k): self.B = x.size()[0] #batch size self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(aa), x) z, logqz = self.q_dist.forward(k=k, x=x, logposterior=self.logposterior) return z
def sample(self, mean, logvar, k): self.B = mean.size()[0] eps = Variable( torch.FloatTensor(k, self.B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] return z, logqz
def mh_step(z0, v0, z, v, step_size, intermediate_dist_func): logpv0 = lognormal(v0, zeros, zeros) #[P,B] hamil_0 = intermediate_dist_func(z0) + logpv0 logpvT = lognormal(v, zeros, zeros) #[P,B] hamil_T = intermediate_dist_func(z) + logpvT accept_prob = torch.exp(hamil_T - hamil_0) if torch.cuda.is_available(): rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_(), volatile=volatile_, requires_grad=requires_grad).cuda() else: rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_()) accept = accept_prob > rand_uni if torch.cuda.is_available(): accept = accept.type(torch.FloatTensor).cuda() else: accept = accept.type(torch.FloatTensor) accept = accept.view(k, model.B, 1) z = (accept * z) + ((1-accept) * z0) #Adapt step size avg_acceptance_rate = torch.mean(accept) if avg_acceptance_rate.cpu().data.numpy() > .65: step_size = 1.02 * step_size else: step_size = .98 * step_size if step_size < 0.0001: step_size = 0.0001 if step_size > 0.5: step_size = 0.5 return z, step_size
def sample(self, mu, logvar, k): # if torch.cuda.is_available(): # eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_()).cuda() #[P,B,Z] # else: eps = Variable( torch.FloatTensor(k, self.B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mu #[P,B,Z] # if torch.cuda.is_available(): # logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size).cuda()), # Variable(torch.zeros(self.B, self.z_size)).cuda()) #[P,B] # else: logpz = lognormal( z, Variable(torch.zeros(self.B, self.z_size).type(self.dtype)), Variable(torch.zeros(self.B, self.z_size)).type(self.dtype)) #[P,B] logqz = lognormal(z, mu, logvar) return z, logpz, logqz
def sample_weights(self): Ws = [] log_p_W_sum = 0 log_q_W_sum = 0 for layer_i in range(len(self.net)-1): input_size_i = self.net[layer_i]+1 #plus 1 for bias output_size_i = self.net[layer_i+1] #plus 1 because we want layer i+1 #Get vars [I,O] W_means = self.W_means[layer_i] W_logvars = self.W_logvars[layer_i] #Sample weights [IS,OS]*[IS,OS]=[IS,OS] eps = Variable(torch.randn(input_size_i, output_size_i).type(self.dtype)) # print eps # print torch.sqrt(torch.exp(W_logvars)) # W = torch.add(W_means, torch.sqrt(torch.exp(W_logvars)) * eps) W = (torch.sqrt(torch.exp(W_logvars)) * eps) + W_means # W = W_means #Compute probs of samples [1] flat_w = W.view(input_size_i*output_size_i) #[IS*OS] flat_W_means = W_means.view(input_size_i*output_size_i) #[IS*OS] flat_W_logvars = W_logvars.view(input_size_i*output_size_i) #[IS*OS] log_p_W_sum += lognormal(flat_w, Variable(torch.zeros([input_size_i*output_size_i]).type(self.dtype)), Variable(torch.log(torch.ones([input_size_i*output_size_i]).type(self.dtype)))) # log_p_W_sum += log_normal3(flat_w, tf.zeros([input_size_i*output_size_i]), tf.log(tf.ones([input_size_i*output_size_i])*100.)) log_q_W_sum += lognormal(flat_w, flat_W_means, flat_W_logvars) Ws.append(W) return Ws, log_p_W_sum, log_q_W_sum
def forward(self, x, k=1, warmup=1.): self.B = x.size()[0] #batch size self.zeros = Variable( torch.zeros(self.B, self.z_size).type(self.dtype)) #[B,Z] self.logposterior = lambda aa: lognormal( aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x) z, logqz = self.q_dist.forward(k, x, self.logposterior) # [PB,Z] # z = z.view(-1,self.z_size) logpxz = self.logposterior(z) #Compute elbo elbo = logpxz - logqz #[P,B] if k > 1: max_ = torch.max(elbo, 0)[0] #[B] elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] elbo = torch.mean(elbo) #[1] logpxz = torch.mean(logpxz) #[1] logqz = torch.mean(logqz) # mu, logvar = self.encode(x) # z, logpz, logqz = self.sample(mu, logvar, k=k) #[P,B,Z] # x_hat = self.decode(z) #[PB,X] # x_hat = x_hat.view(k, self.B, -1) # logpx = log_bernoulli(x_hat, x) #[P,B] # elbo = logpx + warmup*(logpz - logqz) #[P,B] # if k>1: # max_ = torch.max(elbo, 0)[0] #[B] # elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] # elbo = torch.mean(elbo) #[1] # #for printing # logpx = torch.mean(logpx) # logpz = torch.mean(logpz) # logqz = torch.mean(logqz) # self.x_hat_sigmoid = F.sigmoid(x_hat) # return elbo, logpx, logpz, logqz return elbo, logpxz, logqz
def sample(k): P = k #Sample eps = Variable(torch.FloatTensor(k, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] logdetsum = 0. for i in range(n_flows): z, logdet = norm_flow(params[i],z) logdetsum += logdet logq = logqz - logdetsum return z, logq
def forward2(self, x, k): self.B = x.size()[0] #batch size self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x) z, logqz = self.q_dist.forward(k, x, self.logposterior) logpxz = self.logposterior(z) #Compute elbo elbo = logpxz - logqz #[P,B] # if k>1: # max_ = torch.max(elbo, 0)[0] #[B] # elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] elbo = torch.mean(elbo) #[1] logpxz = torch.mean(logpxz) #[1] logqz = torch.mean(logqz) return elbo, logpxz, logqz
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] #Encode out = self.act_func(self.fc1(x)) out = self.act_func(self.fc2(out)) out = self.fc3(out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] #Sample eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] return z, logqz
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] # #Encode # out = x # for i in range(len(self.encoder_weights)-1): # out = self.act_func(self.encoder_weights[i](out)) # out = self.encoder_weights[-1](out) # mean = out[:,:self.z_size] # logvar = out[:,self.z_size:] x = x.view(-1, 3, 32, 32) x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) # print (x) # x = x.view(-1, 1960) x = x.view(-1, 250) h1 = F.relu(self.fc1(x)) h2 = self.fc2(h1) mean = h2[:, :self.z_size] logvar = h2[:, self.z_size:] #Sample eps = Variable( torch.FloatTensor(k, self.B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] return z, logqz
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] #Encode out = self.act_func(self.fc1(x)) out = self.act_func(self.fc2(out)) out = self.fc3(out) mean = out[:, :self.z_size] logvar = out[:, self.z_size:] #Sample eps = Variable( torch.FloatTensor(k, self.B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] return z, logqz
def forward(self, x, a, k=1, current_state=None): ''' x: [B,T,X] a: [B,T,A] output: elbo scalar ''' self.B = x.size()[0] self.T = x.size()[1] self.k = k a = a.float() x = x.float() # log_probs = [[] for i in range(k)] # log_probs = [] logpxs = [] logpzs = [] logqzs = [] weights = Variable(torch.ones(k, self.B)/k) # if current_state==None: prev_z = Variable(torch.zeros(k, self.B, self.z_size)) # else: # prev_z = current_state for t in range(self.T): current_x = x[:,t] #[B,X] current_a = a[:,t] #[B,A] #Encode mu, logvar = self.encode(current_x, current_a, prev_z) #[P,B,Z] #Sample z, logqz = self.sample(mu, logvar) #[P,B,Z], [P,B] #Decode x_hat = self.decode(z) #[P,B,X] logpx = log_bernoulli(x_hat, current_x) #[P,B] #Transition/Prior prob prior_mean, prior_log_var = self.transition_prior(prev_z, current_a) #[P,B,Z] logpz = lognormal(z, prior_mean, prior_log_var) #[P,B] log_alpha_t = logpx + logpz - logqz #[P,B] log_weights_tmp = torch.log(weights * torch.exp(log_alpha_t)) max_ = torch.max(log_weights_tmp, 0)[0] #[B] log_p_hat = torch.log(torch.sum(torch.exp(log_weights_tmp - max_), 0)) + max_ #[B] # p_hat = torch.sum(alpha_t,0) #[B] normalized_alpha_t = log_weights_tmp - log_p_hat #[P,B] weights = torch.exp(normalized_alpha_t) #[P,B] #if resample if t%2==0: # print weights #[B,P] indices of the particles for each bactch sampled_indices = torch.multinomial(torch.t(weights), k, replacement=True).detach() new_z = [] for b in range(self.B): tmp = z[:,b] #[P,Z] z_b = tmp[sampled_indices[b]] #[P,Z] new_z.append(z_b) new_z = torch.stack(new_z, 1) #[P,B,Z] weights = Variable(torch.ones(k, self.B)/k) z = new_z logpxs.append(logpx) logpzs.append(logpz) logqzs.append(logqz) # log_probs.append(logpx + logpz - logqz) prev_z = z logpxs = torch.stack(logpxs) logpzs = torch.stack(logpzs) logqzs = torch.stack(logqzs) #[T,P,B] logws = logpxs + logpzs - logqzs #[T,P,B] logws = torch.mean(logws, 0) #[P,B] # elbo = logpx + logpz - logqz #[P,B] if k>1: max_ = torch.max(logws, 0)[0] #[B] elbo = torch.log(torch.mean(torch.exp(logws - max_), 0)) + max_ #[B] elbo = torch.mean(elbo) #over batch else: elbo = torch.mean(logws) # print log_probs[0] # #for printing logpx = torch.mean(logpxs) logpz = torch.mean(logpzs) logqz = torch.mean(logqzs) # self.x_hat_sigmoid = F.sigmoid(x_hat) # elbo = torch.mean(torch.stack(log_probs)) #[1] # elbo = logpx + logpz - logqz return elbo, logpx, logpz, logqz
def optimize_local_gaussian_mean_logvar2(logposterior, model, x): # B = x.shape[0] B = x.size()[0] #batch size # input to log posterior is z, [P,B,Z] # I think B will be 1 for now mean = Variable(torch.zeros(B, model.z_size).type(model.dtype), requires_grad=True) logvar = Variable(torch.zeros(B, model.z_size).type(model.dtype), requires_grad=True) optimizer = optim.Adam([mean, logvar], lr=.001) # time_ = time.time() # n_data = len(train_x) # arr = np.array(range(n_data)) P = 50 last_100 = [] best_last_100_avg = -1 consecutive_worse = 0 for epoch in range(1, 99999): # 999999): if quick: # if 1: break #Sample eps = Variable( torch.FloatTensor(P, B, model.z_size).normal_().type( model.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] logpx = logposterior(z) loss = -(torch.mean(1.5 * logpx - logqz)) optimizer.zero_grad() loss.backward() optimizer.step() loss_np = loss.data.cpu().numpy() last_100.append(loss_np) if epoch % 100 == 0: last_100_avg = np.mean(last_100) if last_100_avg < best_last_100_avg or best_last_100_avg == -1: consecutive_worse = 0 best_last_100_avg = last_100_avg else: consecutive_worse += 1 # print(consecutive_worse) if consecutive_worse > 10: # print ('done') break print(epoch, last_100_avg, consecutive_worse, mean) # print (torch.mean(logpx)) last_100 = [] if epoch % 1000 == 0: # print (logpx) # print (logqz) print(torch.mean(logpx)) print(torch.mean(logqz)) print(torch.std(logpx)) print(torch.std(logqz)) #Round 2 last_100 = [] best_last_100_avg = -1 consecutive_worse = 0 for epoch in range(1, 99999): # 999999): if quick: # if 1: break #Sample eps = Variable( torch.FloatTensor(P, B, model.z_size).normal_().type( model.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] logpx = logposterior(z) loss = -(torch.mean(logpx - logqz)) optimizer.zero_grad() loss.backward() optimizer.step() loss_np = loss.data.cpu().numpy() last_100.append(loss_np) if epoch % 100 == 0: last_100_avg = np.mean(last_100) if last_100_avg < best_last_100_avg or best_last_100_avg == -1: consecutive_worse = 0 best_last_100_avg = last_100_avg else: consecutive_worse += 1 # print(consecutive_worse) if consecutive_worse > 10: # print ('done') break print(epoch, last_100_avg, consecutive_worse, mean, '2') # print (torch.mean(logpx)) last_100 = [] if epoch % 1000 == 0: # print (logpx) # print (logqz) print(torch.mean(logpx)) print(torch.mean(logqz)) print(torch.std(logpx)) print(torch.std(logqz)) return mean, logvar, z
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] self.P = k # print (self.B, 'B') # print (k) # fsdaf #q(v|x) out = x for i in range(len(self.qv_weights)-1): out = self.act_func(self.qv_weights[i](out)) out = self.qv_weights[-1](out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] #Sample v0 eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] v = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqv0 = lognormal(v, mean, logvar) #[P,B] #[PB,Z] v = v.view(-1,self.z_size) #[PB,X] x_tiled = x.repeat(k,1) #[PB,X+Z] # print (x_tiled.size()) # print (v.size()) xv = torch.cat((x_tiled, v),1) #q(z|x,v) out = xv for i in range(len(self.qz_weights)-1): out = self.act_func(self.qz_weights[i](out)) out = self.qz_weights[-1](out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] self.B = x.size()[0] # print (self.B, 'B') #Sample z0 eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] # print (eps.size(),'eps') # print (mean.size(),'mean') # print (self.P, 'P') # print (mean) mean = mean.contiguous().view(self.P,self.B,self.z_size) logvar = logvar.contiguous().view(self.P,self.B,self.z_size) # print (mean) # mean = mean.contiguous().view(self.P,1,self.z_size) # logvar = logvar.contiguous().view(self.P,1,self.z_size) # print (mean.size(),'mean') z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] # print (z.size(),'z') # mean = mean.contiguous().view(self.P*self.B,self.z_size) # logvar = logvar.contiguous().view(self.P*self.B,self.z_size) logqz0 = lognormal333(z, mean, logvar) #[P,B] #[PB,Z] z = z.view(-1,self.z_size) # print (z.size()) logdetsum = 0. for i in range(self.n_flows): z, v, logdet = self.norm_flow(self.params[i],z,v) logdetsum += logdet xz = torch.cat((x_tiled,z),1) #r(vT|x,zT) out = xz for i in range(len(self.rv_weights)-1): out = self.act_func(self.rv_weights[i](out)) out = self.rv_weights[-1](out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] mean = mean.contiguous().view(self.P,self.B,self.z_size) logvar = logvar.contiguous().view(self.P,self.B,self.z_size) v = v.view(k,self.B,self.z_size) logrvT = lognormal333(v, mean, logvar) #[P,B] z = z.view(k,self.B,self.z_size) # print(logqz0.size(), 'here') # print(logqv0.size()) # print(logdetsum.size()) # print(logrvT.size()) logdetsum = logdetsum.view(k,self.B) # print (logqz0+logqv0-logdetsum-logrvT) # fadfdsa return z, logqz0+logqv0-logdetsum-logrvT
def optimize_local_gaussian(logposterior, model, x): # print_ = 0 # B = x.shape[0] B = x.size()[0] #batch size # input to log posterior is z, [P,B,Z] # I think B will be 1 for now # self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) mean = Variable(torch.zeros(B, model.z_size).type(model.dtype), requires_grad=True) logvar = Variable(torch.zeros(B, model.z_size).type(model.dtype), requires_grad=True) optimizer = optim.Adam([mean, logvar], lr=.001) # time_ = time.time() # n_data = len(train_x) # arr = np.array(range(n_data)) P = 50 last_100 = [] best_last_100_avg = -1 consecutive_worse = 0 for epoch in range(1, 999999): #Sample eps = Variable(torch.FloatTensor(P, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] logpx = logposterior(z) # print (logpx) # print (logqz) # fsda # data_index= 0 # for i in range(int(n_data/batch_size)): # batch = train_x[data_index:data_index+batch_size] # data_index += batch_size # batch = Variable(torch.from_numpy(batch)).type(self.dtype) optimizer.zero_grad() # elbo, logpxz, logqz = self.forward(batch, k=k) loss = -(torch.mean(logpx-logqz)) loss_np = loss.data.cpu().numpy() # print (epoch, loss_np) # fasfaf loss.backward() optimizer.step() last_100.append(loss_np) if epoch % 100 ==0: last_100_avg = np.mean(last_100) if last_100_avg< best_last_100_avg or best_last_100_avg == -1: consecutive_worse=0 best_last_100_avg = last_100_avg else: consecutive_worse +=1 # print(consecutive_worse) if consecutive_worse> 10: # print ('done') break if epoch % 2000 ==0: print (epoch, last_100_avg, consecutive_worse)#,mean) # print (torch.mean(logpx)) last_100 = [] # break # if epoch%display_epoch==0: # print ('Train Epoch: {}/{}'.format(epoch, epochs), # 'LL:{:.3f}'.format(-loss.data[0]), # 'logpxz:{:.3f}'.format(logpxz.data[0]), # # 'logpz:{:.3f}'.format(logpz.data[0]), # 'logqz:{:.3f}'.format(logqz.data[0]), # 'T:{:.2f}'.format(time.time()-time_), # ) # time_ = time.time() # Compute VAE and IWAE bounds #Sample eps = Variable(torch.FloatTensor(1000, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] # print (logqz) # fad logpx = logposterior(z) elbo = logpx-logqz #[P,B] vae = torch.mean(elbo) max_ = torch.max(elbo, 0)[0] #[B] elbo_ = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] iwae = torch.mean(elbo_) return vae, iwae
def forward(self, x, a, k=1, current_state=None): ''' x: [B,T,X] a: [B,T,A] output: elbo scalar ''' self.B = x.size()[0] self.T = x.size()[1] self.k = k a = a.float() x = x.float() # log_probs = [[] for i in range(k)] # log_probs = [] logpxs = [] logpzs = [] logqzs = [] # if current_state==None: prev_z = Variable(torch.zeros(k, self.B, self.z_size)) # else: # prev_z = current_state for t in range(self.T): current_x = x[:,t] #[B,X] current_a = a[:,t] #[B,A] #Encode mu, logvar = self.encode(current_x, current_a, prev_z) #[P,B,Z] #Sample z, logqz = self.sample(mu, logvar) #[P,B,Z], [P,B] #Decode x_hat = self.decode(z) #[P,B,X] logpx = log_bernoulli(x_hat, current_x) #[P,B] #Transition/Prior prob prior_mean, prior_log_var = self.transition_prior(prev_z, current_a) #[P,B,Z] logpz = lognormal(z, prior_mean, prior_log_var) #[P,B] logpxs.append(logpx) logpzs.append(logpz) logqzs.append(logqz) # log_probs.append(logpx + logpz - logqz) prev_z = z logpxs = torch.stack(logpxs) logpzs = torch.stack(logpzs) logqzs = torch.stack(logqzs) #[T,P,B] logws = logpxs + logpzs - logqzs #[T,P,B] logws = torch.mean(logws, 0) #[P,B] # elbo = logpx + logpz - logqz #[P,B] if k>1: max_ = torch.max(logws, 0)[0] #[B] elbo = torch.log(torch.mean(torch.exp(logws - max_), 0)) + max_ #[B] elbo = torch.mean(elbo) #over batch else: elbo = torch.mean(logws) # print log_probs[0] # #for printing logpx = torch.mean(logpxs) logpz = torch.mean(logpzs) logqz = torch.mean(logqzs) # self.x_hat_sigmoid = F.sigmoid(x_hat) # elbo = torch.mean(torch.stack(log_probs)) #[1] # elbo = logpx + logpz - logqz return elbo, logpx, logpz, logqz
def sample(k): P = k # #Sample # eps = Variable(torch.FloatTensor(P, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z] # z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] # logqz = lognormal(z, mean, logvar) #[P,B] # logpx = logposterior(z) # optimizer.zero_grad() #q(v|x) # out = x # for i in range(len(self.qv_weights)-1): # out = self.act_func(self.qv_weights[i](out)) # out = self.qv_weights[-1](out) # mean = out[:,:self.z_size] # logvar = out[:,self.z_size:] #Sample v0 eps = Variable(torch.FloatTensor(k, B, z_size).normal_().type(model.dtype)) #[P,B,Z] # print (eps) v = eps.mul(torch.exp(.5*logvar_v)) + mean_v #[P,B,Z] logqv0 = lognormal(v, mean_v, logvar_v) #[P,B] #[PB,Z] v = v.view(-1,model.z_size) # print (v) # fsaf # print(v) # fasd #[PB,X] # x_tiled = x.repeat(k,1) #[PB,X+Z] # xv = torch.cat((x_tiled, v),1) #q(z|x,v) out = v for i in range(len(qz_weights)-1): out = act_func(qz_weights[i](out)) out = qz_weights[-1](out) mean = out[:,:z_size] logvar = out[:,z_size:] + 5. # print (mean) # B = x.size()[0] # print (self.B, 'B') #Sample z0 eps = Variable(torch.FloatTensor(k, B, z_size).normal_().type(model.dtype)) #[P,B,Z] # print (mean) mean = mean.contiguous().view(P,B,model.z_size) logvar = logvar.contiguous().view(P,B,model.z_size) z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] # print (z.size(),'z') # mean = mean.contiguous().view(P*B,model.z_size) # logvar = logvar.contiguous().view(P*B,model.z_size) # print (z) # fad logqz0 = lognormal333(z, mean, logvar) #[P,B] #[PB,Z] z = z.view(-1,z_size) logdetsum = 0. for i in range(n_flows): z, v, logdet = norm_flow(params[i],z,v) logdetsum += logdet # xz = torch.cat((x_tiled,z),1) #r(vT|x,zT) out = z for i in range(len(rv_weights)-1): out = act_func(rv_weights[i](out)) out = rv_weights[-1](out) mean = out[:,:model.z_size] logvar = out[:,model.z_size:] mean = mean.contiguous().view(P,B,model.z_size) logvar = logvar.contiguous().view(P,B,model.z_size) v = v.view(k,B,model.z_size) logrvT = lognormal333(v, mean, logvar) #[P,B] z = z.view(k,B,model.z_size) logq = logqz0+logqv0-logdetsum-logrvT # print (torch.mean(logqz0),torch.mean(logqv0),torch.mean(logdetsum),torch.mean(logrvT)) return z, logq
def sample(k): P = k # #Sample # eps = Variable(torch.FloatTensor(P, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z] # z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] # logqz = lognormal(z, mean, logvar) #[P,B] # logpx = logposterior(z) # optimizer.zero_grad() #q(v|x) # out = x # for i in range(len(self.qv_weights)-1): # out = self.act_func(self.qv_weights[i](out)) # out = self.qv_weights[-1](out) # mean = out[:,:self.z_size] # logvar = out[:,self.z_size:] #Sample v0 eps = Variable( torch.FloatTensor(k, B, z_size).normal_().type(model.dtype)) #[P,B,Z] # print (eps) v = eps.mul(torch.exp(.5 * logvar_v)) + mean_v #[P,B,Z] logqv0 = lognormal(v, mean_v, logvar_v) #[P,B] #[PB,Z] v = v.view(-1, model.z_size) # print (v) # fsaf # print(v) # fasd #[PB,X] # x_tiled = x.repeat(k,1) #[PB,X+Z] # xv = torch.cat((x_tiled, v),1) #q(z|x,v) out = v for i in range(len(qz_weights) - 1): out = act_func(qz_weights[i](out)) out = qz_weights[-1](out) mean = out[:, :z_size] logvar = out[:, z_size:] + 5. # print (mean) # B = x.size()[0] # print (self.B, 'B') #Sample z0 eps = Variable( torch.FloatTensor(k, B, z_size).normal_().type(model.dtype)) #[P,B,Z] # print (mean) mean = mean.contiguous().view(P, B, model.z_size) logvar = logvar.contiguous().view(P, B, model.z_size) z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] # print (z.size(),'z') # mean = mean.contiguous().view(P*B,model.z_size) # logvar = logvar.contiguous().view(P*B,model.z_size) # print (z) # fad logqz0 = lognormal333(z, mean, logvar) #[P,B] #[PB,Z] z = z.view(-1, z_size) logdetsum = 0. for i in range(n_flows): z, v, logdet = norm_flow(params[i], z, v) logdetsum += logdet # xz = torch.cat((x_tiled,z),1) #r(vT|x,zT) out = z for i in range(len(rv_weights) - 1): out = act_func(rv_weights[i](out)) out = rv_weights[-1](out) mean = out[:, :model.z_size] logvar = out[:, model.z_size:] mean = mean.contiguous().view(P, B, model.z_size) logvar = logvar.contiguous().view(P, B, model.z_size) v = v.view(k, B, model.z_size) logrvT = lognormal333(v, mean, logvar) #[P,B] z = z.view(k, B, model.z_size) logq = logqz0 + logqv0 - logdetsum - logrvT # print (torch.mean(logqz0),torch.mean(logqv0),torch.mean(logdetsum),torch.mean(logrvT)) return z, logq
def optimize_local_gaussian(logposterior, model, x): # B = x.shape[0] B = x.size()[0] #batch size # input to log posterior is z, [P,B,Z] # I think B will be 1 for now # self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype)) mean = Variable(torch.zeros(B, model.z_size).type(model.dtype), requires_grad=True) logvar = Variable(torch.zeros(B, model.z_size).type(model.dtype), requires_grad=True) optimizer = optim.Adam([mean, logvar], lr=.001) # time_ = time.time() # n_data = len(train_x) # arr = np.array(range(n_data)) P = 50 last_100 = [] best_last_100_avg = -1 consecutive_worse = 0 for epoch in range(1, 999999): #Sample eps = Variable( torch.FloatTensor(P, B, model.z_size).normal_().type( model.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] logpx = logposterior(z) # print (logpx) # print (logqz) # fsda # data_index= 0 # for i in range(int(n_data/batch_size)): # batch = train_x[data_index:data_index+batch_size] # data_index += batch_size # batch = Variable(torch.from_numpy(batch)).type(self.dtype) optimizer.zero_grad() # elbo, logpxz, logqz = self.forward(batch, k=k) loss = -(torch.mean(logpx - logqz)) loss_np = loss.data.cpu().numpy() # print (epoch, loss_np) # fasfaf loss.backward() optimizer.step() last_100.append(loss_np) if epoch % 100 == 0: last_100_avg = np.mean(last_100) if last_100_avg < best_last_100_avg or best_last_100_avg == -1: consecutive_worse = 0 best_last_100_avg = last_100_avg else: consecutive_worse += 1 # print(consecutive_worse) if consecutive_worse > 10: # print ('done') break print(epoch, last_100_avg, consecutive_worse) #,mean) # print (torch.mean(logpx)) last_100 = [] # break # if epoch%display_epoch==0: # print ('Train Epoch: {}/{}'.format(epoch, epochs), # 'LL:{:.3f}'.format(-loss.data[0]), # 'logpxz:{:.3f}'.format(logpxz.data[0]), # # 'logpz:{:.3f}'.format(logpz.data[0]), # 'logqz:{:.3f}'.format(logqz.data[0]), # 'T:{:.2f}'.format(time.time()-time_), # ) # time_ = time.time() # Compute VAE and IWAE bounds #Sample eps = Variable( torch.FloatTensor(1000, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqz = lognormal(z, mean, logvar) #[P,B] # print (logqz) # fad logpx = logposterior(z) elbo = logpx - logqz #[P,B] vae = torch.mean(elbo) max_ = torch.max(elbo, 0)[0] #[B] elbo_ = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] iwae = torch.mean(elbo_) return vae, iwae
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] self.P = k #q(v|x) out = self.act_func(self.fc1(x)) out = self.act_func(self.fc2(out)) out = self.fc3(out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] #Sample v0 eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] v = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqv0 = lognormal(v, mean, logvar) #[P,B] #[PB,Z] v = v.view(-1,self.z_size) #[PB,X] x_tiled = x.repeat(k,1) #[PB,X+Z] # print (x_tiled.size()) # print (v.size()) xv = torch.cat((x_tiled, v),1) #q(z|x,v) out = self.act_func(self.fc4(xv)) out = self.act_func(self.fc5(out)) out = self.fc6(out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] #Sample z0 eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqz0 = lognormal(z, mean, logvar) #[P,B] #[PB,Z] z = z.view(-1,self.z_size) logdetsum = 0. for i in range(self.n_flows): z, v, logdet = self.norm_flow(self.params[i],z,v) logdetsum += logdet xz = torch.cat((x_tiled,z),1) #r(vT|x,zT) out = self.act_func(self.fc7(xz)) out = self.act_func(self.fc8(out)) out = self.fc9(out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] v = v.view(k,self.B,self.z_size) logrvT = lognormal(v, mean, logvar) #[P,B] z = z.view(k,self.B,self.z_size) return z, logqz0+logqv0-logdetsum-logrvT
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] self.P = k if torch.cuda.is_available(): self.grad_outputs = torch.ones(k, self.B).cuda() else: self.grad_outputs = torch.ones(k, self.B) #q(v|x) out = x for i in range(len(self.qv_weights)-1): out = self.act_func(self.qv_weights[i](out)) out = self.qv_weights[-1](out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] #Sample v0 eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] v = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] logqv0 = lognormal(v, mean, logvar) #[P,B] #[PB,Z] v = v.view(-1,self.z_size) #[PB,X] x_tiled = x.repeat(k,1) #[PB,X+Z] # print (x_tiled.size()) # print (v.size()) xv = torch.cat((x_tiled, v),1) #q(z|x,v) out = xv for i in range(len(self.qz_weights)-1): out = self.act_func(self.qz_weights[i](out)) out = self.qz_weights[-1](out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] #Sample z0 eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z] mean = mean.contiguous().view(self.P,self.B,self.z_size) logvar = logvar.contiguous().view(self.P,self.B,self.z_size) z = eps.mul(torch.exp(.5*logvar)) + mean #[P,B,Z] mean = mean.contiguous().view(self.P*self.B,self.z_size) logvar = logvar.contiguous().view(self.P*self.B,self.z_size) logqz0 = lognormal(z, mean, logvar) #[P,B] #[PB,Z] z = z.view(-1,self.z_size) logdetsum = 0. for i in range(self.n_flows): z, v, logdet = self.norm_flow(self.params[i],z,v,logposterior) logdetsum += logdet xz = torch.cat((x_tiled,z),1) #r(vT|x,zT) out = xz for i in range(len(self.rv_weights)-1): out = self.act_func(self.rv_weights[i](out)) out = self.rv_weights[-1](out) mean = out[:,:self.z_size] logvar = out[:,self.z_size:] v = v.view(k,self.B,self.z_size) logrvT = lognormal(v, mean, logvar) #[P,B] z = z.view(k,self.B,self.z_size) return z, logqz0+logqv0-logdetsum-logrvT
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] self.P = k if torch.cuda.is_available(): self.grad_outputs = torch.ones(k, self.B).cuda() else: self.grad_outputs = torch.ones(k, self.B) #q(v|x) out = self.act_func(self.fc1(x)) out = self.act_func(self.fc2(out)) out = self.fc3(out) mean = out[:, :self.z_size] logvar = out[:, self.z_size:] #Sample v0 eps = Variable( torch.FloatTensor(k, self.B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] v = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqv0 = lognormal(v, mean, logvar) #[P,B] #[PB,Z] v = v.view(-1, self.z_size) #[PB,X] x_tiled = x.repeat(k, 1) #[PB,X+Z] # print (x_tiled.size()) # print (v.size()) xv = torch.cat((x_tiled, v), 1) #q(z|x,v) out = self.act_func(self.fc4(xv)) out = self.act_func(self.fc5(out)) out = self.fc6(out) mean = out[:, :self.z_size] logvar = out[:, self.z_size:] #Sample z0 eps = Variable( torch.FloatTensor(k, self.B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqz0 = lognormal(z, mean, logvar) #[P,B] #[PB,Z] z = z.view(-1, self.z_size) logdetsum = 0. for i in range(self.n_flows): z, v, logdet = self.norm_flow(self.params[i], z, v, logposterior) logdetsum += logdet xz = torch.cat((x_tiled, z), 1) #r(vT|x,zT) out = self.act_func(self.fc7(xz)) out = self.act_func(self.fc8(out)) out = self.fc9(out) mean = out[:, :self.z_size] logvar = out[:, self.z_size:] v = v.view(k, self.B, self.z_size) logrvT = lognormal(v, mean, logvar) #[P,B] z = z.view(k, self.B, self.z_size) return z, logqz0 + logqv0 - logdetsum - logrvT
def forward(self, k, x, logposterior): ''' k: number of samples x: [B,X] logposterior(z) -> [P,B] ''' self.B = x.size()[0] self.P = k # print (self.B, 'B') # print (k) # fsdaf #q(v|x) out = x for i in range(len(self.qv_weights) - 1): out = self.act_func(self.qv_weights[i](out)) out = self.qv_weights[-1](out) mean = out[:, :self.z_size] logvar = out[:, self.z_size:] #Sample v0 eps = Variable( torch.FloatTensor(k, self.B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] v = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] logqv0 = lognormal(v, mean, logvar) #[P,B] #[PB,Z] v = v.view(-1, self.z_size) #[PB,X] x_tiled = x.repeat(k, 1) #[PB,X+Z] # print (x_tiled.size()) # print (v.size()) xv = torch.cat((x_tiled, v), 1) #q(z|x,v) out = xv for i in range(len(self.qz_weights) - 1): out = self.act_func(self.qz_weights[i](out)) out = self.qz_weights[-1](out) mean = out[:, :self.z_size] logvar = out[:, self.z_size:] self.B = x.size()[0] # print (self.B, 'B') #Sample z0 eps = Variable( torch.FloatTensor(k, self.B, self.z_size).normal_().type( self.dtype)) #[P,B,Z] # print (eps.size(),'eps') # print (mean.size(),'mean') # print (self.P, 'P') # print (mean) mean = mean.contiguous().view(self.P, self.B, self.z_size) logvar = logvar.contiguous().view(self.P, self.B, self.z_size) # print (mean) # mean = mean.contiguous().view(self.P,1,self.z_size) # logvar = logvar.contiguous().view(self.P,1,self.z_size) # print (mean.size(),'mean') z = eps.mul(torch.exp(.5 * logvar)) + mean #[P,B,Z] # print (z.size(),'z') # mean = mean.contiguous().view(self.P*self.B,self.z_size) # logvar = logvar.contiguous().view(self.P*self.B,self.z_size) logqz0 = lognormal333(z, mean, logvar) #[P,B] #[PB,Z] z = z.view(-1, self.z_size) # print (z.size()) logdetsum = 0. for i in range(self.n_flows): z, v, logdet = self.norm_flow(self.params[i], z, v) logdetsum += logdet xz = torch.cat((x_tiled, z), 1) #r(vT|x,zT) out = xz for i in range(len(self.rv_weights) - 1): out = self.act_func(self.rv_weights[i](out)) out = self.rv_weights[-1](out) mean = out[:, :self.z_size] logvar = out[:, self.z_size:] mean = mean.contiguous().view(self.P, self.B, self.z_size) logvar = logvar.contiguous().view(self.P, self.B, self.z_size) v = v.view(k, self.B, self.z_size) logrvT = lognormal333(v, mean, logvar) #[P,B] z = z.view(k, self.B, self.z_size) # print(logqz0.size(), 'here') # print(logqv0.size()) # print(logdetsum.size()) # print(logrvT.size()) logdetsum = logdetsum.view(k, self.B) # print (logqz0+logqv0-logdetsum-logrvT) # fadfdsa return z, logqz0 + logqv0 - logdetsum - logrvT
def test_ais(model, data_x, path_to_load_variables='', batch_size=20, display_epoch=4, k=10): def intermediate_dist(t, z, mean, logvar, zeros, batch): logp1 = lognormal(z, mean, logvar) #[P,B] log_prior = lognormal(z, zeros, zeros) #[P,B] log_likelihood = log_bernoulli(model.decode(z), batch) logpT = log_prior + log_likelihood log_intermediate_2 = (1-float(t))*logp1 + float(t)*logpT return log_intermediate_2 n_intermediate_dists = 25 n_HMC_steps = 5 step_size = .1 retain_graph = False volatile_ = False requires_grad = False if path_to_load_variables != '': # model.load_state_dict(torch.load(path_to_load_variables)) model.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage)) print 'loaded variables ' + path_to_load_variables logws = [] data_index= 0 for i in range(len(data_x)/ batch_size): print i #AIS schedule = np.linspace(0.,1.,n_intermediate_dists) model.B = batch_size batch = data_x[data_index:data_index+batch_size] data_index += batch_size if torch.cuda.is_available(): batch = Variable(batch, volatile=volatile_, requires_grad=requires_grad).cuda() zeros = Variable(torch.zeros(model.B, model.z_size), volatile=volatile_, requires_grad=requires_grad).cuda() # [B,Z] logw = Variable(torch.zeros(k, model.B), volatile=volatile_, requires_grad=requires_grad).cuda() grad_outputs = torch.ones(k, model.B).cuda() else: batch = Variable(batch) zeros = Variable(torch.zeros(model.B, model.z_size)) # [B,Z] logw = Variable(torch.zeros(k, model.B)) grad_outputs = torch.ones(k, model.B) #Encode x mean, logvar = model.encode(batch) #[B,Z] # print mean.data.numpy().shape # fasdf #Init z z, logpz, logqz = model.sample(mean, logvar, k=k) #[P,B,Z], [P,B], [P,B] # print logpz.data.numpy().shape # fasdf for (t0, t1) in zip(schedule[:-1], schedule[1:]): # gc.collect() memReport() print t0 #Compute intermediate distribution log prob # (1-t)*logp1(z) + (t)*logpT(z) logp1 = lognormal(z, mean, logvar) #[P,B] # print z.size() # print zeros.size() log_prior = lognormal(z, zeros, zeros) #[P,B] log_likelihood = log_bernoulli(model.decode(z), batch) logpT = log_prior + log_likelihood #log pt-1(zt-1) log_intermediate_1 = (1-float(t0))*logp1 + float(t0)*logpT #log pt(zt-1) log_intermediate_2 = (1-float(t1))*logp1 + float(t1)*logpT logw += log_intermediate_2 - log_intermediate_1 #HMC if torch.cuda.is_available(): v = Variable(torch.FloatTensor(z.size()).normal_(), volatile=volatile_, requires_grad=requires_grad).cuda() else: v = Variable(torch.FloatTensor(z.size()).normal_()) v0 = v z0 = z gradients = torch.autograd.grad(outputs=log_intermediate_2, inputs=z, grad_outputs=grad_outputs, create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] v = v + .5 *step_size*gradients z = z + step_size*v for LF_step in range(n_HMC_steps): # for LF_step in range(1): # print LF_step # logp1 = lognormal(z, mean, logvar) #[P,B] # log_prior = lognormal(z, zeros, zeros) #[P,B] # log_likelihood = log_bernoulli(model.decode(z), batch) # logpT = log_prior + log_likelihood # log_intermediate_2 = (1-float(t1))*logp1 + float(t1)*logpT log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) gradients = torch.autograd.grad(outputs=log_intermediate_2, inputs=z, grad_outputs=grad_outputs, create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] v = v + step_size*gradients z = z + step_size*v # logp1 = lognormal(z, mean, logvar) #[P,B] # log_prior = lognormal(z, zeros, zeros) #[P,B] # log_likelihood = log_bernoulli(model.decode(z), batch) # logpT = log_prior + log_likelihood # log_intermediate_2 = (1-float(t1))*logp1 + float(t1)*logpT log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) gradients = torch.autograd.grad(outputs=log_intermediate_2, inputs=z, grad_outputs=grad_outputs, create_graph=True, retain_graph=retain_graph, only_inputs=True)[0] v = v + .5 *step_size*gradients #MH step # logp1 = lognormal(z0, mean, logvar) #[P,B] # log_prior = lognormal(z0, zeros, zeros) #[P,B] # log_likelihood = log_bernoulli(model.decode(z0), batch) # logpT = log_prior + log_likelihood # log_intermediate_2 = (1-float(t1))*logp1 + float(t1)*logpT log_intermediate_2 = intermediate_dist(t1, z0, mean, logvar, zeros, batch) logpv0 = lognormal(v0, zeros, zeros) #[P,B] hamil_0 = log_intermediate_2 + logpv0 # logp1 = lognormal(z, mean, logvar) #[P,B] # log_prior = lognormal(z, zeros, zeros) #[P,B] # log_likelihood = log_bernoulli(model.decode(z), batch) # logpT = log_prior + log_likelihood # log_intermediate_2 = (1-float(t1))*logp1 + float(t1)*logpT log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch) logpvT = lognormal(v, zeros, zeros) #[P,B] hamil_T = log_intermediate_2 + logpvT # print hamil_T.data.numpy().shape accept_prob = torch.exp(hamil_T - hamil_0) if torch.cuda.is_available(): rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_(), volatile=volatile_, requires_grad=requires_grad).cuda() else: rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_()) accept = accept_prob > rand_uni if torch.cuda.is_available(): accept = accept.type(torch.FloatTensor).cuda() else: accept = accept.type(torch.FloatTensor) accept = accept.view(k, model.B, 1) # print accept.data.numpy().shape # print torch.mean(accept) z = (accept * z) + ((1-accept) * z0) avg_acceptance_rate = torch.mean(accept) # print avg_acceptance_rate.data.numpy() # if avg_acceptance_rate.data.numpy() > .7: # if avg_acceptance_rate > .7: if avg_acceptance_rate.cpu().data.numpy() > .7: step_size = 1.02 * step_size else: step_size = .98 * step_size if step_size < 0.0001: step_size = 0.0001 if step_size > 0.5: step_size = 0.5 #lgo sum exp max_ = torch.max(logw,0)[0] #[B] logw = torch.log(torch.mean(torch.exp(logw - max_), 0)) + max_ #[B] logws.append(torch.mean(logw.cpu()).data.numpy()) if i%display_epoch==0: print i,len(data_x)/ batch_size, np.mean(logws) return np.mean(logws)