def model_fn(self): # regression in latent space w = yield JDCRoot( Independent( tfd.Normal(loc=tf.zeros([self.num_factors, self.k]), scale=tf.fill([self.num_factors, self.k], 10.0)))) z_scale = yield JDCRoot( Independent( tfd.HalfCauchy(loc=tf.zeros([self.num_factors, self.k]), scale=1.0))) F_test = yield JDCRoot( Independent( tfd.OneHotCategorical(logits=tf.zeros([ self.num_testing_samples, self.num_factors - self.num_confounders ])))) F_full = tf.concat([tf.expand_dims(self.F, 0), F_test], axis=-2) z = yield Independent( tfd.Normal(loc=tf.matmul(F_full, w), scale=tf.matmul(F_full, z_scale))) x_bias = yield JDCRoot( Independent( tfd.Normal(loc=tf.fill([self.num_features], np.float32(self.x_bias_loc0)), scale=np.float32(self.x_bias_scale0)))) # decoded log-expression space x_loc = x_bias + self.decoder(z) - self.sample_scales x_scale_concentration_c = yield JDCRoot( Independent( tfd.HalfCauchy(loc=tf.zeros([self.kernel_regression_degree]), scale=1.0))) x_scale_mode_c = yield JDCRoot( Independent( tfd.HalfCauchy(loc=tf.zeros([self.kernel_regression_degree]), scale=1.0))) weights = kernel_regression_weights(self.kernel_regression_bandwidth, x_bias, self.x_scale_hinges) x_scale = yield Independent( mean_variance_model(weights, x_scale_concentration_c, x_scale_mode_c)) # log expression distribution x = yield Independent(tfd.StudentT(df=1.0, loc=x_loc, scale=x_scale)) if not self.use_point_estimates: rnaseq_reads = yield tfd.Independent( rnaseq_approx_likelihood_from_vars(self.vars, x))
def eight_schools_log_pdf(z, centered=EIGHT_SCHOOL_CENTERED): prior_mu = tfd.Normal(loc=0, scale=5) prior_tau = tfd.HalfCauchy(loc=0, scale=5) mu, log_tau = z[:, -2], z[:, -1] # Adapt size of mu an tau. mu = tf.transpose(eight_schools_replicate * mu) log_tau = tf.transpose(eight_schools_replicate * log_tau) if centered: # shapes, thetas=(8,N), mu=(N,), tau=(N,) thetas = z[:, 0:eight_schools_K] likelihood = tfd.Normal(loc=thetas, scale=eight_schools_sigma[0:eight_schools_K]) prior_theta = tfd.Normal(loc=mu, scale=math.exp(log_tau)) log_det_jac = math.log(math.exp( log_tau)) # kept log(exp()) for mathematical understanding. return likelihood.log_prob( eight_schools_y[0:eight_schools_K]) + prior_theta.log_prob( thetas) + prior_mu.log_prob(mu) + prior_tau.log_prob( math.exp(log_tau)) + log_det_jac else: # shapes, thetas=(8,N), mu=(N,), tau=(N,) thetas_tilde = z[:, 0:eight_schools_K] zeros = tf.zeros(mu.shape) ones = tf.ones(log_tau.shape) thetas = mu + thetas_tilde * math.exp(log_tau) likelihood = tfd.Normal(loc=thetas, scale=eight_schools_sigma[0:eight_schools_K]) prior_theta = tfd.Normal(loc=zeros, scale=ones) log_det_jac = math.log(math.exp( log_tau)) # kept log(exp()) for mathematical understanding. return likelihood.log_prob( eight_schools_y[0:eight_schools_K]) + prior_theta.log_prob( thetas_tilde) + prior_mu.log_prob(mu) + prior_tau.log_prob( math.exp(log_tau)) + log_det_jac
def pdf_2D(z, density_name=''): assert density_name in AVAILABLE_2D_DISTRIBUTIONS, "Incorrect density name." if density_name == '': return 1 elif density_name == 'banana': z1, z2 = z[:, 0], z[:, 1] mu = np.array([0.5, 0.5], dtype='float32') cov = np.array([[0.06, 0.055], [0.055, 0.06]], dtype='float32') scale = tf.linalg.cholesky(cov) p = tfd.MultivariateNormalTriL(loc=mu, scale_tril=scale) z2 = z1**2 + z2 z1, z2 = tf.expand_dims(z1, 1), tf.expand_dims(z2, 1) z = tf.concat([z1, z2], axis=1) return p.prob(z) elif density_name == 'circle': z1, z2 = z[:, 0], z[:, 1] norm = (z1**2 + z2**2)**0.5 exp1 = math.exp(-0.2 * ((z1 - 2) / 0.8)**2) exp2 = math.exp(-0.2 * ((z1 + 2) / 0.8)**2) u = 0.5 * ((norm - 4) / 0.4)**2 - math.log(exp1 + exp2) return math.exp(-u) elif density_name == 'eight_schools': y_i = 0 sigma_i = 10 thetas, mu, log_tau = z[:, 0], z[:, 1], z[:, 2] likelihood = tfd.Normal(loc=thetas, scale=sigma_i) prior_theta = tfd.Normal(loc=mu, scale=math.exp(log_tau)) prior_mu = tfd.Normal(loc=0, scale=5) prior_tau = tfd.HalfCauchy(loc=0, scale=5) return likelihood.prob(y_i) * prior_theta.prob(thetas) * prior_mu.prob( mu) * prior_tau.prob(math.exp(log_tau)) * math.exp(log_tau) elif density_name == 'figure_eight': mu1 = 1 * np.array([-1, -1], dtype='float32') mu2 = 1 * np.array([1, 1], dtype='float32') scale = 0.45 * np.array([1, 1], dtype='float32') pi = 0.5 comp1 = tfd.MultivariateNormalDiag(loc=mu1, scale_diag=scale) comp2 = tfd.MultivariateNormalDiag(loc=mu2, scale_diag=scale) return (1 - pi) * comp1.prob(z) + pi * comp2.prob(z)
def _init_distribution(conditions, **kwargs): scale = conditions["scale"] return tfd.HalfCauchy(loc=0, scale=scale, **kwargs)
def create_distributions(self): """Create distribution objects """ self.bijectors = { 'u': tfb.Softplus(), 'v': tfb.Softplus(), 'u_eta': tfb.Softplus(), 'u_tau': tfb.Softplus(), 's': tfb.Softplus(), 's_eta': tfb.Softplus(), 's_tau': tfb.Softplus(), 'w': tfb.Softplus() } symmetry_breaking_decay = self.symmetry_breaking_decay**tf.cast( tf.range(self.latent_dim), self.dtype)[tf.newaxis, ...] distribution_dict = { 'v': tfd.Independent(tfd.HalfNormal(scale=0.1 * tf.ones( (self.latent_dim, self.feature_dim), dtype=self.dtype)), reinterpreted_batch_ndims=2), 'w': tfd.Independent(tfd.HalfNormal( scale=tf.ones((1, self.feature_dim), dtype=self.dtype)), reinterpreted_batch_ndims=2) } if self.horseshoe_plus: distribution_dict = { **distribution_dict, 'u': lambda u_eta, u_tau: tfd.Independent(tfd.HalfNormal( scale=u_eta * u_tau * symmetry_breaking_decay), reinterpreted_batch_ndims= 2), 'u_eta': tfd.Independent(tfd.HalfCauchy( loc=tf.zeros((self.feature_dim, self.latent_dim), dtype=self.dtype), scale=tf.ones((self.feature_dim, self.latent_dim), dtype=self.dtype)), reinterpreted_batch_ndims=2), 'u_tau': tfd.Independent(tfd.HalfCauchy( loc=tf.zeros((1, self.latent_dim), dtype=self.dtype), scale=tf.ones((1, self.latent_dim), dtype=self.dtype) * self.u_tau_scale), reinterpreted_batch_ndims=2), } distribution_dict['s'] = lambda s_eta, s_tau: tfd.Independent( tfd.HalfNormal(scale=s_eta * s_tau), reinterpreted_batch_ndims=2) distribution_dict['s_eta'] = tfd.Independent( tfd.HalfCauchy(loc=tf.zeros((2, self.feature_dim), dtype=self.dtype), scale=tf.ones((2, self.feature_dim), dtype=self.dtype)), reinterpreted_batch_ndims=2) distribution_dict['s_tau'] = tfd.Independent( tfd.HalfCauchy(loc=tf.zeros((1, self.feature_dim), dtype=self.dtype), scale=tf.ones( (1, self.feature_dim), dtype=self.dtype) * self.s_tau_scale), reinterpreted_batch_ndims=2) self.bijectors['u_eta_a'] = tfb.Softplus() self.bijectors['u_tau_a'] = tfb.Softplus() self.bijectors['s_eta_a'] = tfb.Softplus() self.bijectors['s_tau_a'] = tfb.Softplus() distribution_dict['u_eta'] = lambda u_eta_a: tfd.Independent( SqrtInverseGamma(concentration=0.5 * tf.ones( (self.feature_dim, self.latent_dim), dtype=self.dtype), scale=1.0 / u_eta_a), reinterpreted_batch_ndims=2) distribution_dict['u_eta_a'] = tfd.Independent( tfd.InverseGamma(concentration=0.5 * tf.ones( (self.feature_dim, self.latent_dim), dtype=self.dtype), scale=tf.ones( (self.feature_dim, self.latent_dim), dtype=self.dtype)), reinterpreted_batch_ndims=2) distribution_dict['u_tau'] = lambda u_tau_a: tfd.Independent( SqrtInverseGamma(concentration=0.5 * tf.ones( (1, self.latent_dim), dtype=self.dtype), scale=1.0 / u_tau_a), reinterpreted_batch_ndims=2) distribution_dict['u_tau_a'] = tfd.Independent( tfd.InverseGamma(concentration=0.5 * tf.ones( (1, self.latent_dim), dtype=self.dtype), scale=tf.ones( (1, self.latent_dim), dtype=self.dtype) / self.u_tau_scale**2), reinterpreted_batch_ndims=2) distribution_dict['s_eta'] = lambda s_eta_a: tfd.Independent( SqrtInverseGamma(concentration=0.5 * tf.ones( (2, self.feature_dim), dtype=self.dtype), scale=1.0 / s_eta_a), reinterpreted_batch_ndims=2) distribution_dict['s_eta_a'] = tfd.Independent( tfd.InverseGamma(concentration=0.5 * tf.ones( (2, self.feature_dim), dtype=self.dtype), scale=tf.ones((2, self.feature_dim), dtype=self.dtype)), reinterpreted_batch_ndims=2) distribution_dict['s_tau'] = lambda s_tau_a: tfd.Independent( SqrtInverseGamma(concentration=0.5 * tf.ones( (1, self.feature_dim), dtype=self.dtype), scale=1.0 / s_tau_a), reinterpreted_batch_ndims=2) distribution_dict['s_tau_a'] = tfd.Independent( tfd.InverseGamma(concentration=0.5 * tf.ones( (1, self.feature_dim), dtype=self.dtype), scale=tf.ones( (1, self.feature_dim), dtype=self.dtype) / self.s_tau_scale**2), reinterpreted_batch_ndims=2) else: distribution_dict = { **distribution_dict, 'u': tfd.Independent( AbsHorseshoe( scale=(self.u_tau_scale * symmetry_breaking_decay * tf.ones((self.feature_dim, self.latent_dim), dtype=self.dtype)), reinterpreted_batch_ndims=2)), 's': tfd.Independent(AbsHorseshoe( scale=self.s_tau_scale * tf.ones((1, self.feature_dim), dtype=self.dtype)), reinterpreted_batch_ndims=2) } self.prior_distribution = tfd.JointDistributionNamed(distribution_dict) surrogate_dict = { 'v': self.bijectors['v'](build_trainable_normal_dist( -6. * tf.ones( (self.latent_dim, self.feature_dim), dtype=self.dtype), 5e-4 * tf.ones( (self.latent_dim, self.feature_dim), dtype=self.dtype), 2, strategy=self.strategy)), 'w': self.bijectors['w'](build_trainable_normal_dist( -6 * tf.ones((1, self.feature_dim), dtype=self.dtype), 5e-4 * tf.ones((1, self.feature_dim), dtype=self.dtype), 2, strategy=self.strategy)) } if self.horseshoe_plus: surrogate_dict = { **surrogate_dict, 'u': self.bijectors['u'](build_trainable_normal_dist( -6. * tf.ones( (self.feature_dim, self.latent_dim), dtype=self.dtype), 5e-4 * tf.ones( (self.feature_dim, self.latent_dim), dtype=self.dtype), 2, strategy=self.strategy)), 'u_eta': self.bijectors['u_eta'](build_trainable_InverseGamma_dist( 3 * tf.ones( (self.feature_dim, self.latent_dim), dtype=self.dtype), tf.ones((self.feature_dim, self.latent_dim), dtype=self.dtype), 2, strategy=self.strategy)), 'u_tau': self.bijectors['u_tau'](build_trainable_InverseGamma_dist( 3 * tf.ones((1, self.latent_dim), dtype=self.dtype), tf.ones((1, self.latent_dim), dtype=self.dtype), 2, strategy=self.strategy)), } surrogate_dict['s_eta'] = self.bijectors['s_eta']( build_trainable_InverseGamma_dist(tf.ones( (2, self.feature_dim), dtype=self.dtype), tf.ones( (2, self.feature_dim), dtype=self.dtype), 2, strategy=self.strategy)) surrogate_dict['s_tau'] = self.bijectors['s_tau']( build_trainable_InverseGamma_dist(1 * tf.ones( (1, self.feature_dim), dtype=self.dtype), tf.ones( (1, self.feature_dim), dtype=self.dtype), 2, strategy=self.strategy)) surrogate_dict['s'] = self.bijectors['s']( build_trainable_normal_dist( tf.ones((2, self.feature_dim), dtype=self.dtype) * tf.cast([[-2.], [-1.]], dtype=self.dtype), 1e-3 * tf.ones((2, self.feature_dim), dtype=self.dtype), 2, strategy=self.strategy)) self.bijectors['u_eta_a'] = tfb.Softplus() self.bijectors['u_tau_a'] = tfb.Softplus() surrogate_dict['u_eta_a'] = self.bijectors['u_eta_a']( build_trainable_InverseGamma_dist( 2. * tf.ones( (self.feature_dim, self.latent_dim), dtype=self.dtype), tf.ones((self.feature_dim, self.latent_dim), dtype=self.dtype), 2, strategy=self.strategy)) surrogate_dict['u_tau_a'] = self.bijectors['u_tau_a']( build_trainable_InverseGamma_dist( 2. * tf.ones((1, self.latent_dim), dtype=self.dtype), tf.ones((1, self.latent_dim), dtype=self.dtype) / self.u_tau_scale**2, 2, strategy=self.strategy)) self.bijectors['s_eta_a'] = tfb.Softplus() self.bijectors['s_tau_a'] = tfb.Softplus() surrogate_dict['s_eta_a'] = self.bijectors['s_eta_a']( build_trainable_InverseGamma_dist(2. * tf.ones( (2, self.feature_dim), dtype=self.dtype), tf.ones( (2, self.feature_dim), dtype=self.dtype), 2, strategy=self.strategy)) surrogate_dict['s_tau_a'] = self.bijectors['s_tau_a']( build_trainable_InverseGamma_dist( 2. * tf.ones((1, self.feature_dim), dtype=self.dtype), (tf.ones((1, self.feature_dim), dtype=self.dtype) / self.s_tau_scale**2), 2, strategy=self.strategy)) else: surrogate_dict = { **surrogate_dict, 's': self.bijectors['s'](build_trainable_normal_dist( tf.ones((2, self.feature_dim), dtype=self.dtype) * tf.cast([[-2.], [-1.]], dtype=self.dtype), 1e-3 * tf.ones((2, self.feature_dim), dtype=self.dtype), 2, strategy=self.strategy)), 'u': self.bijectors['u'](build_trainable_normal_dist( -9. * tf.ones( (self.feature_dim, self.latent_dim), dtype=self.dtype), 5e-4 * tf.ones( (self.feature_dim, self.latent_dim), dtype=self.dtype), 2, strategy=self.strategy)) } self.surrogate_distribution = tfd.JointDistributionNamed( surrogate_dict) self.surrogate_vars = self.surrogate_distribution.variables self.var_list = list(surrogate_dict.keys()) self.set_calibration_expectations()
def estimate_gmm_precision(qx_loc, qx_scale, fixed_expression=False, profile_trace=False, tensorboard_summaries=False, batch_size=100, err_scale=0.2, edge_cutoff=0.7): num_samples = qx_loc.shape[0] n = qx_loc.shape[1] batch_size = min(batch_size, n) # [num_samples, n] if fixed_expression: qx = qx_loc else: qx = ed.Normal(loc=qx_loc, scale=qx_scale, name="qx") b = np.mean(qx_loc, axis=0) # variational estimate of w # ------------------------- qw_loc_init = tf.placeholder(tf.float32, (batch_size, n), name="qw_loc_init") qw_loc_init_value = np.zeros((batch_size, n), dtype=np.float32) qw_loc = tf.Variable(qw_loc_init, name="qw_loc") qw = qw_loc # variational estimate of w_scale # ------------------------------- qw_scale_loc_init_value = np.full((batch_size, n), -3.0, dtype=np.float32) qw_scale_loc_init = tf.placeholder(tf.float32, (batch_size, n), name="qw_scale_loc_init") qw_scale_loc = tf.Variable(qw_scale_loc_init, name="qw_scale_loc") qw_scale = tf.nn.softplus(qw_scale_loc) # estimate of b # ------------- by_init_value = np.zeros((batch_size, ), dtype=np.float32) by_init = tf.placeholder(tf.float32, (batch_size, ), name="by_init") by = tf.Variable(by_init, name="by", trainable=False) # [batch_size] # w # - w_scale_prior = tfd.HalfCauchy(loc=0.0, scale=1.0, name="w_scale_prior") # qw_scale can be shrunk all the way to zero, producing NaNs qw_scale = tf.clip_by_value(qw_scale, 1e-4, 10000.0) scale_tau = 0.1 w_prior = tfd.Normal(loc=0.0, scale=qw_scale * scale_tau, name="w_prior") # [n, batch_size] mask_init = tf.placeholder(tf.float32, (batch_size, n), name="mask_init") mask_init_value = np.empty([batch_size, n], dtype=np.float32) mask = tf.Variable(mask_init, name="mask", trainable=False) qw_masked = qw * mask # [batch_size, n] qx_std = qx - b # [num_samples, n] # CONDITIONAL CORRELATION # qxqw = tf.matmul(qx_std, qw_masked, transpose_b=True) # [num_samples, batch_size] # y_dist_loc = qxqw + by # UNCONDITIONAL CORRELATION qxqw = tf.expand_dims(qx_std, 1) * tf.expand_dims( qw_masked, 0) # [num_samples, num_batches, n] y_dist_loc = tf.expand_dims(tf.expand_dims(by, 0), -1) + qxqw # [num_samples, num_batches, n] y_dist = tfd.StudentT(loc=y_dist_loc, scale=err_scale, df=10.0) y_slice_start_init = tf.placeholder( tf.int32, 2, name="y_slice_start_init") # set to [0, j] y_slice_start = tf.Variable(y_slice_start_init, name="y_slice_start", trainable=False) y = tf.slice(qx, y_slice_start, [num_samples, batch_size]) # [num_samples, batch_size] # y = tf.Print(y, [tf.square(y_dist_loc - tf.expand_dims(y, -1))], "y", summarize=16) # objective function # ------------------ y = tf.expand_dims(y, -1) y_log_prob = tf.reduce_sum(y_dist.log_prob(y)) w_log_prob = tf.reduce_sum(w_prior.log_prob(qw_masked)) w_scale_log_prob = tf.reduce_sum(w_scale_prior.log_prob(qw_scale)) log_posterior = y_log_prob + w_log_prob + w_scale_log_prob elbo = log_posterior optimizer = tf.train.AdamOptimizer(learning_rate=1e-2) train = optimizer.minimize(-elbo) sess = tf.Session() niter = 1000 feed_dict = dict() feed_dict[qw_scale_loc_init] = qw_scale_loc_init_value feed_dict[qw_loc_init] = qw_loc_init_value feed_dict[mask_init] = mask_init_value feed_dict[by_init] = by_init_value qx_loc_means = np.mean(qx_loc, axis=0) # check_ops = tf.add_check_numerics_ops() if tensorboard_summaries: # tf.summary.histogram("qw_loc_param", qw_loc) # tf.summary.histogram("qw_scale_param", qw_scale_param) tf.summary.scalar("y_log_prob", y_log_prob) tf.summary.scalar("w_log_prob", w_log_prob) tf.summary.scalar("w_scale_log_prob", w_scale_log_prob) tf.summary.scalar("qw min", tf.reduce_min(qw)) tf.summary.scalar("qw max", tf.reduce_max(qw)) tf.summary.scalar("qw_scale min", tf.reduce_min(qw_scale)) tf.summary.scalar("qw_scale max", tf.reduce_max(qw_scale)) # tf.summary.histogram("qw_scale_loc_param", qw_scale_loc) # tf.summary.histogram("qw_scale_scale_param", qw_scale_scale) edges = dict() count = 0 num_batches = math.ceil(n / batch_size) for batch_num in range(num_batches): # deal with n not necessarily being divisible by batch_size if batch_num == num_batches - 1: start_j = n - batch_size else: start_j = batch_num * batch_size fillmask(mask_init_value, start_j, batch_size) feed_dict[y_slice_start_init] = np.array([0, start_j], dtype=np.int32) for k in range(batch_size): by_init_value[k] = b[start_j + k] sess.run(tf.global_variables_initializer(), feed_dict=feed_dict) # if requested, just benchmark one run of the training operation and return if profile_trace: print("WRITING PROFILING DATA") options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() sess.run(train, options=options, run_metadata=run_metadata) fetched_timeline = timeline.Timeline(run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format() with open('log/timeline.json', 'w') as f: f.write(chrome_trace) break if tensorboard_summaries: train_writer = tf.summary.FileWriter( "log/" + "batch-" + str(batch_num), sess.graph) tf.summary.scalar("elbo", elbo) merged_summary = tf.summary.merge_all() for t in range(niter): # _, elbo_val = sess.run([train, elbo]) # _, entropy_val, log_posterior_val, elbo_val = sess.run([train, entropy, log_posterior, elbo]) _, y_log_prob_value, w_log_prob_value, w_scale_log_prob_value = sess.run( [train, y_log_prob, w_log_prob, w_scale_log_prob]) if t % 100 == 0: # print((t, elbo_val, log_posterior_val, entropy_val)) print((y_log_prob_value, w_log_prob_value, w_scale_log_prob_value)) # print((t, elbo_val)) if tensorboard_summaries: train_writer.add_summary(sess.run(merged_summary), t) print("") print("batch") print(start_j) # qw_scale_min, qw_scale_mean, qw_scale_max = sess.run( # [tf.reduce_min(qw_scale), tf.reduce_mean(qw_scale), tf.reduce_max(qw_scale)]) # print(("qw_scale span", qw_scale_min, qw_scale_mean, qw_scale_max)) # lower_credible = sess.run(qw.distribution.quantile(0.01)) # upper_credible = sess.run(qw.distribution.quantile(0.99)) lower_credible = upper_credible = sess.run(qw) print("credible span") print(np.max(lower_credible)) print(np.min(upper_credible)) print("nonzeros") print(np.sum((lower_credible > edge_cutoff))) print(np.sum((upper_credible < -edge_cutoff))) for k in range(batch_size): neighbors = [] for j in range(n): if lower_credible[k, j] > edge_cutoff or upper_credible[ k, j] < -edge_cutoff: neighbors.append( (j, lower_credible[k, j], upper_credible[k, j])) edges[start_j + k] = neighbors count += 1 if count > 4: break return edges
def _base_dist(self, beta: TensorLike, *args, **kwargs): return tfd.HalfCauchy(loc=0, scale=beta)