def __call__(self, x): # Obtain parameters for q(z|x) encoding_time = time.time() self.encode(x) encoding_time = float(time.time() - encoding_time) decoding_time_average = 0. xp = cuda.cupy self.importance_weights = 0 self.w_holder = [] self.kl = 0 self.logp = 0 for j in xrange(self.num_zsamples): # Sample z ~ q(z|x) z = F.gaussian(self.qmu, self.qln_var) # Compute log q(z|x) encoder_log = gaussian_logp(z, self.qmu, self.qln_var) # Obtain parameters for p(x|z) decoding_time = time.time() self.decode(z) decoding_time = time.time() - decoding_time decoding_time_average += decoding_time # Compute log p(x|z) decoder_log = bernoulli_logp(x, self.p_ber_prob_logit) # Compute log p(z). prior_log = gaussian_logp0(z) # Store the latest log weight' current_temperature = min(self.temperature['value'], 1.0) self.w_holder.append(decoder_log + current_temperature * (prior_log - encoder_log)) # Store the KL and Logp equivalents. They are not used for computation but for recording and reporting. self.kl += (encoder_log - prior_log) self.logp += (decoder_log) self.temperature['value'] += self.temperature['increment'] # Compute w' for this sample (batch) logps = F.stack(self.w_holder) self.obj_batch = F.logsumexp(logps, axis=0) - np.log(self.num_zsamples) self.kl /= self.num_zsamples self.logp /= self.num_zsamples decoding_time_average /= self.num_zsamples batch_size = self.obj_batch.shape[0] self.obj = -F.sum(self.obj_batch) / batch_size self.timing_info = np.array([encoding_time, decoding_time_average]) return self.obj
def __call__(self, x): # Compute q(z|x) # pdb.set_trace() encoding_time = time.time() self.encode(x) encoding_time = float(time.time() - encoding_time) decoding_time_average = 0. self.kl = 0 self.logp = 0 for j in xrange(self.num_zsamples): # z ~ q(z|x) z = F.gaussian(self.qmu, self.qln_var) # pdb.set_trace() # Compute log q(z|x) encoder_log = gaussian_logp(z, self.qmu, self.qln_var) # Compute p(x|z) decoding_time = time.time() self.decode(z) decoding_time = time.time() - decoding_time decoding_time_average += decoding_time # Computer p(z) prior_log = gaussian_logp0(z) # Compute objective self.kl += (encoder_log-prior_log) self.logp += bernoulli_logp(x, self.p_ber_prob_logit) # pdb.set_trace() current_temperature = min(self.temperature['value'],1.0) self.temperature['value'] += self.temperature['increment'] decoding_time_average /= self.num_zsamples self.logp /= self.num_zsamples self.kl /= self.num_zsamples # pdb.set_trace() self.obj_batch = self.logp - (current_temperature*self.kl) self.timing_info = np.array([encoding_time,decoding_time_average]) batch_size = self.obj_batch.shape[0] self.obj = -F.sum(self.obj_batch)/batch_size # pdb.set_trace() return self.obj
def __call__(self, x): # Obtain parameters for q(z|x) encoding_time = time.time() qmu, qln_var, qh_vec_0 = self.encode(x) encoding_time = float(time.time() - encoding_time) decoding_time_average = 0. self.kl = 0 self.logp = 0 for j in xrange(self.num_zsamples): # z_0 ~ q(z|x) z_0 = F.gaussian(qmu, qln_var) # Perform Householder flow transformation, Equation (8) decoding_time = time.time() z_T = self.house_transform(z_0) # Obtain parameters for p(x|z_T) p_ber_prob_logit = self.decode(z_T) decoding_time = time.time() - decoding_time decoding_time_average += decoding_time # Compute objective self.logp += bernoulli_logp(x, self.p_ber_prob_logit) self.kl += gaussian_kl_divergence(z_0, qmu, qln_var, z_T) decoding_time_average /= self.num_zsamples self.logp /= self.num_zsamples self.kl /= self.num_zsamples current_temperature = min(self.temperature['value'], 1.0) self.obj_batch = self.logp - (current_temperature * self.kl) self.temperature['value'] += self.temperature['increment'] self.timing_info = np.array([encoding_time, decoding_time_average]) batch_size = self.obj_batch.shape[0] self.obj = -F.sum(self.obj_batch) / batch_size return self.obj
def __call__(self, x): # Compute q(z|x) encoding_time = time.time() qmu, qln_var = self.encode(x) encoding_time = float(time.time() - encoding_time) decoding_time_average = 0. self.kl = 0 self.logp = 0 current_temperature = min(self.temperature['value'],1.0) self.temperature['value'] += self.temperature['increment'] for j in xrange(self.num_zsamples): # Sample z ~ q(z_0|x) z_0 = F.gaussian(self.qmu, self.qln_var) # Perform planar flow mappings, Equation (10) decoding_time = time.time() z_K = self.planar_flows(z_0) # Obtain parameters for p(x|z_K) p_ber_prob_logit = self.decode(z_K) decoding_time = time.time() - decoding_time decoding_time_average += decoding_time # Compute log q(z_0) q_prior_log = current_temperature*gaussian_logp0(z_0) # Compute log p(x|z_K) decoder_log = bernoulli_logp(x, p_ber_prob_logit) # Compute log p(z_K) p_prior_log = current_temperature*gaussian_logp0(z_K) # Compute log p(x,z_K) which is log p(x|z_K) + log p(z_K) joint_log = decoder_log + p_prior_log # Compute second term of log q(z_K) q_K_log = 0 for i in range(self.num_trans): flow_u_name = 'flow_u_' + str(i) lodget_jacobian = F.sum(self[flow_u_name](self.phi[i]), axis=1) q_K_log += F.log(1 + lodget_jacobian) q_K_log *= current_temperature # For recording purposes only self.logp += decoder_log self.kl += -(q_prior_log - p_prior_log - q_K_log) decoding_time_average /= self.num_zsamples # pdb.set_trace() self.obj_batch = ((q_prior_log -joint_log) - q_K_log) self.obj_batch /= self.num_zsamples batch_size = self.obj_batch.shape[0] self.obj = F.sum(self.obj_batch)/batch_size self.kl /= self.num_zsamples self.logp /= self.num_zsamples self.timing_info = np.array([encoding_time,decoding_time]) return self.obj
def __call__(self, x): # Compute parameters for q(z|x, a) encoding_time_1 = time.time() qmu_a, qln_var_a = self.encode_a(x) encoding_time_1 = float(time.time() - encoding_time_1) a_enc = F.gaussian(qmu_a, qln_var_a) encoding_time_2 = time.time() qmu_z, qln_var_z = self.encode_z(x, a_enc) encoding_time_2 = float(time.time() - encoding_time_2) encoding_time = encoding_time_1 + encoding_time_2 decoding_time_average = 0. self.kl = 0 self.logp = 0 logp_a_xz = 0 logp_x_z = 0 logp_z = 0 logq_a_x = 0 logq_z_ax = 0 current_temperature = min(self.temperature['value'], 1.0) self.temperature['value'] += self.temperature['increment'] for j in xrange(self.num_zsamples): # z ~ q(z|x, a) z = F.gaussian(self.qmu_z, self.qln_var_z) # Compute p(x|z) decoding_time = time.time() pmu_a, pln_var_a = self.decode_a(z, x) p_ber_prob_logit = self.decode(z) decoding_time = time.time() - decoding_time decoding_time_average += decoding_time logp_a_xz += gaussian_logp(a_enc, pmu_a, pln_var_a) logp_x_z += bernoulli_logp(x, p_ber_prob_logit) logp_z += current_temperature * gaussian_logp0(z) logq_a_x += gaussian_logp(a_enc, qmu_a, qln_var_a) logq_z_ax += current_temperature * gaussian_logp( z, qmu_z, qln_var_z) logp_a_xz /= self.num_zsamples logp_x_z /= self.num_zsamples logp_z /= self.num_zsamples logq_a_x /= self.num_zsamples logq_z_ax /= self.num_zsamples decoding_time_average /= self.num_zsamples self.logp /= self.num_zsamples self.obj_batch = logp_a_xz + logp_x_z + logp_z - logq_a_x - logq_z_ax self.kl = logq_z_ax - logp_z self.logp = logp_x_z self.timing_info = np.array([encoding_time, decoding_time_average]) batch_size = self.obj_batch.shape[0] self.obj = -F.sum(self.obj_batch) / batch_size return self.obj
def __call__(self, x): # Obtain parameters for q(z|x) encoding_time = time.time() self.encode(x) encoding_time = float(time.time() - encoding_time) self.logp_xz = 0 self.logq = 0 # For reporting purposes only self.logp = 0 self.kl = 0 decoding_time_average = 0. current_temperature = min(self.temperature['value'], 1.0) self.temperature['value'] += self.temperature['increment'] for j in xrange(self.num_zsamples): # z ~ q(z|x) z = F.gaussian(self.qmu, self.qln_var) decoding_time = time.time() # Apply inverse autoregressive flow (IAF) self.logq += gaussian_logp(z, self.qmu, self.qln_var) # -> log q(z|x) logq_ori = self.logq for i in range(self.num_trans): a_layer_name = 'qiaf_a' + str(i + 1) b_layer_name = 'qiaf_b' + str(i + 1) z, delta_logq = self.iaf(z, self.qh, self[a_layer_name], self[b_layer_name]) self.logq += delta_logq self.logq *= current_temperature # Compute p(x|z) self.decode(z) decoding_time = time.time() - decoding_time decoding_time_average += decoding_time # Compute objective, p(x,z) logx_given_z = bernoulli_logp(x, self.p_ber_prob_logit) # p(x|z) logz = (current_temperature * gaussian_logp0(z)) # p(z) self.logp_xz += (logx_given_z + logz) # p(x,z) # For reporting purposes only self.logp += logx_given_z self.kl += (self.logq - logz) # pdb.set_trace() decoding_time_average /= self.num_zsamples # self.logp_xz /= self.num_zsamples # self.logq /= self.num_zsamples # For reporting purposes only self.logp /= self.num_zsamples self.kl /= self.num_zsamples self.obj_batch = self.logp_xz - self.logq # ELBO, form: p(x,z) - q(z|x) self.obj_batch /= self.num_zsamples self.timing_info = np.array([encoding_time, decoding_time_average]) batch_size = self.obj_batch.shape[0] self.obj = -F.sum(self.obj_batch) / batch_size # pdb.set_trace() return self.obj