def __call__(self, features, dtype=None): x = features for index in range(self._layers): kw = {} if index == self._layers - 1 and self._outscale: kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling( self._outscale) x = self.get(f'h{index}', tfkl.Dense, self._units, self._act, **kw)(x) if self._dist == 'tanh_normal': # https://www.desmos.com/calculator/rcmcf5jwe7 x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) if dtype: x = tf.cast(x, dtype) mean, std = tf.split(x, 2, -1) mean = tf.tanh(mean) std = tf.nn.softplus(std + self._init_std) + self._min_std dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) dist = tfd.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == 'tanh_normal_5': x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) if dtype: x = tf.cast(x, dtype) mean, std = tf.split(x, 2, -1) mean = 5 * tf.tanh(mean / 5) std = tf.nn.softplus(std + 5) + 5 dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) dist = tfd.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == 'normal': x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) if dtype: x = tf.cast(x, dtype) mean, std = tf.split(x, 2, -1) std = tf.nn.softplus(std + self._init_std) + self._min_std dist = tfd.Normal(mean, std) dist = tfd.Independent(dist, 1) elif self._dist == 'normal_1': mean = self.get(f'hout', tfkl.Dense, self._size)(x) if dtype: mean = tf.cast(mean, dtype) dist = tfd.Normal(mean, 1) dist = tfd.Independent(dist, 1) elif self._dist == 'trunc_normal': # https://www.desmos.com/calculator/mmuvuhnyxo x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) x = tf.cast(x, tf.float32) mean, std = tf.split(x, 2, -1) mean = tf.tanh(mean) std = 2 * tf.nn.sigmoid(std / 2) + self._min_std dist = tools.SafeTruncatedNormal(mean, std, -1, 1) dist = tools.DtypeDist(dist, dtype) dist = tfd.Independent(dist, 1) elif self._dist == 'onehot': x = self.get(f'hout', tfkl.Dense, self._size)(x) x = tf.cast(x, tf.float32) dist = tools.OneHotDist(x, dtype=dtype) dist = tools.DtypeDist(dist, dtype) elif self._dist == 'onehot_gumble': x = self.get(f'hout', tfkl.Dense, self._size)(x) if dtype: x = tf.cast(x, dtype) temp = self._temp dist = tools.GumbleDist(temp, x, dtype=dtype) else: raise NotImplementedError(self._dist) return dist
reshaped_dim = [-1, ENCODED_SIZE] inputs_decoder = ENCODED_SIZE # Dense : output = activation(dot(input, kernel) + bias) # def encoder(X_in, rate): activation = tf.nn.leaky_relu with tf.name_scope("encoder_sequence"): e_in = tf.reshape(X_in, shape=[-1, INPUT_SHAPE]) e1 = tf.layers.dense(e_in, units=N_LATENT, activation=activation) e2 = tf.nn.dropout(e1, rate) e3 = tf.layers.flatten(e2) # e4 = prior(ENCODED_SIZE, "prior") with tf.name_scope("prior"): prior = tfd.Independent(tfd.Normal(loc=tf.zeros(ENCODED_SIZE), scale=1), reinterpreted_batch_ndims=1) with tf.name_scope("multivariate_norm"): multivarnorm = tfpl.MultivariateNormalTriL( ENCODED_SIZE, activity_regularizer=tfpl.KLDivergenceRegularizer(prior)) mn = tf.layers.dense(e3, units=N_LATENT) stdev = 0.5 * tf.layers.dense(e3, units=N_LATENT) epsilon = tf.random_normal(tf.stack([tf.shape(e3)[0], N_LATENT])) e_z = mn + tf.multiply(epsilon, tf.exp(stdev)) # return z, mn, stdev DECODED_SIZE = [-1, 4]
def probit(x): """ Applies the CDF from the Normal distribution as the activation rather than a sigmoid """ from tensorflow_probability import distributions return distributions.Normal(0, 1).cdf(x)
def model_wrapped(): return ed.as_random_variable(tfd.Normal(1., 0.1, name="x"))
def Z(self): """Variational posterior for the logit Psi""" return tfd.Normal(self.Z_loc, self.Z_std)
return tf.nn.softplus(x) + 1e-4 def periodic_kernel(x1, x2): # periodic kernel with parameter set to encode # daily activity pattern (period=rescale). return tfp.math.psd_kernels.ExpSinSquared(x1, x2, np.float64(24.0)) # transform for parameter to ensure positive transforms = [sp_shift, sp_shift] #transforms=[sp_shift] # diffuse priors on parameters lpriors = [ tfd.Normal(loc=np.float64(0), scale=np.float64(1)), tfd.Normal(loc=np.float64(0), scale=np.float64(10.)) ] apriors = [ tfd.Normal(loc=np.float64(0.), scale=np.float64(1)), tfd.Normal(loc=np.float64(0), scale=np.float64(10.)) ] lparams_init = [0.0, 0.0] aparams_init = [0.0, 0.0] # create the model #2880 mover = moveNS(T, X, Z,
def estimate_transcript_expression_dropout(init_feed_dict, num_samples, n, vars, x0_log): log_prior = 0.0 x_dropout_loc = tf.Variable([-15.0], name="x_dropout_prior_loc") x_dropout_scale = tf.nn.softplus( tf.Variable([2.0], name="x_dropout_prior_scale")) x_dropout_loc = tf.Print(x_dropout_loc, [x_dropout_loc], "x_dropout_loc") x_dropout_scale = tf.Print(x_dropout_scale, [x_dropout_scale], "x_dropout_scale") # x_dropout_dist = tfd.MultivariateNormalDiag( # loc=tf.ones([1,n]) * x_dropout_loc, # scale_diag=tf.ones([1,n]) * x_dropout_scale) x_dropout_dist = tfd.Normal(loc=x_dropout_loc, scale=x_dropout_scale) # x_dropout_prob = tf.Variable(tf.zeros([1,2]), name="x_dropout_prob", trainable=False) x_dropout_prob = tf.sigmoid(tf.Variable(0.0, name="x_dropout_prob")) # x_dropout_prob = tf.Variable([[-4.0, 4.0]], name="x_dropout_prob", trainable=False) # x_dropout_prob = tf.Print(x_dropout_prob, [x_dropout_prob], "x_dropout_prob") x_dropout_prob = tf.Print(x_dropout_prob, [x_dropout_prob], "x_dropout_prob") # The question here is whether the non-dropout prior # should be pooled across transcripts or what. x_non_dropout_scale = tf.Print(x_non_dropout_scale, [ tf.reduce_min(x_non_dropout_scale), tf.reduce_max(x_non_dropout_scale) ], "x_non_dropout_scale span") # TODO: No! We can't set the prior based on x0_log values! x_non_dropout_loc_prior = tfd.Normal(loc=tf.constant(-8.0, dtype=tf.float32), scale=2.0) x_non_dropout_loc = tf.Variable( # tf.expand_dims(np.mean(x0_log, 0), 0), tf.fill([1, n], np.float32(np.quantile(x0_log, 0.95))), dtype=tf.float32, name="x_non_dropout_loc") log_prior += tf.reduce_sum( x_non_dropout_loc_prior.log_prob(x_non_dropout_loc)) x_non_dropout_loc = tf.Print( x_non_dropout_loc, [tf.reduce_min(x_non_dropout_loc), tf.reduce_max(x_non_dropout_loc)], "x_non_dropout_loc span") # x_non_dropout_dist = tfd.MultivariateNormalDiag( # loc=x_non_dropout_loc, # scale_diag=x_non_dropout_scale) x_non_dropout_dist = tfd.Normal(loc=x_non_dropout_loc, scale=x_non_dropout_scale) # print(x_dropout_dist.batch_shape) # print(x_non_dropout_dist.batch_shape) # print(x_dropout_dist.event_shape) # print(x_non_dropout_dist.event_shape) # print(tfd.Categorical(logits=x_dropout_prob).event_shape) # print(tfd.Categorical(logits=x_dropout_prob).batch_shape) # x_prior = tfd.Mixture( # cat=tfd.Categorical(logits=x_dropout_prob), # components=[ # x_dropout_dist, # x_non_dropout_dist]) # x_prior = x_non_dropout_dist x = tf.Variable(x0_log, dtype=tf.float32, name="x") # x = tf.Print(x, [tf.reduce_min(x), tf.reduce_max(x)], "x span") # TODO: I'm afraid this may not work as a mixture due to numerical issues dropout_log_prob = tf.log(x_dropout_prob) non_dropout_log_prob = tf.log(1.0 - x_dropout_prob) x_dropout_log_prob = x_dropout_dist.log_prob(x) + dropout_log_prob x_non_dropout_log_prob = x_non_dropout_dist.log_prob( x) + non_dropout_log_prob x_dropout_log_prob = tf.Print(x_dropout_log_prob, [x_dropout_log_prob], "x_dropout_log_prob") x_non_dropout_log_prob = tf.Print(x_non_dropout_log_prob, [x_non_dropout_log_prob], "x_non_dropout_log_prob") x_log_prob = tf.reduce_logsumexp( tf.stack([x_dropout_log_prob, x_non_dropout_log_prob]), 0) x_log_prob = tf.Print(x_log_prob, [x_log_prob], "x_log_prob") log_prior += tf.reduce_sum(x_log_prob) # log_prior += tf.reduce_sum(x_non_dropout_log_prob) # x_non_dropout_log_prob = # log_prior += tf.reduce_sum( # x_prior.log_prob(x)) # log_prior += tf.reduce_sum( # x_non_dropout_dist.log_prob(x)) # # manual calculation # distribution_log_probs = [d.log_prob(x) for d in x_prior.components] # distribution_log_probs[0] = tf.Print(distribution_log_probs[0], [distribution_log_probs[0]], "comp log prob 0") # distribution_log_probs[1] = tf.Print(distribution_log_probs[1], [distribution_log_probs[1]], "comp log prob 1") # cat_log_probs = tf.unstack(tf.nn.softmax(x_dropout_prob), axis=1) # final_log_probs = tf.stack([dlp + clp for (dlp, clp) in zip(distribution_log_probs, cat_log_probs)]) # # final_log_probs = tf.Print(final_log_probs, [final_log_probs], "final_log_probs") # x_prior_prob = tf.reduce_logsumexp(final_log_probs, 0) # x_prior_prob = tf.Print(x_prior_prob, [x_prior_prob], "x_prior_prob") # log_prior += tf.reduce_sum(x_prior_prob) log_likelihood = rnaseq_approx_likelihood_from_vars(vars, x) log_posterior = log_likelihood + log_prior sess = tf.Session() train(sess, -log_posterior, init_feed_dict, 100, 5e-2)
def __call__(self, x): mean = self.net(x) var = tf.ones(tf.shape(mean), dtype=tf.float32) return tfd.Normal(loc=mean, scale=var)
def estimate_splicing_code( qx_feature_loc, qx_feature_scale, donor_seqs, acceptor_seqs, alt_donor_seqs, alt_acceptor_seqs, donor_cons, acceptor_cons, alt_donor_cons, alt_acceptor_cons, tissues): num_samples = len(tissues) num_tissues = np.max(tissues) tissue_matrix = np.zeros((num_samples, num_tissues), dtype=np.float32) for (i, j) in enumerate(tissues): tissue_matrix[i, j-1] = 1 seqs = np.hstack( [donor_seqs, acceptor_seqs, alt_donor_seqs, alt_acceptor_seqs]) # [ num_features, seq_length, 4 ] cons = np.hstack( [donor_cons, acceptor_cons, alt_donor_cons, alt_acceptor_cons]) seqs = np.concatenate((seqs, np.expand_dims(cons, 2)), axis=2) print(seqs.shape) # sys.exit() num_features = seqs.shape[0] # split into testing and training data shuffled_feature_idxs = np.arange(num_features) np.random.shuffle(shuffled_feature_idxs) seqs_train_len = int(np.floor(0.75 * num_features)) seqs_test_len = num_features - seqs_train_len print(num_features) print(seqs_train_len) print(seqs_test_len) print(qx_feature_loc.shape) print(qx_feature_scale.shape) train_idxs = shuffled_feature_idxs[:seqs_train_len] test_idxs = shuffled_feature_idxs[seqs_train_len:] seqs_train = seqs[train_idxs] seqs_test = seqs[test_idxs] qx_feature_loc_train = qx_feature_loc[:,train_idxs] qx_feature_scale_train = qx_feature_scale[:,train_idxs] qx_feature_loc_test = qx_feature_loc[:,test_idxs] qx_feature_scale_test = qx_feature_scale[:,test_idxs] # invented data to test my intuition # seqs_train = np.array( # [[[1.0, 0.0], # [1.0, 0.0], # [1.0, 0.0], # [1.0, 0.0]], # [[0.0, 1.0], # [0.0, 1.0], # [0.0, 1.0], # [0.0, 1.0]]], # dtype=np.float32) # seqs_test = np.copy(seqs_train) # tissue_matrix = np.array( # [[1], # [1], # [1]], # dtype=np.float32) # qx_feature_loc_train = np.array( # [[-1.0, 1.0], # [-1.1, 1.1], # # [-0.5, 0.5]], # [0.9, -0.9]], # dtype=np.float32) # qx_feature_scale_train = np.array( # [[0.1, 0.1], # [0.1, 0.1], # # [0.1, 0.1]], # [1.0, 1.0]], # dtype=np.float32) # qx_feature_loc_test = np.copy(qx_feature_loc_train) # qx_feature_scale_test = np.copy(qx_feature_scale_train) # num_tissues = 1 # num_samples = qx_feature_loc_train.shape[0] # seqs_train_len = 2 # print(qx_feature_loc_train) # print(qx_feature_scale_train) # sys.exit() keep_prob = tf.placeholder(tf.float32) # model lyr0_input = tf.placeholder(tf.float32, (None, seqs_train.shape[1], seqs_train.shape[2])) # lyr0 = tf.layers.flatten(lyr0_input) lyr0 = lyr0_input print(lyr0) training_flag = tf.placeholder(tf.bool) conv1 = tf.layers.conv1d( inputs=lyr0, filters=32, kernel_size=4, activation=tf.nn.leaky_relu, kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-1), name="conv1") conv1_dropout = tf.layers.dropout( inputs=conv1, rate=0.5, training=training_flag, name="conv1_dropout") pool1 = tf.layers.max_pooling1d( inputs=conv1_dropout, pool_size=2, strides=2, name="pool1") conv2 = tf.layers.conv1d( inputs=pool1, filters=64, kernel_size=4, activation=tf.nn.leaky_relu, kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-1), name="conv2") pool2 = tf.layers.max_pooling1d( inputs=conv2, pool_size=2, strides=2, name="pool2") pool2_flat = tf.layers.flatten( pool2, name="pool2_flat") # pool2_flat = tf.layers.flatten(conv1_dropout) dense1 = tf.layers.dense( inputs=pool2_flat, units=256, activation=tf.nn.leaky_relu, kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-1), name="dense1") # dropout1 = tf.layers.dropout( # inputs=dense1, # rate=0.5, # training=training_flag, # name="dropout1") prediction_layer = tf.layers.dense( # inputs=dropout1, inputs=dense1, units=num_tissues, activation=tf.identity) # [num_features, num_tissues] # TODO: eventually this should be a latent variable # x_scale = 0.2 x_scale_prior = tfd.InverseGamma( concentration=0.001, rate=0.001, name="x_scale_prior") x_scale = tf.nn.softplus(tf.Variable(tf.fill([seqs_train_len], -3.0))) # x_scale = tf.constant(0.1) print(tissue_matrix.shape) x_mu = tf.matmul( tf.constant(tissue_matrix), tf.transpose(prediction_layer)) # [num_samples, num_features] x_prior = tfd.Normal( loc=x_mu, # loc=0.0, scale=x_scale, name="x_prior") # x_prior = tfd.StudentT( # loc=x_mu, # scale=x_scale, # df=2.0, # name="x_prior") x_likelihood_loc = tf.placeholder(tf.float32, [num_samples, None]) x_likelihood_scale = tf.placeholder(tf.float32, [num_samples, None]) x_likelihood = ed.Normal( loc=x_likelihood_loc, scale=x_likelihood_scale, name="x_likelihood") # x = x_likelihood x = tf.Variable( qx_feature_loc_train, # tf.random_normal(qx_feature_loc_train.shape), # tf.zeros(qx_feature_loc_train.shape), # qx_feature_loc_train + qx_feature_scale_train * np.float32(np.random.randn(*qx_feature_loc_train.shape)), # trainable=False, name="x") print("X") print(x) # x_delta = tf.Variable( # # qx_feature_loc_train, # # tf.random_normal(qx_feature_loc_train.shape), # tf.zeros(qx_feature_loc_train.shape), # # trainable=False, # name="x") # x_delta = tf.Print(x_delta, # [tf.reduce_min(x_delta), tf.reduce_max(x_delta)], "x_delta span") # x = tf.Print(x, # [tf.reduce_min(x - qx_feature_loc_train), tf.reduce_max(x - qx_feature_loc_train)], # "x deviance from init") # print(x_prior.log_prob(x)) # print(x_likelihood.log_prob(x)) # sys.exit() # log_prior = tf.reduce_sum(x_prior.log_prob(x_delta)) # log_likelihood = tf.reduce_sum(x_likelihood.distribution.log_prob(x_mu + x_delta)) log_prior = tf.reduce_sum(x_prior.log_prob(x)) + tf.reduce_sum(x_scale_prior.log_prob(x_scale)) log_likelihood = tf.reduce_sum(x_likelihood.distribution.log_prob(x)) log_posterior = log_prior + log_likelihood # log_posterior = x_likelihood.distribution.log_prob(x_mu) sess = tf.Session() optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) train = optimizer.minimize(-log_posterior) sess.run(tf.global_variables_initializer()) # dropout doesn't seem to do much.... train_feed_dict = { training_flag: True, # training_flag: False, lyr0_input: seqs_train, x_likelihood_loc: qx_feature_loc_train, x_likelihood_scale: qx_feature_scale_train } test_feed_dict = { training_flag: False, lyr0_input: seqs_test, x_likelihood_loc: qx_feature_loc_test, x_likelihood_scale: qx_feature_scale_test } n_iter = 1000 mad_sample = median_absolute_deviance_sample(x_mu, x_likelihood) for iter in range(n_iter): # _, log_prior_value, log_likelihood_value = sess.run( # [train, log_prior, log_likelihood], # feed_dict=train_feed_dict) sess.run( [train], feed_dict=train_feed_dict) # print((log_prior_value, log_likelihood_value)) if iter % 100 == 0: # print(iter) # print("x") # print(sess.run(x)) # print("x likelihood") # print(sess.run(x_likelihood.distribution.log_prob(x), feed_dict=train_feed_dict)) # print("x_mu") # print(sess.run(x_mu, feed_dict=train_feed_dict)) # print(sess.run(x_mu, feed_dict=test_feed_dict)) # print("x_mu likelihood") # print(sess.run(x_likelihood.distribution.log_prob(x_mu), feed_dict=train_feed_dict)) # print(sess.run(x_likelihood.distribution.log_prob(x_mu), feed_dict=test_feed_dict)) print(sess.run(tf.reduce_sum(x_likelihood.distribution.log_prob(x_mu)), feed_dict=train_feed_dict)) print(sess.run(tf.reduce_sum(x_likelihood.distribution.log_prob(x_mu)), feed_dict=test_feed_dict)) print(sess.run(tfp.distributions.percentile(x_likelihood.distribution.log_prob(x_mu), 50.0), feed_dict=train_feed_dict)) print(sess.run(tfp.distributions.percentile(x_likelihood.distribution.log_prob(x_mu), 50.0), feed_dict=test_feed_dict)) print(est_expected_median_absolute_deviance(sess, mad_sample, train_feed_dict)) print(est_expected_median_absolute_deviance(sess, mad_sample, test_feed_dict)) print(est_expected_median_absolute_deviance(sess, mad_sample, train_feed_dict)) print(est_expected_median_absolute_deviance(sess, mad_sample, test_feed_dict))
def transform(dist): mean = config.postprocess_fn(dist.mean()) mean = tf.clip_by_value(mean, 0.0, 1.0) return tfd.Independent(tfd.Normal(mean, 1.0), len(dist.event_shape))
def feed_forward(features, data_shape, num_layers=2, activation=tf.nn.relu, mean_activation=None, stop_gradient=False, trainable=True, units=100, std=1.0, low=-1.0, high=1.0, dist='normal', min_std=1e-2, init_std=1.0): hidden = features if stop_gradient: hidden = tf.stop_gradient(hidden) for _ in range(num_layers): hidden = tf.layers.dense(hidden, units, activation, trainable=trainable) mean = tf.layers.dense(hidden, int(np.prod(data_shape)), mean_activation, trainable=trainable) mean = tf.reshape(mean, tools.shape(features)[:-1] + data_shape) if std == 'learned': std = tf.layers.dense(hidden, int(np.prod(data_shape)), None, trainable=trainable) init_std = np.log(np.exp(init_std) - 1) std = tf.nn.softplus(std + init_std) + min_std std = tf.reshape(std, tools.shape(features)[:-1] + data_shape) if dist == 'normal': dist = tfd.Normal(mean, std) dist = tfd.Independent(dist, len(data_shape)) elif dist == 'deterministic': dist = tfd.Deterministic(mean) dist = tfd.Independent(dist, len(data_shape)) elif dist == 'binary': dist = tfd.Bernoulli(mean) dist = tfd.Independent(dist, len(data_shape)) elif dist == 'trunc_normal': # https://www.desmos.com/calculator/rnksmhtgui dist = tfd.TruncatedNormal(mean, std, low, high) dist = tfd.Independent(dist, len(data_shape)) elif dist == 'tanh_normal': # https://www.desmos.com/calculator/794s8kf0es dist = distributions.TanhNormal(mean, std) elif dist == 'tanh_normal_tanh': # https://www.desmos.com/calculator/794s8kf0es mean = 5.0 * tf.tanh(mean / 5.0) dist = distributions.TanhNormal(mean, std) elif dist == 'onehot_score': dist = distributions.OneHot(mean, gradient='score') elif dist == 'onehot_straight': dist = distributions.OneHot(mean, gradient='straight') else: raise NotImplementedError(dist) return dist
def VariationalNormal(name, shape, constraint=None): means = tf.get_variable(name + '_mean', initializer=tf.ones([1]), constraint=constraint) stds = tf.get_variable(name + '_std', initializer=-1.0 * tf.ones([1])) return tfd.Normal(loc=means, scale=tf.nn.softplus(stds))
def bnn(args, X, y, Xval, yval): import tensorflow as tf import tensorflow_probability as tfp from tensorflow_probability import distributions as tfd tf.reset_default_graph() y, y_mean, y_std = normalize_y(y) if args.dataset == 'protein' or args.dataset == 'year_prediction': n_neurons = 100 else: n_neurons = 50 def VariationalNormal(name, shape, constraint=None): means = tf.get_variable(name + '_mean', initializer=tf.ones([1]), constraint=constraint) stds = tf.get_variable(name + '_std', initializer=-1.0 * tf.ones([1])) return tfd.Normal(loc=means, scale=tf.nn.softplus(stds)) x_p = tf.placeholder(tf.float32, shape=(None, X.shape[1])) y_p = tf.placeholder(tf.float32, shape=(None, 1)) with tf.name_scope('model', values=[x_p]): layer1 = tfp.layers.DenseFlipout( units=n_neurons, activation='relu', kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(), bias_posterior_fn=tfp.layers.default_mean_field_normal_fn()) layer2 = tfp.layers.DenseFlipout( units=1, activation='linear', kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(), bias_posterior_fn=tfp.layers.default_mean_field_normal_fn()) predictions = layer2(layer1(x_p)) noise = VariationalNormal('noise', [1], constraint=tf.keras.constraints.NonNeg()) pred_distribution = tfd.Normal(loc=predictions, scale=noise.sample()) neg_log_prob = -tf.reduce_mean(pred_distribution.log_prob(y_p)) kl_div = sum(layer1.losses + layer2.losses) / X.shape[0] elbo_loss = neg_log_prob + kl_div with tf.name_scope("train"): optimizer = tf.train.AdamOptimizer(learning_rate=args.lr) train_op = optimizer.minimize(elbo_loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) it = 0 progressBar = tqdm(desc='Training BNN', total=args.iters, unit='iter') batches = batchify(X, y, batch_size=args.batch_size, shuffel=args.shuffel) while it < args.iters: data, label = next(batches) _, l = sess.run([train_op, elbo_loss], feed_dict={ x_p: data, y_p: label.reshape(-1, 1) }) progressBar.update() progressBar.set_postfix({'loss': l}) it += 1 progressBar.close() W0_samples = layer1.kernel_posterior.sample(1000) b0_samples = layer1.bias_posterior.sample(1000) W1_samples = layer2.kernel_posterior.sample(1000) b1_samples = layer2.bias_posterior.sample(1000) noise_samples = noise.sample(1000) W0, b0, W1, b1, n = sess.run( [W0_samples, b0_samples, W1_samples, b1_samples, noise_samples]) def sample_net(x, W0, b0, W1, b1, n): h = np.maximum( np.matmul(x[np.newaxis], W0) + b0[:, np.newaxis, :], 0.0) return np.matmul( h, W1) + b1[:, np.newaxis, :] + n[:, np.newaxis, :] * np.random.randn() samples = sample_net(Xval, W0, b0, W1, b1, n) m = samples.mean(axis=0) v = samples.var(axis=0) m = m * y_std + y_mean v = v * y_std**2 log_probs = normal_log_prob(yval, m, v) rmse = math.sqrt(((m.flatten() - yval)**2).mean()) return log_probs.mean(), rmse
def __init__(self, mean, std, samples=100): dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, TanhBijector()) dist = tfd.Independent(dist, 1) self._dist = dist self._samples = samples
def call(self, inputs): return tfd.Normal(*self.get_loc_and_scale(inputs))
def estimate_splicing_code_from_kmers( qx_feature_loc, qx_feature_scale, kmer_usage_matrix, tissues): num_samples = len(tissues) num_tissues = np.max(tissues) tissue_matrix = np.zeros((num_samples, num_tissues), dtype=np.float32) for (i, j) in enumerate(tissues): tissue_matrix[i, j-1] = 1 num_features = kmer_usage_matrix.shape[0] num_kmers = kmer_usage_matrix.shape[1] # split into testing and training data shuffled_feature_idxs = np.arange(num_features) np.random.shuffle(shuffled_feature_idxs) seqs_train_len = int(np.floor(0.75 * num_features)) seqs_test_len = num_features - seqs_train_len train_idxs = shuffled_feature_idxs[:seqs_train_len] test_idxs = shuffled_feature_idxs[seqs_train_len:] kmer_usage_matrix_train = kmer_usage_matrix[train_idxs] kmer_usage_matrix_test = kmer_usage_matrix[test_idxs] qx_feature_loc_train = qx_feature_loc[:,train_idxs] qx_feature_scale_train = qx_feature_scale[:,train_idxs] qx_feature_loc_test = qx_feature_loc[:,test_idxs] qx_feature_scale_test = qx_feature_scale[:,test_idxs] W0 = tf.Variable( tf.random_normal([num_kmers, 1], mean=0.0, stddev=0.01), name="W0") # B = tf.Variable( # tf.random_normal([1, num_tissues], mean=0.0, stddev=0.01), # name="B") W_prior = tfd.Normal( loc=0.0, scale=0.1, name="W_prior") W = tf.Variable( tf.random_normal([num_kmers, num_tissues], mean=0.0, stddev=0.01), name="W") X = tf.placeholder(tf.float32, shape=(None, num_kmers), name="X") # Y = B + tf.matmul(X, W0 + W) Y = tf.matmul(X, W0 + W) print(Y) x_scale_prior = tfd.InverseGamma( concentration=0.001, rate=0.001, name="x_scale_prior") x_scale = tf.nn.softplus(tf.Variable(tf.fill([seqs_train_len], -3.0))) x_mu = tf.matmul( tf.constant(tissue_matrix), tf.transpose(Y)) # [num_samples, num_features] print(x_mu) x_prior = tfd.Normal( loc=x_mu, scale=x_scale, name="x_prior") x_likelihood_loc = tf.placeholder(tf.float32, [num_samples, None]) x_likelihood_scale = tf.placeholder(tf.float32, [num_samples, None]) x_likelihood = ed.Normal( loc=x_likelihood_loc, scale=x_likelihood_scale, name="x_likelihood") # Using likelihood x = tf.Variable( qx_feature_loc_train, name="x") # x = x_likelihood_loc # x = x_mu log_prior = \ tf.reduce_sum(x_prior.log_prob(x)) + \ tf.reduce_sum(x_scale_prior.log_prob(x_scale)) + \ tf.reduce_sum(W_prior.log_prob(W)) log_likelihood = tf.reduce_sum(x_likelihood.distribution.log_prob(x)) log_posterior = log_prior + log_likelihood # Using point estimates # x = qx_feature_loc_train # log_prior = \ # tf.reduce_sum(x_prior.log_prob(x)) + \ # tf.reduce_sum(x_scale_prior.log_prob(x_scale)) + \ # tf.reduce_sum(W_prior.log_prob(W)) # log_posterior = log_prior sess = tf.Session() optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) train = optimizer.minimize(-log_posterior) sess.run(tf.global_variables_initializer()) train_feed_dict = { X: kmer_usage_matrix_train, x_likelihood_loc: qx_feature_loc_train, x_likelihood_scale: qx_feature_scale_train } test_feed_dict = { X: kmer_usage_matrix_test, x_likelihood_loc: qx_feature_loc_test, x_likelihood_scale: qx_feature_scale_test } n_iter = 1000 mad_sample = median_absolute_deviance_sample(x_mu, x_likelihood) for iter in range(n_iter): # _, log_prior_value, log_likelihood_value = sess.run( # [train, log_prior, log_likelihood], # feed_dict=train_feed_dict) sess.run( [train], feed_dict=train_feed_dict) # print((log_prior_value, log_likelihood_value)) if iter % 100 == 0: print(iter) print(est_expected_median_absolute_deviance(sess, mad_sample, train_feed_dict)) print(est_expected_median_absolute_deviance(sess, mad_sample, test_feed_dict)) print(sess.run(tf.reduce_min(x_scale))) print(sess.run(tf.reduce_max(x_scale))) # print(sess.run(log_prior, feed_dict=train_feed_dict)) # print(sess.run(log_likelihood, feed_dict=train_feed_dict)) return sess.run(W0), sess.run(W)
def base_dist(self, loc, scale): return tfd.Normal(loc, scale)
def fn(scale): d = tfd.Normal(loc=0, scale=[scale]) x = d.sample() return d, x, d.log_prob(x)
def _init_distribution(conditions): loc, scale = conditions["loc"], conditions["scale"] return tfd.Normal(loc=loc, scale=scale)
# %% optimizer = gpflow.optimizers.Scipy() optimizer.minimize(model.training_loss, model.trainable_variables) print(f"log posterior density at optimum: {model.log_posterior_density()}") # %% [markdown] # Thirdly, we add priors to the hyperparameters. # %% # tfp.distributions dtype is inferred from parameters - so convert to 64-bit model.kernel.lengthscales.prior = tfd.Gamma(f64(1.0), f64(1.0)) model.kernel.variance.prior = tfd.Gamma(f64(1.0), f64(1.0)) model.likelihood.variance.prior = tfd.Gamma(f64(1.0), f64(1.0)) model.mean_function.A.prior = tfd.Normal(f64(0.0), f64(10.0)) model.mean_function.b.prior = tfd.Normal(f64(0.0), f64(10.0)) gpflow.utilities.print_summary(model) # %% [markdown] # We now sample from the posterior using HMC. # %% num_burnin_steps = ci_niter(300) num_samples = ci_niter(500) # Note that here we need model.trainable_parameters, not trainable_variables - only parameters can have priors! hmc_helper = gpflow.optimizers.SamplingHelper(model.log_posterior_density, model.trainable_parameters)
def mix(eta, loc, scale): return tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=eta), components_distribution=tfd.Normal(loc=loc, scale=scale))
def tensorflow_model(self, pars): """Output tensorflow probability model object, to be combined together and sampled from. pars - dictionary of signal and nuisance parameters (tensors, constant or Variable) """ # Need to construct these shapes to match the event_shape, batch_shape, sample_shape # semantics of tensorflow_probability. cov_order = self.get_cov_order() small = 1e-10 #print("pars:",pars) tfds = {} # Determine which SRs participate in the covariance matrix if self.cov is not None: cov = tf.constant(self.cov, dtype=c.TFdtype) cov_diag = tf.constant( [self.cov[k][k] for k in range(len(self.cov))]) # Select which systematic to use, depending on whether SR participates in the covariance matrix bsys_tmp = [ np.sqrt(self.cov_diag[cov_order.index(sr)]) if self.in_cov[i] else self.SR_b_sys[i] for i, sr in enumerate(self.SR_names) ] else: bsys_tmp = self.SR_b_sys[:] # Prepare input parameters #print("input pars:",pars) s = pars[ 's'] * self.s_scaling # We "scan" normalised versions of s, to help optimizer theta = pars[ 'theta'] * self.theta_scaling # We "scan" normalised versions of theta, to help optimizer #print("de-scaled pars: s :",s) #print("de-scaled pars: theta:",theta) theta_safe = theta #print("theta_safe:", theta_safe) #print("rate:", s+b+theta_safe) # Expand dims of internal parameters to match input pars. Right-most dimension is the 'event' dimension, # i.e. parameters for each independent Poisson distribution. The rest go into batch_shape. if s.shape == (): n_batch_dims = 0 else: n_batch_dims = len(s.shape) - 1 new_dims = [1 for i in range(n_batch_dims)] b = tf.constant(self.SR_b, dtype=c.TFdtype) bsys = tf.constant(bsys_tmp, dtype=c.TFdtype) if n_batch_dims > 0: b = tf.reshape(b, new_dims + list(b.shape)) bsys = tf.reshape(bsys, new_dims + list(bsys.shape)) # Poisson model poises0 = tfd.Poisson( rate=tf.abs(s + b + theta_safe) + c.reallysmall ) # Abs works to constrain rate to be positive. Might be confusing to interpret BF parameters though. # Treat SR batch dims as event dims poises0i = tfd.Independent(distribution=poises0, reinterpreted_batch_ndims=1) tfds["n"] = poises0i # Multivariate background constraints if self.cov is not None: #print("theta_safe:",theta_safe) #print("covi:",self.covi) theta_cov = tf.gather(theta_safe, self.covi, axis=-1) #print("theta_cov:",theta_cov) cov_nuis = tfd.MultivariateNormalFullCovariance( loc=theta_cov, covariance_matrix=cov) tfds["x_cov"] = cov_nuis #print("str(cov_nuis):", str(cov_nuis)) # Remaining uncorrelated background constraints if np.sum(~self.in_cov) > 0: nuis0 = tfd.Normal(loc=theta_safe[..., ~self.in_cov], scale=bsys[..., ~self.in_cov]) # Treat SR batch dims as event dims nuis0i = tfd.Independent(distribution=nuis0, reinterpreted_batch_ndims=1) tfds["x_nocov"] = nuis0i else: # Only have uncorrelated background constraints nuis0 = tfd.Normal(loc=theta_safe, scale=bsys) # Treat SR batch dims as event dims nuis0i = tfd.Independent(distribution=nuis0, reinterpreted_batch_ndims=1) tfds["x"] = nuis0i #print("hello3") return tfds #, sample_layout, sample_count
def transform(dist): mean = config.postprocess_fn(dist.mean()) dist = tfd.Independent(tfd.Normal(mean, 1.0), len(dist.event_shape)) return dist
def update_networks(self): (states, actions, rewards, next_states, dones) = self.replay_buffer.get_minibatch(batch_size=self.batch_size) B, M = self.batch_size, self.n_samples # [B, obs_dim] -> [B, obs_dim * M] -> [B * M, obs_dim] next_states_tiled = tf.reshape(tf.tile(next_states, multiples=(1, M)), shape=(B * M, -1)) target_mu, target_sigma = self.target_policy(next_states_tiled) # For MultivariateGaussianPolicy #target_dist = tfd.MultivariateNormalFullCovariance(loc=target_mu, covariance_matrix=target_sigma) # For IndependentGaussianPolicy target_dist = tfd.Independent(tfd.Normal(loc=target_mu, scale=target_sigma), reinterpreted_batch_ndims=1) sampled_actions = target_dist.sample() # [B * M, action_dim] #sampled_actions = tf.clip_by_value(sampled_actions, -1.0, 1.0) # Update Q-network: sampled_qvalues = tf.reshape(self.target_critic( next_states_tiled, sampled_actions), shape=(B, M, -1)) mean_qvalues = tf.reduce_mean(sampled_qvalues, axis=1) TQ = rewards + self.gamma * (1.0 - dones) * mean_qvalues with tf.GradientTape() as tape1: Q = self.critic(states, actions) loss_critic = tf.reduce_mean(tf.square(TQ - Q)) variables = self.critic.trainable_variables grads = tape1.gradient(loss_critic, variables) grads, _ = tf.clip_by_global_norm(grads, 40.) self.critic_optimizer.apply_gradients(zip(grads, variables)) # E-step: # Obtain η* by minimising g(η) with tf.GradientTape() as tape2: temperature = tf.math.softplus(self.log_temperature) q_logsumexp = tf.math.reduce_logsumexp(sampled_qvalues / temperature, axis=1) loss_temperature = temperature * ( self.eps + tf.reduce_mean(q_logsumexp, axis=0)) grad = tape2.gradient(loss_temperature, self.log_temperature) if tf.math.is_nan(grad).numpy().sum() != 0: print("NAN GRAD in TEMPERATURE !!!!!!!!!") import pdb pdb.set_trace() else: self.temperature_optimizer.apply_gradients([ (grad, self.log_temperature) ]) # Obtain sample-based variational distribution q(a|s) temperature = tf.math.softplus(self.log_temperature) # M-step: Optimize the lower bound J with respect to θ weights = tf.squeeze(tf.math.softmax(sampled_qvalues / temperature, axis=1), axis=2) # [B, M, 1] if tf.math.is_nan(weights).numpy().sum() != 0: print("NAN in weights !!!!!!!!!") import pdb pdb.set_trace() with tf.GradientTape(persistent=True) as tape3: online_mu, online_sigma = self.policy(next_states_tiled) # For MultivariateGaussianPolicy #online_dist = tfd.MultivariateNormalFullCovariance(loc=online_mu, covariance_matrix=online_sigma) # For IndependentGaussianPolicy online_dist = tfd.Independent(tfd.Normal(loc=online_mu, scale=online_sigma), reinterpreted_batch_ndims=1) log_probs = tf.reshape(online_dist.log_prob(sampled_actions) + 1e-6, shape=(B, M)) # [B * M, ] -> [B, M] cross_entropy_qp = tf.reduce_sum(weights * log_probs, axis=1) # [B, M] -> [B,] # For MultivariateGaussianPolicy # online_dist_fixedmu = tfd.MultivariateNormalFullCovariance(loc=target_mu, covariance_matrix=online_sigma) # online_dist_fixedsigma = tfd.MultivariateNormalFullCovariance(loc=online_mu, covariance_matrix=target_sigma) # For IndependentGaussianPolicy online_dist_fixedmu = tfd.Independent(tfd.Normal( loc=target_mu, scale=online_sigma), reinterpreted_batch_ndims=1) online_dist_fixedsigma = tfd.Independent( tfd.Normal(loc=online_mu, scale=target_sigma), reinterpreted_batch_ndims=1) kl_mu = tf.reshape( target_dist.kl_divergence(online_dist_fixedsigma), shape=(B, M)) # [B * M, ] -> [B, M] kl_sigma = tf.reshape( target_dist.kl_divergence(online_dist_fixedmu), shape=(B, M)) # [B * M, ] -> [B, M] alpha_mu = tf.math.softplus(self.log_alpha_mu) alpha_sigma = tf.math.softplus(self.log_alpha_sigma) loss_policy = -cross_entropy_qp # [B,] loss_policy += tf.stop_gradient(alpha_mu) * tf.reduce_mean(kl_mu, axis=1) loss_policy += tf.stop_gradient(alpha_sigma) * tf.reduce_mean( kl_sigma, axis=1) loss_policy = tf.reduce_mean(loss_policy) # [B,] -> [1] loss_alpha_mu = tf.reduce_mean( alpha_mu * tf.stop_gradient(self.eps_mu - tf.reduce_mean(kl_mu, axis=1))) loss_alpha_sigma = tf.reduce_mean( alpha_sigma * tf.stop_gradient(self.eps_sigma - tf.reduce_mean(kl_sigma, axis=1))) loss_alpha = loss_alpha_mu + loss_alpha_sigma variables = self.policy.trainable_variables grads = tape3.gradient(loss_policy, variables) grads, _ = tf.clip_by_global_norm(grads, 40.) self.policy_optimizer.apply_gradients(zip(grads, variables)) variables = [self.log_alpha_mu, self.log_alpha_sigma] grads = tape3.gradient(loss_alpha, variables) grads, _ = tf.clip_by_global_norm(grads, 40.) self.alpha_optimizer.apply_gradients(zip(grads, variables)) del tape3 with self.summary_writer.as_default(): tf.summary.scalar("loss_policy", loss_policy, step=self.global_steps) tf.summary.scalar("loss_critic", loss_critic, step=self.global_steps) tf.summary.scalar("sigma", tf.reduce_mean(online_sigma), step=self.global_steps) tf.summary.scalar("kl_mu", tf.reduce_mean(kl_mu), step=self.global_steps) tf.summary.scalar("kl_sigma", tf.reduce_mean(kl_sigma), step=self.global_steps) tf.summary.scalar("temperature", temperature, step=self.global_steps) tf.summary.scalar("alpha_mu", alpha_mu, step=self.global_steps) tf.summary.scalar("alpha_sigma", alpha_sigma, step=self.global_steps) tf.summary.scalar("replay_buffer", len(self.replay_buffer), step=self.global_steps)
def joint_limit_cost(self, q, std=0.1): return -ds.Normal( tf.constant(self.joint_limits, dtype=tf.float32)[:, 0], std).log_cdf(q) - ds.Normal( -tf.constant(self.joint_limits, dtype=tf.float32)[:, 1], std).log_cdf(-q)
def mu(self): """Variational posterior for distribution mean""" return tfd.Normal(self.locs, self.scales)
def get_action(self, state): mu, sigma, _ = self.model(state) dist = tfd.Normal(loc=mu[0], scale=sigma[0]) action = dist.sample([1])[0] action = np.clip(action, -self.max_action, self.max_action) return action
def train_high(self, BATCH: High_BatchExperiences): # BATCH.obs_ : [B, N] # BATCH.obs, BATCH.action [B, T, *] batchs = tf.shape(BATCH.obs)[0] with tf.device(self.device): with tf.GradientTape() as tape: s = BATCH.obs[:, 0] # [B, N] true_end = (BATCH.obs_ - s)[:, self.fn_goal_dim:] g_dist = tfd.Normal(loc=true_end, scale=0.5 * self.high_scale[None, :]) ss = tf.expand_dims(BATCH.obs, 0) # [1, B, T, *] ss = tf.tile(ss, [self.sample_g_nums, 1, 1, 1]) # [10, B, T, *] ss = tf.reshape(ss, [-1, tf.shape(ss)[-1]]) # [10*B*T, *] aa = tf.expand_dims(BATCH.action, 0) # [1, B, T, *] aa = tf.tile(aa, [self.sample_g_nums, 1, 1, 1]) # [10, B, T, *] aa = tf.reshape(aa, [-1, tf.shape(aa)[-1]]) # [10*B*T, *] gs = tf.concat([ tf.expand_dims(BATCH.subgoal, 0), tf.expand_dims(true_end, 0), tf.clip_by_value(g_dist.sample(self.sample_g_nums - 2), -self.high_scale, self.high_scale) ], axis=0) # [10, B, N] all_g = gs + s[:, self.fn_goal_dim:] all_g = tf.expand_dims(all_g, 2) # [10, B, 1, N] all_g = tf.tile( all_g, [1, 1, self.sub_goal_steps, 1]) # [10, B, T, N] all_g = tf.reshape(all_g, [-1, tf.shape(all_g)[-1]]) # [10*B*T, N] all_g = all_g - ss[:, self.fn_goal_dim:] # [10*B*T, N] feat = tf.concat([ss, all_g], axis=-1) # [10*B*T, *] _aa = self.low_ac_net.policy_net(feat) # [10*B*T, A] if not self.is_continuous: _aa = tf.one_hot(tf.argmax(_aa, axis=-1), self.a_dim, dtype=tf.float32) diff = _aa - aa diff = tf.reshape( diff, [self.sample_g_nums, batchs, self.sub_goal_steps, -1 ]) # [10, B, T, A] diff = tf.transpose(diff, [1, 0, 2, 3]) # [B, 10, T, A] logps = -0.5 * tf.reduce_sum(tf.norm(diff, ord=2, axis=-1)**2, axis=-1) # [B, 10] idx = tf.argmax(logps, axis=-1, output_type=tf.int32) idx = tf.stack([tf.range(batchs), idx], axis=1) # [B, 2] g = tf.gather_nd(tf.transpose(gs, [1, 0, 2]), idx) # [B, N] q1, q2 = self.high_ac_net.get_value(s, g) q = tf.minimum(q1, q2) target_sub_goal = self.high_ac_target_net.policy_net( BATCH.obs_) * self.high_scale q_target = self.high_ac_target_net.get_min( BATCH.obs_, target_sub_goal) dc_r = tf.stop_gradient(BATCH.reward + self.gamma * (1 - BATCH.done) * q_target) td_error1 = q1 - dc_r td_error2 = q2 - dc_r q1_loss = tf.reduce_mean(tf.square(td_error1)) q2_loss = tf.reduce_mean(tf.square(td_error2)) high_critic_loss = q1_loss + q2_loss high_critic_grads = tape.gradient( high_critic_loss, self.high_ac_net.critic_trainable_variables) self.high_critic_optimizer.apply_gradients( zip(high_critic_grads, self.high_ac_net.critic_trainable_variables)) with tf.GradientTape() as tape: mu = self.high_ac_net.policy_net(s) * self.high_scale q_actor = self.high_ac_net.value_net(s, mu) high_actor_loss = -tf.reduce_mean(q_actor) high_actor_grads = tape.gradient( high_actor_loss, self.high_ac_net.actor_trainable_variables) self.high_actor_optimizer.apply_gradients( zip(high_actor_grads, self.high_ac_net.actor_trainable_variables)) return dict([['LOSS/high_actor_loss', high_actor_loss], ['LOSS/high_critic_loss', high_critic_loss], ['Statistics/high_q_min', tf.reduce_min(q)], ['Statistics/high_q_mean', tf.reduce_mean(q)], ['Statistics/high_q_max', tf.reduce_max(q)]])
def distribution_fn(t): scale = 1e-5 + tf.nn.softplus(c + t[Ellipsis, -1]) return tfd.Independent(tfd.Normal(loc=t[Ellipsis, :n], scale=scale), reinterpreted_batch_ndims=1)
class BatchShapeInferenceTests(test_util.TestCase): @parameterized.named_parameters( { 'testcase_name': '_trivial', 'value_fn': lambda: tfd.Normal(loc=0., scale=1.), 'expected_batch_shape_parts': { 'loc': [], 'scale': [] }, 'expected_batch_shape': [] }, { 'testcase_name': '_simple_tensor_broadcasting', 'value_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=[0., 0.], scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])), 'expected_batch_shape_parts': { 'loc': [], 'scale_diag': [2] }, 'expected_batch_shape': [2] }, { 'testcase_name': '_rank_deficient_tensor_broadcasting', 'value_fn': lambda: tfd.MultivariateNormalDiag( # pylint: disable=g-long-lambda loc=0., scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])), 'expected_batch_shape_parts': { 'loc': [], 'scale_diag': [2] }, 'expected_batch_shape': [2] }, { 'testcase_name': '_dynamic_event_ndims', 'value_fn': lambda: _MVNTriLWithDynamicParamNdims( # pylint: disable=g-long-lambda loc=[[0., 0.], [1., 1.], [2., 2.]], scale_tril=[[1., 0.], [-1., 1.]]), 'expected_batch_shape_parts': { 'loc': [3], 'scale_tril': [] }, 'expected_batch_shape': [3] }, { 'testcase_name': '_mixture_same_family', 'value_fn': lambda: tfd.MixtureSameFamily( # pylint: disable=g-long-lambda mixture_distribution=tfd.Categorical(logits=[[[1., 2., 3.], [4., 5., 6.]]]), components_distribution=tfd.Normal( loc=0., scale=[[[1., 2., 3.], [4., 5., 6.]]])), 'expected_batch_shape_parts': { 'mixture_distribution': [1, 2], 'components_distribution': [1, 2] }, 'expected_batch_shape': [1, 2] }, { 'testcase_name': '_deeply_nested', 'value_fn': lambda: tfd.Independent( # pylint: disable=g-long-lambda tfd.Independent(tfd.Independent(tfd.Independent( tfd.Normal(loc=0., scale=[[[[[[[[1.]]]]]]]]), reinterpreted_batch_ndims=2), reinterpreted_batch_ndims=0), reinterpreted_batch_ndims=1), reinterpreted_batch_ndims=1), 'expected_batch_shape_parts': { 'distribution': [1, 1, 1, 1] }, 'expected_batch_shape': [1, 1, 1, 1] }, { 'testcase_name': 'noparams', 'value_fn': tfb.Exp, 'expected_batch_shape_parts': {}, 'expected_batch_shape': [] }) @test_util.numpy_disable_test_missing_functionality('b/188002189') def test_batch_shape_inference_is_correct(self, value_fn, expected_batch_shape_parts, expected_batch_shape): value = value_fn( ) # Defer construction until we're in the right graph. parts = batch_shape_lib.batch_shape_parts(value) self.assertAllEqualNested( parts, nest.map_structure_up_to(parts, tf.TensorShape, expected_batch_shape_parts)) self.assertAllEqual(expected_batch_shape, batch_shape_lib.inferred_batch_shape_tensor(value)) batch_shape = batch_shape_lib.inferred_batch_shape(value) self.assertIsInstance(batch_shape, tf.TensorShape) self.assertTrue(batch_shape.is_compatible_with(expected_batch_shape)) def test_bijector_event_ndims(self): bij = tfb.Sigmoid(low=tf.zeros([2]), high=tf.ones([3, 2])) self.assertAllEqual(batch_shape_lib.inferred_batch_shape(bij), [3, 2]) self.assertAllEqual(batch_shape_lib.inferred_batch_shape_tensor(bij), [3, 2]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape(bij, bijector_x_event_ndims=1), [3]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape_tensor( bij, bijector_x_event_ndims=1), [3]) # Verify that we don't pass Nones through to component # `experimental_batch_shape(x_event_ndims=None)` calls, where they'd be # incorrectly interpreted as `x_event_ndims=forward_min_event_ndims`. joint_bij = tfb.JointMap([bij, bij]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape( joint_bij, bijector_x_event_ndims=[None, None]), tf.TensorShape(None))