def elbo_h_gaussian_x_bernoulli(d, z, mean_q, log_var_q, mean_prior, log_var_prior, p_out, x, mask, settings, temperature_KL=1.0, test=False): if not test: mask_sum = T.sum(mask, axis=1) # We some over output_dim and we take the mean over batch_size and sequence_length log_p_x_given_h = log_bernoulli( x=x, p=p_out, eps=settings.tolerance_softmax) * mask.dimshuffle( 0, 1, 'x') log_p_x_given_h = log_p_x_given_h.sum(axis=(1, 2)) / mask_sum log_p_x_given_h_tot = log_p_x_given_h.mean() kl_divergence = kl_normal2_normal2(mean_q, log_var_q, mean_prior, log_var_prior) # kl_divergence has size (batch_size, sequence_length, output_dim) kl_divergence_tmp = kl_divergence * mask.dimshuffle(0, 1, 'x') kl_divergence_tmp = kl_divergence_tmp.sum(axis=(1, 2)) / mask_sum kl_divergence_tot = T.mean(kl_divergence_tmp) lower_bound = log_p_x_given_h_tot - temperature_KL * kl_divergence_tot return lower_bound else: mask_sum = T.sum(mask, axis=1) log_p_x_given_h = log_bernoulli( x=x, p=p_out, eps=settings.tolerance_softmax) * mask.dimshuffle( 0, 1, 'x') log_p_x_given_h_tot = log_p_x_given_h.sum(axis=(1, 2)) / mask_sum kl_divergence = kl_normal2_normal2(mean_q, log_var_q, mean_prior, log_var_prior) # kl_divergence has size (batch_size, sequence_length, output_dim) kl_divergence_seq = T.reshape( kl_divergence, (settings.batch_size, -1, settings.latent_size_z)) kl_divergence_seq = kl_divergence_seq * mask.dimshuffle( 0, 1, 'x') kl_divergence_tot = kl_divergence_seq.sum(axis=(1, 2)) / mask_sum lower_bound = log_p_x_given_h_tot - temperature_KL * kl_divergence_tot return lower_bound
def log_likelihood(z, z_mu, z_log_var, x_mu, x, analytic_kl_term): if analytic_kl_term: kl_term = kl_normal2_stdnormal(z_mu, z_log_var).sum(axis = 1) log_px_given_z = log_bernoulli(x, x_mu, eps = 1e-6).sum(axis = 1) LL = T.mean(-kl_term + log_px_given_z) else: log_qz_given_x = log_normal2(z, z_mu, z_log_var).sum(axis = 1) log_pz = log_stdnormal(z).sum(axis = 1) log_px_given_z = log_bernoulli(x, x_mu, eps = 1e-6).sum(axis = 1) LL = T.mean(log_pz + log_px_given_z - log_qz_given_x) return LL
def latent_gaussian_x_bernoulli(z, z_mu, psi_u_list, z_log_var, x_mu, x, eq_samples, iw_samples, epsilon=1e-6): """ Latent z : gaussian with standard normal prior decoder output : bernoulli When the output is bernoulli then the output from the decoder should be sigmoid. The sizes of the inputs are z: (batch_size*Eq_samples*ivae_samples*nsamples, num_latent) z_mu: (batch_size, num_latent) z_log_var: (batch_size, num_latent) x_mu: (batch_size*Eq_samples*ivae_samples*nsamples, num_latent) x: (batch_size, num_features) Reference: Burda et. al. 2015 "Importance Weighted Autoencoders" """ # reshape the variables so batch_size, eq_samples and iw_samples are separate dimensions z = z.reshape((-1, eq_samples, iw_samples, latent_size)) x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features)) # dimshuffle x since we need to broadcast it when calculationg the binary # cross-entropy x = x.dimshuffle(0,'x','x',1) # x: (batch_size, eq_samples, iw_samples, num_latent) for i in range(len(psi_u_list)): psi_u_list[i] = psi_u_list[i].reshape((-1, eq_samples, iw_samples)) #calculate LL components, note that we sum over the feature/num_unit dimension z_mu = z_mu.dimshuffle(0,'x','x',1) # mean: (batch_size, num_latent) z_log_var = z_log_var.dimshuffle(0,'x','x',1) # logvar: (batch_size, num_latent) log_qz_given_x = log_normal2(z, z_mu, z_log_var).sum(axis=3) log_pz = log_stdnormal(z).sum(axis=3) log_px_given_z = log_bernoulli(x, T.clip(x_mu,epsilon,1-epsilon)).sum(axis=3) #normalizing flow loss sum_log_psiu = 0 for psi_u in psi_u_list: sum_log_psiu += T.log(T.abs_(1+psi_u)) #all log_*** should have dimension (batch_size, eq_samples, iw_samples) # Calculate the LL using log-sum-exp to avoid underflow a = log_pz + log_px_given_z - log_qz_given_x + sum_log_psiu # size: (batch_size, eq_samples, iw_samples) a_max = T.max(a, axis=2, keepdims=True) # size: (batch_size, eq_samples, 1) # LL is calculated using Eq (8) in burda et al. # Working from inside out of the calculation below: # T.exp(a-a_max): (bathc_size, Eq_samples, iw_samples) # -> subtract a_max to avoid overflow. a_max is specific for each set of # importance samples and is broadcoasted over the last dimension. # # T.log( T.mean(T.exp(a-a_max), axis=2): (batch_size, Eq_samples) # -> This is the log of the sum over the importance weighted samples # # Lastly we add T.mean(a_max) to correct for the log-sum-exp trick LL = T.mean(a_max) + T.mean( T.log( T.mean(T.exp(a-a_max), axis=2))) return LL, T.mean(log_qz_given_x), T.mean(log_pz), T.mean(log_px_given_z)
def latent_gaussian_x_bernoulli(z, z_mu, z_log_var, x, x_mu, analytic_kl_term): """ Latent z : gaussian with standard normal prior decoder output : bernoulli When the output is bernoulli then the output from the decoder should be sigmoid. """ if analytic_kl_term: kl_term = kl_normal2_stdnormal(z_mu, z_log_var).sum(axis=1) log_px_given_z = log_bernoulli(x, x_mu).sum(axis=1) LL = T.mean(-kl_term + log_px_given_z) else: log_qz_given_x = log_normal2(z, z_mu, z_log_var).sum(axis=1) log_pz = log_stdnormal(z).sum(axis=1) log_px_given_z = log_bernoulli(x, x_mu).sum(axis=1) LL = T.mean(log_pz + log_px_given_z - log_qz_given_x) return LL
def lower_bound(z, z_mu, x_mu, x, eq_samples, iw_samples, epsilon=1e-6): from theano.gradient import disconnected_grad as dg # reshape the variables so batch_size, eq_samples and iw_samples are # separate dimensions z = z.reshape((-1, eq_samples, iw_samples, latent_size)) x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features)) # prepare x, z for broadcasting # size: (batch_size, eq_samples, iw_samples, num_features) x = x.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) z_mu = z_mu.dimshuffle(0, 'x', 'x', 1) log_qz_given_x = log_bernoulli(z, z_mu, eps=epsilon).sum(axis=3) z_prior = T.ones_like(z)*np.float32(0.5) log_pz = log_bernoulli(z, z_prior).sum(axis=3) log_px_given_z = log_bernoulli(x, x_mu, eps=epsilon).sum(axis=3) # Calculate the LL using log-sum-exp to avoid underflow log_pxz = log_pz + log_px_given_z # L is (bs, mc) See definition of L in appendix eq. 14 L = log_sum_exp(log_pxz - log_qz_given_x, axis=2) + \ T.log(1.0/T.cast(iw_samples, 'float32')) grads_model = T.grad(-L.mean(), p_params) # L_corr should correspond to equation 10 in the paper L_corr = L.dimshuffle(0, 1, 'x') - get_vimco_baseline( log_pxz - log_qz_given_x) g_lb_inference = T.mean(T.sum(dg(L_corr) * log_qz_given_x) + L) grads_inference = T.grad(-g_lb_inference, q_params) grads = grads_model + grads_inference LL = log_mean_exp(log_pz + log_px_given_z - log_qz_given_x, axis=2) return (LL, T.mean(log_qz_given_x), T.mean(log_pz), T.mean(log_px_given_z), grads)
def latent_gaussian_x_bernoulli(z, z_mu, z_log_var, x_mu, x, eq_samples, iw_samples, epsilon=1e-6): """ Latent z : gaussian with standard normal prior decoder output : bernoulli When the output is bernoulli then the output from the decoder should be sigmoid. The sizes of the inputs are z: (batch_size*eq_samples*iw_samples, num_latent) z_mu: (batch_size, num_latent) z_log_var: (batch_size, num_latent) x_mu: (batch_size*eq_samples*iw_samples, num_features) x: (batch_size, num_features) Reference: Burda et al. 2015 "Importance Weighted Autoencoders" """ # reshape the variables so batch_size, eq_samples and iw_samples are separate dimensions z = z.reshape((-1, eq_samples, iw_samples, latent_size)) x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features)) # dimshuffle x, z_mu and z_log_var since we need to broadcast them when calculating the pdfs x = x.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_features) z_mu = z_mu.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) z_log_var = z_log_var.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) # calculate LL components, note that the log_xyz() functions return log prob. for indepenedent components separately # so we sum over feature/latent dimensions for multivariate pdfs log_qz_given_x = log_normal2(z, z_mu, z_log_var).sum(axis=3) log_pz = log_stdnormal(z).sum(axis=3) log_px_given_z = log_bernoulli(x, x_mu, eps=epsilon).sum(axis=3) #all log_*** should have dimension (batch_size, eq_samples, iw_samples) # Calculate the LL using log-sum-exp to avoid underflow a = log_pz + log_px_given_z - log_qz_given_x # size: (batch_size, eq_samples, iw_samples) a_max = T.max(a, axis=2, keepdims=True) # size: (batch_size, eq_samples, 1) # LL is calculated using Eq (8) in Burda et al. # Working from inside out of the calculation below: # T.exp(a-a_max): (batch_size, eq_samples, iw_samples) # -> subtract a_max to avoid overflow. a_max is specific for each set of # importance samples and is broadcasted over the last dimension. # # T.log( T.mean(T.exp(a-a_max), axis=2) ): (batch_size, eq_samples) # -> This is the log of the sum over the importance weighted samples # # The outer T.mean() computes the mean over eq_samples and batch_size # # Lastly we add T.mean(a_max) to correct for the log-sum-exp trick LL = T.mean(a_max) + T.mean( T.log( T.mean(T.exp(a-a_max), axis=2) ) ) return LL, T.mean(log_qz_given_x), T.mean(log_pz), T.mean(log_px_given_z)
def latent_gaussian_x_bernoulli(z, z_mu, z_log_var, x_mu, x, eq_samples, iw_samples, epsilon=1e-6): """ Latent z : gaussian with standard normal prior decoder output : bernoulli When the output is bernoulli then the output from the decoder should be sigmoid. The sizes of the inputs are z: (batch_size*eq_samples*iw_samples, num_latent) z_mu: (batch_size, num_latent) z_log_var: (batch_size, num_latent) x_mu: (batch_size*eq_samples*iw_samples, num_features) x: (batch_size, num_features) Reference: Burda et al. 2015 "Importance Weighted Autoencoders" """ # reshape the variables so batch_size, eq_samples and iw_samples are separate dimensions z = z.reshape((-1, eq_samples, iw_samples, latent_size)) x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features)) # dimshuffle x, z_mu and z_log_var since we need to broadcast them when calculating the pdfs x = x.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_features) z_mu = z_mu.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) z_log_var = z_log_var.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) # calculate LL components, note that the log_xyz() functions return log prob. for indepenedent components separately # so we sum over feature/latent dimensions for multivariate pdfs log_qz_given_x = log_normal2(z, z_mu, z_log_var).sum(axis=3) log_pz = log_stdnormal(z).sum(axis=3) log_px_given_z = log_bernoulli(x, T.clip(x_mu, epsilon, 1 - epsilon)).sum(axis=3) #all log_*** should have dimension (batch_size, eq_samples, iw_samples) # Calculate the LL using log-sum-exp to avoid underflow a = log_pz + log_px_given_z - log_qz_given_x # size: (batch_size, eq_samples, iw_samples) a_max = T.max(a, axis=2, keepdims=True) # size: (batch_size, eq_samples, 1) # LL is calculated using Eq (8) in Burda et al. # Working from inside out of the calculation below: # T.exp(a-a_max): (batch_size, eq_samples, iw_samples) # -> subtract a_max to avoid overflow. a_max is specific for each set of # importance samples and is broadcasted over the last dimension. # # T.log( T.mean(T.exp(a-a_max), axis=2) ): (batch_size, eq_samples) # -> This is the log of the sum over the importance weighted samples # # The outer T.mean() computes the mean over eq_samples and batch_size # # Lastly we add T.mean(a_max) to correct for the log-sum-exp trick LL = T.mean(a_max) + T.mean( T.log( T.mean(T.exp(a-a_max), axis=2) ) ) return LL, T.mean(log_qz_given_x), T.mean(log_pz), T.mean(log_px_given_z)
def lower_bound(z, z_mu, x_mu, x, eq_samples, iw_samples, epsilon=1e-6): from theano.gradient import disconnected_grad as dg # reshape the variables so batch_size, eq_samples and iw_samples are # separate dimensions z = z.reshape((-1, eq_samples, iw_samples, latent_size)) x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features)) # prepare x, z for broadcasting # size: (batch_size, eq_samples, iw_samples, num_features) x = x.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) z_mu = z_mu.dimshuffle(0, 'x', 'x', 1) log_qz_given_x = log_bernoulli(z, z_mu, eps=epsilon).sum(axis=3) z_prior = T.ones_like(z) * np.float32(0.5) log_pz = log_bernoulli(z, z_prior).sum(axis=3) log_px_given_z = log_bernoulli(x, x_mu, eps=epsilon).sum(axis=3) # Calculate the LL using log-sum-exp to avoid underflow log_pxz = log_pz + log_px_given_z # L is (bs, mc) See definition of L in appendix eq. 14 L = log_sum_exp(log_pxz - log_qz_given_x, axis=2) + \ T.log(1.0/T.cast(iw_samples, 'float32')) grads_model = T.grad(-L.mean(), p_params) # L_corr should correspond to equation 10 in the paper L_corr = L.dimshuffle(0, 1, 'x') - get_vimco_baseline(log_pxz - log_qz_given_x) g_lb_inference = T.mean(T.sum(dg(L_corr) * log_qz_given_x) + L) grads_inference = T.grad(-g_lb_inference, q_params) grads = grads_model + grads_inference LL = log_mean_exp(log_pz + log_px_given_z - log_qz_given_x, axis=2) return (LL, T.mean(log_qz_given_x), T.mean(log_pz), T.mean(log_px_given_z), grads)
def latent_gaussian_x_bernoulli(z, z_mu, z_log_var, x_mu, x, analytic_kl_term): """ Latent z : gaussian with standard normal prior decoder output : bernoulli When the output is bernoulli then the output from the decoder should be sigmoid. The sizes of the inputs are z: (batch_size, num_latent) z_mu: (batch_size, num_latent) z_log_var: (batch_size, num_latent) x_mu: (batch_size, num_features) x: (batch_size, num_features) """ if analytic_kl_term: kl_term = kl_normal2_stdnormal(z_mu, z_log_var).sum(axis=1) log_px_given_z = log_bernoulli(x, x_mu).sum(axis=1) LL = T.mean(-kl_term + log_px_given_z) else: log_qz_given_x = log_normal2(z, z_mu, z_log_var).sum(axis=1) log_pz = log_stdnormal(z).sum(axis=1) log_px_given_z = log_bernoulli(x, x_mu).sum(axis=1) LL = T.mean(log_pz + log_px_given_z - log_qz_given_x) return LL
def latent_gaussian_x_bernoulli(z, z_mu, z_log_var, x_mu, x, analytic_kl_term): """ Latent z : gaussian with standard normal prior decoder output : bernoulli When the output is bernoulli then the output from the decoder should be sigmoid. The sizes of the inputs are z: (batch_size, num_latent) z_mu: (batch_size, num_latent) z_log_var: (batch_size, num_latent) x_mu: (batch_size, num_features) x: (batch_size, num_features) """ if analytic_kl_term: kl_term = kl_normal2_stdnormal(z_mu, z_log_var).sum(axis=1) log_px_given_z = log_bernoulli(x, x_mu, eps=1e-6).sum(axis=1) LL = T.mean(-kl_term + log_px_given_z) else: log_qz_given_x = log_normal2(z, z_mu, z_log_var).sum(axis=1) log_pz = log_stdnormal(z).sum(axis=1) log_px_given_z = log_bernoulli(x, x_mu, eps=1e-6).sum(axis=1) LL = T.mean(log_pz + log_px_given_z - log_qz_given_x) return LL
def latent_gaussian_x_bernoulli(z0, zk, z0_mu, z0_log_var, logdet_J_list, x_mu, x, eq_samples, iw_samples, epsilon=1e-6): """ Latent z : gaussian with standard normal prior decoder output : bernoulli When the output is bernoulli then the output from the decoder should be sigmoid. The sizes of the inputs are z0: (batch_size*eq_samples*iw_samples, num_latent) zk: (batch_size*eq_samples*iw_samples, num_latent) z0_mu: (batch_size, num_latent) z0_log_var: (batch_size, num_latent) logdet_J_list: list of `nflows` elements, each with shape (batch_size*eq_samples*iw_samples) x_mu: (batch_size*eq_samples*iw_samples, num_features) x: (batch_size, num_features) Reference: Burda et al. 2015 "Importance Weighted Autoencoders" """ # reshape the variables so batch_size, eq_samples and iw_samples are separate dimensions z0 = z0.reshape((-1, eq_samples, iw_samples, latent_size)) zk = zk.reshape((-1, eq_samples, iw_samples, latent_size)) x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features)) for i in range(len(logdet_J_list)): logdet_J_list[i] = logdet_J_list[i].reshape((-1, eq_samples, iw_samples)) # dimshuffle x, z_mu and z_log_var since we need to broadcast them when calculating the pdfs x = x.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_features) z0_mu = z0_mu.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) z0_log_var = z0_log_var.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) # calculate LL components, note that the log_xyz() functions return log prob. for indepenedent components separately # so we sum over feature/latent dimensions for multivariate pdfs log_q0z0_given_x = log_normal2(z0, z0_mu, z0_log_var).sum(axis=3) log_pzk = log_stdnormal(zk).sum(axis=3) log_px_given_zk = log_bernoulli(x, x_mu, epsilon).sum(axis=3) #normalizing flow loss sum_logdet_J = 0 for logdet_J_k in logdet_J_list: sum_logdet_J += logdet_J_k # Calculate the LL using log-sum-exp to avoid underflow all log_*** -> shape: (batch_size, eq_samples, iw_samples) LL = log_mean_exp(log_pzk + log_px_given_zk - log_q0z0_given_x + sum_logdet_J, axis=2) # log-mean-exp over iw_samples dimension -> shape: (batch_size, eq_samples) LL = T.mean(LL) # average over eq_samples, batch_size dimensions -> shape: () return LL, T.mean(log_q0z0_given_x), T.mean(sum_logdet_J), T.mean(log_pzk), T.mean(log_px_given_zk)
def ELBO(z, z_mu, z_log_var, x_mu, x): """ Latent z : gaussian with standard normal prior decoder output : bernoulli When the output is bernoulli then the output from the decoder should be sigmoid. The sizes of the inputs are z: (batch_size, num_latent) z_mu: (batch_size, num_latent) z_log_var: (batch_size, num_latent) x_mu: (batch_size, num_features) x: (batch_size, num_features) """ kl_term = kl_normal2_stdnormal(z_mu, z_log_var).sum(axis=1) log_px_given_z = log_bernoulli(x, x_mu, eps=1e-6).sum(axis=1) LL = T.mean(-kl_term + log_px_given_z) return LL
def lowerbound_for_reinforce(z, z_mu, z_log_var, x_mu, x, num_features, num_labelled, num_classes, epsilon=1e-6): x = x.reshape((-1, num_features)) x_mu = x_mu.reshape((-1, num_features)) log_qz_given_xy = log_normal2(z, z_mu, z_log_var).sum(axis=1) log_pz = log_stdnormal(z).sum(axis=1) log_py = T.log(1.0 / num_classes) log_px_given_zy = log_bernoulli(x, T.clip(x_mu, epsilon, 1 - epsilon)).sum(axis=1) ll_xy = log_px_given_zy + log_pz + log_py - log_qz_given_xy return ll_xy[num_labelled:]
def latent_gaussian_x_bernoulli(z, z_mu, z_log_var, x_mu, x, eq_samples, iw_samples, epsilon=1e-6): """ Latent z : gaussian with standard normal prior decoder output : bernoulli When the output is bernoulli then the output from the decoder should be sigmoid. The sizes of the inputs are z: (batch_size*eq_samples*iw_samples, num_latent) z_mu: (batch_size, num_latent) z_log_var: (batch_size, num_latent) x_mu: (batch_size*eq_samples*iw_samples, num_features) x: (batch_size, num_features) Reference: Burda et al. 2015 "Importance Weighted Autoencoders" """ # reshape the variables so batch_size, eq_samples and iw_samples are separate dimensions z = z.reshape((-1, eq_samples, iw_samples, latent_size)) x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features)) # dimshuffle x, z_mu and z_log_var since we need to broadcast them when calculating the pdfs x = x.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_features) z_mu = z_mu.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) z_log_var = z_log_var.dimshuffle(0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) # calculate LL components, note that the log_xyz() functions return log prob. for indepenedent components separately # so we sum over feature/latent dimensions for multivariate pdfs log_qz_given_x = log_normal2(z, z_mu, z_log_var).sum(axis=3) log_pz = log_stdnormal(z).sum(axis=3) log_px_given_z = log_bernoulli(x, x_mu, epsilon).sum(axis=3) # Calculate the LL using log-sum-exp to avoid underflow all log_*** -> shape: (batch_size, eq_samples, iw_samples) LL = log_mean_exp(log_pz + log_px_given_z - log_qz_given_x, axis=2) # log-mean-exp over iw_samples dimension -> shape: (batch_size, eq_samples) LL = T.mean(LL) # average over eq_samples, batch_size dimensions -> shape: () return LL, T.mean(log_qz_given_x), T.mean(log_pz), T.mean(log_px_given_z)
k_max_indices = mean == k_max mean[k_max_indices] += meanOfNegativeBinomialDistribution( r[k_max_indices], log_r[k_max_indices]) return mean reconstruction_distributions = { "bernoulli": { "parameters": ["p"], "activation functions": { "p": sigmoid, }, "function": lambda x, x_theta, eps = 0.0: \ log_bernoulli(x, x_theta["p"], eps), "mean": lambda x_theta: x_theta["p"], # TODO Consider switching to Bernouilli sampling "preprocess": lambda x: (x != 0).astype('float32') }, "negative_binomial": { "parameters": ["p", "log_r"], "activation functions": { "p": sigmoid, "log_r": lambda x: T.clip(x, -10, 10) }, "function": lambda x, x_theta, eps = 0.0: \ log_negative_binomial(x, x_theta["p"], x_theta["log_r"], eps), "mean": lambda x_theta: \ meanOfNegativeBinomialDistribution(x_theta["p"], x_theta["log_r"]),
def latent_gaussian_x_bernoulli(z0, zk, z0_mu, z0_log_var, logdet_J_list, x_mu, x, eq_samples, iw_samples, epsilon=1e-6): """ Latent z : gaussian with standard normal prior decoder output : bernoulli When the output is bernoulli then the output from the decoder should be sigmoid. The sizes of the inputs are z0: (batch_size*eq_samples*iw_samples, num_latent) zk: (batch_size*eq_samples*iw_samples, num_latent) z0_mu: (batch_size, num_latent) z0_log_var: (batch_size, num_latent) logdet_J_list: list of `nflows` elements, each with shape (batch_size*eq_samples*iw_samples) x_mu: (batch_size*eq_samples*iw_samples, num_features) x: (batch_size, num_features) Reference: Burda et al. 2015 "Importance Weighted Autoencoders" """ # reshape the variables so batch_size, eq_samples and iw_samples are separate dimensions z0 = z0.reshape((-1, eq_samples, iw_samples, latent_size)) zk = zk.reshape((-1, eq_samples, iw_samples, latent_size)) x_mu = x_mu.reshape((-1, eq_samples, iw_samples, num_features)) for i in range(len(logdet_J_list)): logdet_J_list[i] = logdet_J_list[i].reshape( (-1, eq_samples, iw_samples)) # dimshuffle x, z_mu and z_log_var since we need to broadcast them when calculating the pdfs x = x.dimshuffle( 0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_features) z0_mu = z0_mu.dimshuffle( 0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) z0_log_var = z0_log_var.dimshuffle( 0, 'x', 'x', 1) # size: (batch_size, eq_samples, iw_samples, num_latent) # calculate LL components, note that the log_xyz() functions return log prob. for indepenedent components separately # so we sum over feature/latent dimensions for multivariate pdfs log_q0z0_given_x = log_normal2(z0, z0_mu, z0_log_var).sum(axis=3) log_pzk = log_stdnormal(zk).sum(axis=3) log_px_given_zk = log_bernoulli(x, x_mu, epsilon).sum(axis=3) #normalizing flow loss sum_logdet_J = 0 for logdet_J_k in logdet_J_list: sum_logdet_J += logdet_J_k # Calculate the LL using log-sum-exp to avoid underflow all log_*** -> shape: (batch_size, eq_samples, iw_samples) LL = log_mean_exp( log_pzk + log_px_given_zk - log_q0z0_given_x + sum_logdet_J, axis=2 ) # log-mean-exp over iw_samples dimension -> shape: (batch_size, eq_samples) LL = T.mean( LL) # average over eq_samples, batch_size dimensions -> shape: () return LL, T.mean(log_q0z0_given_x), T.mean(sum_logdet_J), T.mean( log_pzk), T.mean(log_px_given_zk)