def __init__(self, model, optimizer, hyper): super().__init__(model=model, optimizer=optimizer, hyper=hyper) self.temp = tf.constant(value=hyper['temp'], dtype=tf.float32) self.prior_file = hyper['prior_file'] self.mu_0 = tf.constant(value=0., dtype=tf.float32, shape=(1, 1, 1, 1)) self.xi_0 = tf.constant(value=0., dtype=tf.float32, shape=(1, 1, 1, 1)) self.ng = GaussianSoftmaxDist(mu=self.mu_0, xi=self.xi_0) self.load_prior_values()
def reparameterize(self, params_broad): mu, xi = params_broad self.ng = GaussianSoftmaxDist(mu=mu, xi=xi, temp=self.temp, sample_size=self.sample_size) self.ng.do_reparameterization_trick() z_discrete = [self.ng.log_psi] return z_discrete
def reparameterize(self, params_broad): mean, log_var, mu, xi = params_broad z_norm = sample_normal(mean=mean, log_var=log_var) self.ng = GaussianSoftmaxDist(mu=mu, xi=xi, temp=self.temp, sample_size=self.sample_size) self.ng.do_reparameterization_trick() z_discrete = self.ng.psi self.n_required = z_discrete.shape[1] z = [z_norm, z_discrete] return z
class OptGauSoftMaxDis(OptGauSoftMax): def __init__(self, model, optimizer, hyper): super().__init__(model=model, optimizer=optimizer, hyper=hyper) def reparameterize(self, params_broad): mu, xi = params_broad self.ng = GaussianSoftmaxDist(mu=mu, xi=xi, temp=self.temp, sample_size=self.sample_size) self.ng.do_reparameterization_trick() z_discrete = [self.ng.log_psi] return z_discrete def compute_kl_elements(self, z, params_broad, run_analytical_kl): if run_analytical_kl: kl_norm, kl_dis = self.compute_kl_elements_analytically( params_broad=params_broad) else: kl_norm, kl_dis = self.compute_kl_elements_via_sample( z=z, params_broad=params_broad) return kl_norm, kl_dis def compute_kl_elements_analytically(self, params_broad): μ0, ξ0 = params_broad kl_norm = 0. current_batch_n = self.ng.lam.shape[0] ξ1 = self.xi_0[:current_batch_n, :, :] μ1 = self.mu_0[:current_batch_n, :, :] kl_dis = calculate_kl_norm_via_general_analytical_formula( mean_0=μ0, log_var_0=2 * ξ0, mean_1=μ1, log_var_1=2. * ξ1, axis=(1, 3)) return kl_norm, kl_dis def compute_kl_elements_via_sample(self, z, params_broad): kl_norm = 0. μ, ξ = params_broad current_batch_n = self.ng.lam.shape[0] log_pz = compute_log_normal_pdf( sample=self.temp * self.ng.lam, mean=self.mu_0[:current_batch_n, :, :, :], log_var=2. * self.xi_0[:current_batch_n, :, :, :]) log_qz_x = compute_log_normal_pdf(sample=self.temp * self.ng.lam, mean=μ, log_var=2. * ξ) kl_dis = log_qz_x - log_pz kl_dis = tf.reduce_sum(kl_dis, axis=2) return kl_norm, kl_dis
def determine_distribution(model_type, params, temp, samples_n): if model_type == 'GSMDis': dist = GaussianSoftmaxDist(mu=params[0], xi=params[1], sample_size=samples_n, temp=temp) elif model_type == 'ExpGSDis': dist = ExpGSDist(log_pi=params[0], sample_size=samples_n, temp=temp) else: raise RuntimeError return dist
def reparameterize(self, params_broad): mu, xi = params_broad epsilon = tf.random.normal(shape=mu.shape) self.ng = GaussianSoftmaxDist(mu=mu, xi=xi, temp=self.temp, sample_size=self.sample_size) sigma = tf.math.exp(xi) self.ng.lam = self.model.planar_flow(mu + sigma * epsilon) self.ng.log_psi = self.ng.lam - tf.math.reduce_logsumexp( self.ng.lam, axis=1, keepdims=True) # psi = tf.math.softmax(lam / self.temp, axis=1) z_discrete = [self.ng.log_psi] return z_discrete
class OptGauSoftMax(OptVAE): def __init__(self, model, optimizer, hyper): super().__init__(model=model, optimizer=optimizer, hyper=hyper) self.temp = tf.constant(value=hyper['temp'], dtype=tf.float32) self.prior_file = hyper['prior_file'] self.mu_0 = tf.constant(value=0., dtype=tf.float32, shape=(1, 1, 1, 1)) self.xi_0 = tf.constant(value=0., dtype=tf.float32, shape=(1, 1, 1, 1)) self.ng = GaussianSoftmaxDist(mu=self.mu_0, xi=self.xi_0) self.load_prior_values() def reparameterize(self, params_broad): mean, log_var, mu, xi = params_broad z_norm = sample_normal(mean=mean, log_var=log_var) self.ng = GaussianSoftmaxDist(mu=mu, xi=xi, temp=self.temp, sample_size=self.sample_size) self.ng.do_reparameterization_trick() z_discrete = self.ng.psi self.n_required = z_discrete.shape[1] z = [z_norm, z_discrete] return z def compute_kl_elements(self, z, params_broad, run_analytical_kl): if run_analytical_kl: kl_norm, kl_dis = self.compute_kl_elements_analytically( params_broad=params_broad) else: kl_norm, kl_dis = self.compute_kl_elements_via_sample( z=z, params_broad=params_broad) return kl_norm, kl_dis def compute_kl_elements_analytically(self, params_broad): mean, log_var, μ0, ξ0 = params_broad kl_norm = calculate_kl_norm_via_analytical_formula(mean=mean, log_var=log_var) current_batch_n = self.ng.lam.shape[0] ξ1 = self.xi_0[:current_batch_n, :, :] μ1 = self.mu_0[:current_batch_n, :, :] kl_dis = calculate_kl_norm_via_general_analytical_formula( mean_0=μ0, log_var_0=2 * ξ0, mean_1=μ1, log_var_1=2. * ξ1) return kl_norm, kl_dis def compute_kl_elements_via_sample(self, z, params_broad): mean, log_var, μ, ξ = params_broad z_norm, z_discrete = z kl_norm = sample_kl_norm(z_norm=z_norm, mean=mean, log_var=log_var) kl_dis = self.sample_kl_sb() return kl_norm, kl_dis def sample_kl_sb(self): current_batch_n = self.ng.lam.shape[0] log_pz = compute_log_normal_pdf( sample=self.temp * self.ng.lam, mean=self.mu_0[:current_batch_n, :self.ng.n_required, :], log_var=self.xi_0[:current_batch_n, :self.ng.n_required, :]) log_qz_x = compute_log_normal_pdf(sample=self.temp * self.ng.lam, mean=self.ng.mu, log_var=self.ng.xi) kl_sb = log_qz_x - log_pz return kl_sb def load_prior_values(self): shape = (self.model.batch_size, self.model.disc_latent_n, self.sample_size, self.model.disc_var_num) self.mu_0, self.xi_0 = initialize_mu_and_xi_for_logistic(shape=shape)