def call(self, inputs, training=None): logits = self.logits(inputs) q = distributions.OneHotCategorical( logits=logits, dtype=tf.float32, ) if self.beta > 0.0: kld = q.kl_divergence( distributions.OneHotCategorical( logits=tf.zeros_like(logits), dtype=tf.float32, )) self.add_loss(tf.reduce_mean(kld)) return q
def sample_z(x, mu, hyperparams): # get proba dependent on likelihood of x in clusters tau2 = np.full(shape=mu.shape, fill_value=hyperparams['tau2'], dtype='float32') probability = tfd.Normal(loc=mu, scale=tau2).prob(x).numpy() if sum(probability) > 0: probability = probability/sum(probability) z = tfd.OneHotCategorical(probs=probability).sample().numpy() return z, probability
def __init__(self, mixing_distribution, components, **kwargs): super().__init__(mixing_distribution, components, **kwargs) self._components_probs = tf.constant( tf.stack([component.probs for component in self.components], -2)) self._components_distribution = tfd.OneHotCategorical( probs=self._components_probs, dtype=self.dtype) self.uniform_mask = False
def _compute_kl_loss(self, post_probs, prior_probs): """ Compute KL divergence between two OnehotCategorical Distributions Notes: KL[ Q(z_post) || P(z_prior) ] Q(z_prior) := Q(z | h, o) P(z_prior) := P(z | h) Scratch Impl.: qlogq = post_probs * tf.math.log(post_probs) qlogp = post_probs * tf.math.log(prior_probs) kl_div = tf.reduce_sum(qlogq - qlogp, [1, 2]) Inputs: prior_probs (L, B, latent_dim, n_atoms) post_probs (L, B, latent_dim, n_atoms) """ #: Add small value to prevent inf kl post_probs += 1e-5 prior_probs += 1e-5 #: KL Balancing: See 2.2 BEHAVIOR LEARNING Algorithm 2 kl_div1 = tfd.kl_divergence( tfd.Independent( tfd.OneHotCategorical(probs=tf.stop_gradient(post_probs)), reinterpreted_batch_ndims=1), tfd.Independent(tfd.OneHotCategorical(probs=prior_probs), reinterpreted_batch_ndims=1)) kl_div2 = tfd.kl_divergence( tfd.Independent(tfd.OneHotCategorical(probs=post_probs), reinterpreted_batch_ndims=1), tfd.Independent( tfd.OneHotCategorical(probs=tf.stop_gradient(prior_probs)), reinterpreted_batch_ndims=1)) alpha = self.config.kl_alpha kl_loss = alpha * kl_div1 + (1. - alpha) * kl_div2 #: Batch mean kl_loss = tf.reduce_mean(kl_loss) return kl_loss
def model_fn(self): # regression in latent space w = yield JDCRoot( Independent( tfd.Normal(loc=tf.zeros([self.num_factors, self.k]), scale=tf.fill([self.num_factors, self.k], 10.0)))) z_scale = yield JDCRoot( Independent( tfd.HalfCauchy(loc=tf.zeros([self.num_factors, self.k]), scale=1.0))) F_test = yield JDCRoot( Independent( tfd.OneHotCategorical(logits=tf.zeros([ self.num_testing_samples, self.num_factors - self.num_confounders ])))) F_full = tf.concat([tf.expand_dims(self.F, 0), F_test], axis=-2) z = yield Independent( tfd.Normal(loc=tf.matmul(F_full, w), scale=tf.matmul(F_full, z_scale))) x_bias = yield JDCRoot( Independent( tfd.Normal(loc=tf.fill([self.num_features], np.float32(self.x_bias_loc0)), scale=np.float32(self.x_bias_scale0)))) # decoded log-expression space x_loc = x_bias + self.decoder(z) - self.sample_scales x_scale_concentration_c = yield JDCRoot( Independent( tfd.HalfCauchy(loc=tf.zeros([self.kernel_regression_degree]), scale=1.0))) x_scale_mode_c = yield JDCRoot( Independent( tfd.HalfCauchy(loc=tf.zeros([self.kernel_regression_degree]), scale=1.0))) weights = kernel_regression_weights(self.kernel_regression_bandwidth, x_bias, self.x_scale_hinges) x_scale = yield Independent( mean_variance_model(weights, x_scale_concentration_c, x_scale_mode_c)) # log expression distribution x = yield Independent(tfd.StudentT(df=1.0, loc=x_loc, scale=x_scale)) if not self.use_point_estimates: rnaseq_reads = yield tfd.Independent( rnaseq_approx_likelihood_from_vars(self.vars, x))
def __call__(self): """Get the distribution object from the backend""" if get_backend() == 'pytorch': import torch.distributions as tod return tod.one_hot_categorical.OneHotCategorical( logits=self['logits'], probs=self['probs']) else: from tensorflow_probability import distributions as tfd return tfd.OneHotCategorical(logits=self['logits'], probs=self['probs'])
def __init__(self, N, K, B, temperature=1.0, **kwargs): super().__init__(**kwargs) dtype = self.dtype probs = tf.one_hot(np.random.choice(np.arange(K), (N, B)), K, dtype=dtype) # delta probs self.components = tfd.OneHotCategorical(probs=probs, dtype=dtype) # N x B x K self.logits = tf.Variable(tf.random.normal((N, B, K), dtype=dtype), name="logits") self.temperature = temperature
def decoder(topics: tf.Tensor) -> tfd.OneHotCategorical: """ Map Tensor to a OneHotCategorical instance. :param topics: Tensor containing topic values :return: OneHotCategorical of topics """ word_probabilities = tf.matmul(topics, topics_words) # The observations are bag of words and therefore not one-hot. # However, log_prob of OneHotCategorical computes the probability correctly in this case. return tfd.OneHotCategorical(probs=word_probabilities, name="bag_of_words")
def gmm(nb_clusters, hyperparams, batch_size): alpha = np.full(shape=nb_clusters, fill_value=hyperparams['alpha'], dtype='float32') theta = tfd.Dirichlet(concentration=alpha).sample() # assignments probability mu = tfd.Normal(loc=hyperparams['mu0'], scale=hyperparams['sigma0']).sample(nb_clusters) # centroids z = tfd.OneHotCategorical(probs=theta).sample(batch_size) # assignment indicators assignments = tf.argmax(z.numpy(), axis=1) means = tf.gather(mu, assignments) # mapping stds = tf.fill(dims=batch_size, value=hyperparams['tau2']) x = tfd.Normal(loc=means, scale=stds).sample() return mu.numpy(), theta.numpy(), assignments.numpy(), x.numpy()
def test_forward_mean(): a0 = np.array([0.9, 0.08, 0.02]) a = np.array([[0.1, 0.8, 0.1], [0.5, 0.3, 0.2], [0.4, 0.4, 0.2]]) e = np.array([[0.99, 0.01], [0.01, 0.99], [0.5, 0.5]]) model = tfpd.HiddenMarkovModel( tfpd.Categorical(logits=tf.math.log(tf.convert_to_tensor(a0))), tfpd.Categorical(logits=tf.math.log(tf.convert_to_tensor(a))), tfpd.OneHotCategorical(logits=tf.math.log(tf.convert_to_tensor(e))), 5) tst_mean = model.mean() chk_mean = mue.hmm_mean(model, 5) assert np.allclose(tst_mean.numpy(), chk_mean.numpy())
def get_prior(desired_shape, prior_type="normal"): """ Args: size: int, the event size prior_type: string, the type of prior Returns: a prior distribution """ if prior_type is "normal": shapes = tfd.Normal.param_shapes(desired_shape) loc, scale = tf.zeros(shapes["loc"]), tf.ones(shapes["scale"]) return tfd.Normal(loc, scale), shapes elif "categorical" in prior_type: n = get_size(desired_shape) probs = [1 / n for _ in range(n)] if "one_hot" in prior_type: return tfd.OneHotCategorical(probs=probs), n else: return tfd.Categorical(probs=probs), n
def sample_z_prior(self, h): x = self.dense_z_prior1(h) logits = self.dense_z_prior2(x) logits = tf.reshape(logits, [logits.shape[0], self.latent_dim, self.n_atoms]) z_probs = tf.nn.softmax(logits, axis=2) #: batch_shape=[batch_size] event_shape=[32, 32] dist = tfd.Independent(tfd.OneHotCategorical(probs=z_probs), reinterpreted_batch_ndims=1) z = tf.cast(dist.sample(), tf.float32) #: Reparameterization trick for OneHotCategorcalDist z = z + z_probs - tf.stop_gradient(z_probs) return z, z_probs
def encode(x, uln0, rln0, lln0, latent_length, latent_alphabet_size, alphabet_size, padded_data_length, transfer_mats, dtype=tf.float64, eps=1e-32): """First layer of encoder, using the MuE mean.""" # Set initial sequence (replace inf with large number) vxln = tf.maximum(tf.math.log(x), -1e32) # Set insert biases to uniform distribution. vcln = -np.log(alphabet_size) * tf.ones_like(vxln) # Set deletion and insertion parameters. uln = tf.ones((padded_data_length, 2), dtype=dtype) * (uln0 - tf.reduce_logsumexp(uln0))[None, :] rln = tf.ones((padded_data_length, 2), dtype=dtype) * (rln0 - tf.reduce_logsumexp(rln0))[None, :] lln = lln0 - tf.reduce_logsumexp(lln0, axis=1, keepdims=True) # Build HiddenMarkovModel, with one-hot encoded output. a0_enc, a_enc, e_enc = make_hmm_params(vxln, vcln, uln, rln, lln, transfer_mats, eps=eps, dtype=dtype) hmm_enc = tfpd.HiddenMarkovModel(tfpd.Categorical(logits=a0_enc), tfpd.Categorical(logits=a_enc), tfpd.OneHotCategorical(logits=e_enc), latent_length) return hmm_mean(hmm_enc, latent_length)
def RegressMuE(z, latent_dims, latent_length, latent_alphabet_size, alphabet_size, seq_len, transfer_mats, bt_scale, b0_scale, u_conc, r_conc, l_conc, dtype=tf.float32): """Regress MuE model.""" # Factors. bt = Normal(0., bt_scale, sample_shape=[2, latent_dims, latent_length+1, latent_alphabet_size], name="bt") # Offset. b0 = Normal(0., b0_scale, sample_shape=[2, latent_length+1, latent_alphabet_size], name="b0") # Ancestral sequence. vxln = tf.einsum('j,jkl->kl', z, bt[0, :, :, :]) + b0[0, :, :] # Insert biases. vcln = tf.einsum('j,jkl->kl', z, bt[1, :, :, :]) + b0[1, :, :] # Assemble priors -- in this version, we use a Dirichlet. uc, rc, lc = get_prior_conc( latent_length, latent_alphabet_size, alphabet_size, u_conc, r_conc, l_conc, dtype=dtype) # Deletion probability. u = Dirichlet(uc, name="u") # Insertion probability. r = Dirichlet(rc, name="r") # Substitution probability. l = Dirichlet(lc, name="l") # Generate data from the MuE. a0, a, e = mue.make_hmm_params( vxln - tf.reduce_logsumexp(vxln, axis=1, keepdims=True), vcln - tf.reduce_logsumexp(vcln, axis=1, keepdims=True), tf.math.log(u), tf.math.log(r), tf.math.log(l), transfer_mats, eps=eps, dtype=dtype) x = HiddenMarkovModel( tfpd.Categorical(logits=a0), tfpd.Categorical(logits=a), tfpd.OneHotCategorical(logits=e), seq_len, name="x") return x
def design_matrix_model_fn(self): F_test = yield JDCRoot( Independent( tfd.OneHotCategorical(logits=tf.zeros([ self.num_testing_samples, self.num_factors - self.num_confounders ])))) F_full = tf.concat([tf.expand_dims(self.F, 0), F_test], axis=-2) # F_confounders = yield JDCRoot(Independent(tfd.Normal( # loc=tf.zeros([self.num_samples, self.num_confounders]), # scale=1.0))) # F_full = tf.concat([F_full, F_confounders], axis=-1) if self.confounders is not None: F_full = tf.concat( [F_full, tf.expand_dims(self.confounders, 0)], axis=-1) F_full = tf.matmul(F_full, self.F_mask) return F_full
def main(argv): del argv # unused FLAGS.activation = getattr(tf.nn, FLAGS.activation) if tf.gfile.Exists(FLAGS.model_dir): tf.logging.warn("Deleting old log directory at {}".format(FLAGS.model_dir)) tf.gfile.DeleteRecursively(FLAGS.model_dir) tf.gfile.MakeDirs(FLAGS.model_dir) with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() # TODO(b/113163167): Speed up and tune hyperparameters for Bernoulli MNIST. (images, _, handle, training_iterator, heldout_iterator) = build_input_pipeline( FLAGS.data_dir, FLAGS.batch_size, heldout_size=10000, mnist_type=FLAGS.mnist_type) encoder = Encoder(FLAGS.base_depth, FLAGS.activation, FLAGS.latent_size, FLAGS.code_size) decoder = Decoder(FLAGS.base_depth, FLAGS.activation, FLAGS.latent_size * FLAGS.code_size, IMAGE_SHAPE) vector_quantizer = VectorQuantizer(FLAGS.num_codes, FLAGS.code_size) codes = encoder(images) nearest_codebook_entries, one_hot_assignments, dist = vector_quantizer( codes) if FLAGS.bottleneck_type == "deterministic": one_hot_assignments = one_hot_assignments[tf.newaxis, Ellipsis] neg_q_entropy = 0. # Perform straight-through. class_probs = tf.nn.softmax(-dist[tf.newaxis, Ellipsis]) one_hot_assignments = class_probs + tf.stop_gradient( one_hot_assignments - class_probs) elif FLAGS.bottleneck_type == "categorical": one_hot_assignments, neg_q_entropy = ( categorical_bottleneck(dist, vector_quantizer, FLAGS.num_iaf_flows, FLAGS.average_categorical_samples, FLAGS.use_transformer_for_iaf_parameters, FLAGS.sum_over_latents, FLAGS.num_samples)) elif FLAGS.bottleneck_type == "gumbel_softmax": one_hot_assignments, neg_q_entropy = ( gumbel_softmax_bottleneck(dist, vector_quantizer, FLAGS.temperature, FLAGS.num_iaf_flows, FLAGS.use_transformer_for_iaf_parameters, FLAGS.num_samples, FLAGS.sum_over_latents, summary=True)) else: raise ValueError("Unknown bottleneck type.") bottleneck_output = tf.reduce_sum( one_hot_assignments[Ellipsis, tf.newaxis] * tf.reshape( vector_quantizer.codebook, [1, 1, 1, FLAGS.num_codes, FLAGS.code_size]), axis=3) decoder_distribution = decoder(bottleneck_output) reconstructed_images = decoder_distribution.mean()[0] # get first sample reconstruction_loss = -tf.reduce_mean(decoder_distribution.log_prob(images)) commitment_loss = tf.reduce_mean( tf.square(codes[tf.newaxis, Ellipsis] - tf.stop_gradient(nearest_codebook_entries))) commitment_loss = add_ema_control_dependencies( vector_quantizer, tf.reduce_mean(one_hot_assignments, axis=0), # reduce mean over samples codes, commitment_loss, FLAGS.decay) if FLAGS.use_autoregressive_prior: prior_fn = make_transformer_prior(FLAGS.num_codes, FLAGS.code_size) else: prior_fn = make_uniform_prior() prior_inputs = one_hot_assignments if FLAGS.stop_gradient_for_prior: prior_inputs = tf.stop_gradient(one_hot_assignments) prior_loss = make_prior_loss(prior_fn, prior_inputs, sum_over_latents=FLAGS.sum_over_latents) loss = (reconstruction_loss + FLAGS.beta * commitment_loss + prior_loss + FLAGS.entropy_scale * neg_q_entropy) if FLAGS.bottleneck_type == "deterministic": if not FLAGS.sum_over_latents: prior_loss = prior_loss * FLAGS.latent_size heldout_prior_loss = prior_loss heldout_reconstruction_loss = reconstruction_loss heldout_neg_q_entropy = tf.constant(0.) if FLAGS.bottleneck_type == "categorical": # To accurately evaluate heldout NLL, we need to sum over latent dimension # and use a single sample for the categorical (and not multinomial) prior. (heldout_one_hot_assignments, heldout_neg_q_entropy) = categorical_bottleneck( dist, vector_quantizer, FLAGS.num_iaf_flows, FLAGS.average_categorical_samples, FLAGS.use_transformer_for_iaf_parameters, sum_over_latents=True, num_samples=1, summary=False) heldout_bottleneck_output = tf.reduce_sum( heldout_one_hot_assignments[Ellipsis, tf.newaxis] * tf.reshape( vector_quantizer.codebook, [1, 1, 1, FLAGS.num_codes, FLAGS.code_size]), axis=3) heldout_prior_loss = make_prior_loss( prior_fn, heldout_one_hot_assignments, sum_over_latents=True) heldout_decoder_distribution = decoder(heldout_bottleneck_output) heldout_reconstruction_loss = -tf.reduce_mean( heldout_decoder_distribution.log_prob(images)) elif FLAGS.bottleneck_type == "gumbel_softmax": num_test_samples = 1 heldout_q_dist = tfd.OneHotCategorical(logits=-dist, dtype=tf.float32) heldout_one_hot_assignments = heldout_q_dist.sample(num_test_samples) heldout_neg_q_entropy = heldout_q_dist.log_prob( heldout_one_hot_assignments) for flow_num in range(FLAGS.num_iaf_flows): with tf.variable_scope("iaf_variables", reuse=tf.AUTO_REUSE): shifted_codes = shift_assignments(heldout_one_hot_assignments) scale_bias = tf.get_variable("scale_bias_" + str(flow_num)) if FLAGS.use_transformer_for_iaf_parameters: unconstrained_scale = iaf_scale_from_transformer( shifted_codes, FLAGS.code_size, name=str(flow_num)) else: unconstrained_scale = iaf_scale_from_matmul(shifted_codes, name=str(flow_num)) # Don't need to add inverse log determinant jacobian since samples are # discrete (no change in volume when bijecting discrete variables). heldout_one_hot_assignments, _ = iaf_flow(heldout_one_hot_assignments, unconstrained_scale, scale_bias, summary=False) heldout_neg_q_entropy = tf.reduce_sum( tf.reshape(heldout_neg_q_entropy, [-1, FLAGS.latent_size]), axis=1) heldout_neg_q_entropy = tf.reduce_mean(heldout_neg_q_entropy) heldout_nearest_codebook_entries = tf.reduce_sum( heldout_one_hot_assignments[Ellipsis, tf.newaxis] * tf.reshape( vector_quantizer.codebook, [1, 1, 1, FLAGS.num_codes, FLAGS.code_size]), axis=3) # We still evaluate the prior on the transformed samples. But in order # for this categorical distribution to be valid, we have to binarize. heldout_one_hot_assignments = tf.one_hot( tf.argmax(heldout_one_hot_assignments, axis=-1), depth=FLAGS.num_codes) heldout_prior_loss = make_prior_loss( prior_fn, heldout_one_hot_assignments, sum_over_latents=True) heldout_decoder_distribution = decoder(heldout_nearest_codebook_entries) heldout_reconstruction_loss = -tf.reduce_mean( heldout_decoder_distribution.log_prob(images)) marginal_nll = (heldout_prior_loss + heldout_reconstruction_loss + heldout_neg_q_entropy) tf.summary.scalar("losses/total_loss", loss) tf.summary.scalar("losses/neg_q_entropy_loss", neg_q_entropy * FLAGS.entropy_scale) tf.summary.scalar("losses/reconstruction_loss", reconstruction_loss) tf.summary.scalar("losses/prior_loss", prior_loss) tf.summary.scalar("losses/commitment_loss", FLAGS.beta * commitment_loss) tf.summary.scalar("heldout/neg_q_entropy_loss", heldout_neg_q_entropy, collections=["heldout"]) tf.summary.scalar("heldout/reconstruction_loss", heldout_reconstruction_loss, collections=["heldout"]) tf.summary.scalar("heldout/prior_loss", heldout_prior_loss, collections=["heldout"]) # Decode 10 samples from prior for visualization. if FLAGS.use_autoregressive_prior: assignments = tf.zeros([10, 1, FLAGS.num_codes]) # Decode autoregressively. for d in range(FLAGS.latent_size): logits = prior_fn(assignments).logits_parameter() latent_dim_logit = logits[0, :, tf.newaxis, d, :] sample = tfd.OneHotCategorical( logits=latent_dim_logit, dtype=tf.float32).sample() assignments = tf.concat([assignments, sample], axis=1) assignments = assignments[:, 1:, :] else: logits = tf.zeros([10, FLAGS.latent_size, FLAGS.num_codes]) assignments = tf.reduce_mean(tfd.OneHotCategorical( logits=logits, dtype=tf.float32).sample(1), axis=0) prior_samples = tf.reduce_sum( assignments[Ellipsis, tf.newaxis] * tf.reshape(vector_quantizer.codebook, [1, 1, FLAGS.num_codes, FLAGS.code_size]), axis=2) prior_samples = prior_samples[tf.newaxis, Ellipsis] decoded_distribution_given_random_prior = decoder(prior_samples) random_images = decoded_distribution_given_random_prior.mean()[0] # Save summaries. tf.summary.image("train_inputs", tf.cast(images, tf.float32), max_outputs=10, collections=["train_image"]) tf.summary.image("train_reconstructions", reconstructed_images, max_outputs=10, collections=["train_image"]) tf.summary.image("train_prior_samples", tf.cast(random_images, tf.float32), max_outputs=10, collections=["train_image"]) tf.summary.image("heldout_inputs", tf.cast(images, tf.float32), max_outputs=10, collections=["heldout_image"]) tf.summary.image("heldout_reconstructions", reconstructed_images, max_outputs=10, collections=["heldout_image"]) tf.summary.scalar("heldout/marginal_loss", marginal_nll, collections=["heldout"]) # Perform inference by minimizing the loss function. optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) if FLAGS.num_iaf_flows > 0: encoder_variables = encoder.variables iaf_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope="iaf_variables") grads_and_vars = optimizer.compute_gradients(loss) grads_and_vars_except_encoder = [ x for x in grads_and_vars if x[1] not in encoder_variables] grads_and_vars_except_iaf = [ x for x in grads_and_vars if x[1] not in iaf_variables] def train_op_except_iaf(): return optimizer.apply_gradients( grads_and_vars_except_iaf, global_step=global_step) def train_op_except_encoder(): return optimizer.apply_gradients( grads_and_vars_except_encoder, global_step=global_step) def train_op_all(): return optimizer.apply_gradients( grads_and_vars, global_step=global_step) if FLAGS.stop_training_encoder_after_startup: after_startup_train_op = train_op_except_encoder else: after_startup_train_op = train_op_all train_op = tf.cond( global_step < FLAGS.iaf_startup_steps, true_fn=train_op_except_iaf, false_fn=after_startup_train_op) else: train_op = optimizer.minimize(loss, global_step=global_step) summary = tf.summary.merge_all() heldout_summary = tf.summary.merge_all(key="heldout") train_image_summary = tf.summary.merge_all(key="train_image") heldout_image_summary = tf.summary.merge_all(key="heldout_image") init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session(FLAGS.master) as sess: summary_writer = tf.summary.FileWriter(FLAGS.model_dir, sess.graph) sess.run(init) # Run the training loop. train_handle = sess.run(training_iterator.string_handle()) heldout_handle = sess.run(heldout_iterator.string_handle()) for step in range(FLAGS.max_steps): start_time = time.time() _, loss_value = sess.run([train_op, loss], feed_dict={handle: train_handle}) duration = time.time() - start_time if step % 100 == 0: marginal_nll_val = sess.run(marginal_nll, feed_dict={handle: heldout_handle}) print("Step: {:>3d} Training Loss: {:.3f} Heldout NLL: {:.3f} " "({:.3f} sec)".format(step, loss_value, marginal_nll_val, duration)) # Update the events file. summary_str = sess.run(summary, feed_dict={handle: train_handle}) summary_writer.add_summary(summary_str, step) summary_writer.flush() summary_str_heldout = sess.run(heldout_summary, feed_dict={handle: heldout_handle}) summary_writer.add_summary(summary_str_heldout, step) summary_writer.flush() # Periodically save a checkpoint and visualize model progress. if (step + 1) % FLAGS.viz_steps == 0: summary_str_train_images = sess.run( train_image_summary, feed_dict={handle: train_handle}) summary_str_heldout_images = sess.run( heldout_image_summary, feed_dict={handle: heldout_handle}) summary_writer.add_summary(summary_str_train_images, step) summary_writer.add_summary(summary_str_heldout_images, step) checkpoint_file = os.path.join(FLAGS.model_dir, "model.ckpt") saver.save(sess, checkpoint_file, global_step=step)
def categorical_bottleneck(dist, vector_quantizer, num_iaf_flows=0, average_categorical_samples=True, use_transformer_for_iaf_parameters=False, sum_over_latents=True, num_samples=1, summary=True): """Implements soft EM bottleneck using averaged categorical samples. Args: dist: Distances between encoder outputs and codebook entries. Negative distances are used as categorical logits. A float Tensor of shape [batch_size, latent_size, code_size]. vector_quantizer: An instance of the VectorQuantizer class. num_iaf_flows: Number of inverse autoregressive flows. average_categorical_samples: Whether to take the average of `num_samples' categorical samples as in Roy et al. or to use `num_samples` categorical samples to approximate gradient. use_transformer_for_iaf_parameters: Whether to use a Transformer instead of a lower-triangular mat-mul for generating IAF parameters. sum_over_latents: Whether to sum over latent dimension when computing entropy. num_samples: Number of categorical samples. summary: Whether to log entropy histogram. Returns: one_hot_assignments: Simplex-valued assignments sampled from categorical. neg_q_entropy: Negative entropy of categorical distribution. """ latent_size = dist.shape[1] x_means_idx = tf.multinomial( logits=tf.reshape(-dist, [-1, FLAGS.num_codes]), num_samples=num_samples) one_hot_assignments = tf.one_hot(x_means_idx, depth=vector_quantizer.num_codes) one_hot_assignments = tf.reshape( one_hot_assignments, [-1, latent_size, num_samples, vector_quantizer.num_codes]) if average_categorical_samples: summed_assignments = tf.reduce_sum(one_hot_assignments, axis=2) averaged_assignments = tf.reduce_mean(one_hot_assignments, axis=2) entropy_dist = tfd.Multinomial( total_count=tf.cast(num_samples, tf.float32), logits=-dist) neg_q_entropy = entropy_dist.log_prob(summed_assignments) one_hot_assignments = averaged_assignments[tf.newaxis, Ellipsis] else: one_hot_assignments = tf.transpose(one_hot_assignments, [2, 0, 1, 3]) entropy_dist = tfd.OneHotCategorical(logits=-dist, dtype=tf.float32) neg_q_entropy = entropy_dist.log_prob(one_hot_assignments) if summary: tf.summary.histogram("neg_q_entropy_0", tf.reshape(tf.reduce_sum(neg_q_entropy, axis=-1), [-1])) # Perform straight-through. class_probs = tf.nn.softmax(-dist[tf.newaxis, Ellipsis]) one_hot_assignments = class_probs + tf.stop_gradient( one_hot_assignments - class_probs) # Perform IAF flows for flow_num in range(num_iaf_flows): with tf.variable_scope("iaf_variables", reuse=tf.AUTO_REUSE): # Pad the one_hot_assignments by zeroing out the first latent dimension # and shifting the rest down by one (and removing the last dimension). shifted_codes = shift_assignments(one_hot_assignments) if use_transformer_for_iaf_parameters: unconstrained_scale = iaf_scale_from_transformer( shifted_codes, vector_quantizer.code_size, name=str(flow_num)) else: unconstrained_scale = iaf_scale_from_matmul(shifted_codes, name=str(flow_num)) # Initialize scale bias to be log(e/2 - 1) so initial scale + scale_bias # is 1. initial_scale_bias = tf.fill([latent_size, vector_quantizer.num_codes], INITIAL_SCALE_BIAS) scale_bias = tf.get_variable("scale_bias_" + str(flow_num), initializer=initial_scale_bias) # Since categorical is discrete, we don't need to add inverse log # determinant. one_hot_assignments, _ = iaf_flow(one_hot_assignments, unconstrained_scale, scale_bias, summary=summary) if sum_over_latents: neg_q_entropy = tf.reduce_sum(neg_q_entropy, axis=-1) neg_q_entropy = tf.reduce_mean(neg_q_entropy) return one_hot_assignments, neg_q_entropy
def prior_fn(codes): logits = tf.zeros_like(codes) prior_dist = tfd.OneHotCategorical(logits=logits, dtype=tf.float32) return prior_dist
def update_actor_critic(self, trajectory, batch_size=512, strategy="PPO"): """ Actor-Critic update using PPO & Generalized Advantage Estimator """ #: adv: (L*B, 1) targets, weights = self.compute_target(trajectory['state'], trajectory['reward'], trajectory['next_state'], trajectory['discount']) #: (H, L*B, ...) states = trajectory['state'] selected_actions = trajectory['action'] N = weights.shape[0] * weights.shape[1] states = tf.reshape(states, [N, -1]) selected_actions = tf.reshape(selected_actions, [N, -1]) targets = tf.reshape(targets, [N, -1]) weights = tf.reshape(weights, [N, -1]) _, old_action_probs = self.policy(states) old_logprobs = tf.math.log(old_action_probs + 1e-5) for _ in range(10): indices = np.random.choice(N, batch_size) _states = tf.gather(states, indices) _targets = tf.gather(targets, indices) _selected_actions = tf.gather(selected_actions, indices) _old_logprobs = tf.gather(old_logprobs, indices) _weights = tf.gather(weights, indices) #: Update value network with tf.GradientTape() as tape1: v_pred = self.value(_states) advantages = _targets - v_pred value_loss = 0.5 * tf.square(advantages) discount_value_loss = tf.reduce_mean(value_loss * _weights) grads = tape1.gradient(discount_value_loss, self.value.trainable_variables) self.value_optimizer.apply_gradients( zip(grads, self.value.trainable_variables)) #: Update policy network if strategy == "VanillaPG": with tf.GradientTape() as tape2: _, action_probs = self.policy(_states) action_probs += 1e-5 selected_action_logprobs = tf.reduce_sum( _selected_actions * tf.math.log(action_probs), axis=1, keepdims=True) objective = selected_action_logprobs * advantages dist = tfd.Independent( tfd.OneHotCategorical(probs=action_probs), reinterpreted_batch_ndims=0) ent = dist.entropy() policy_loss = objective + self.config.ent_scale * ent[..., None] policy_loss *= -1 discounted_policy_loss = tf.reduce_mean(policy_loss * _weights) elif strategy == "PPO": with tf.GradientTape() as tape2: _, action_probs = self.policy(_states) action_probs += 1e-5 new_logprobs = tf.math.log(action_probs) ratio = tf.reduce_sum(_selected_actions * tf.exp(new_logprobs - _old_logprobs), axis=1, keepdims=True) ratio_clipped = tf.clip_by_value(ratio, 0.9, 1.1) obj_unclipped = ratio * advantages obj_clipped = ratio_clipped * advantages objective = tf.minimum(obj_unclipped, obj_clipped) dist = tfd.Independent( tfd.OneHotCategorical(probs=action_probs), reinterpreted_batch_ndims=0) ent = dist.entropy() policy_loss = objective + self.config.ent_scale * ent[..., None] policy_loss *= -1 discounted_policy_loss = tf.reduce_mean(policy_loss * _weights) grads = tape2.gradient(discounted_policy_loss, self.policy.trainable_variables) self.policy_optimizer.apply_gradients( zip(grads, self.policy.trainable_variables)) info = { "policy_loss": tf.reduce_mean(policy_loss), "objective": tf.reduce_mean(objective), "actor_entropy": tf.reduce_mean(ent), "value_loss": tf.reduce_mean(value_loss), "target_0": tf.reduce_mean(_targets), } return info
def sample(self, feat): logits, probs = self(feat) dist = tfd.Independent(tfd.OneHotCategorical(probs=probs), reinterpreted_batch_ndims=0) actions_onehot = dist.sample() return actions_onehot
def construct_model(self, learning_rate=None): if learning_rate is None: learning_rate = self.learning_rate with self.graph.as_default(): self.sess.close() self.sess = tf.compat.v1.InteractiveSession() self.sess.as_default() self.x = tf.convert_to_tensor(self.rescaled_features, dtype=tf.float32) self.y = tf.convert_to_tensor(self.targets, dtype=tf.float32) # construct precisness self.tau_rescaling = np.zeros((self.num_obs, self.bnn_output_size)) kernel_ranges = self.config.kernel_ranges for obs_index in range(self.num_obs): self.tau_rescaling[obs_index] += kernel_ranges self.tau_rescaling = self.tau_rescaling**2 # construct weight and bias shapes activations = [tf.nn.tanh] weight_shapes, bias_shapes = [[self.feature_size, self.hidden_shape]], [[self.hidden_shape]] for _ in range(1, self.num_layers - 1): activations.append(tf.nn.tanh) weight_shapes.append([self.hidden_shape, self.hidden_shape]) bias_shapes.append([self.hidden_shape]) activations.append(lambda x: x) weight_shapes.append([self.hidden_shape, self.bnn_output_size]) bias_shapes.append([self.bnn_output_size]) # --------------- # construct prior # --------------- self.prior_layer_outputs = [self.x] self.priors = {} for layer_index in range(self.num_layers): weight_shape, bias_shape = weight_shapes[layer_index], bias_shapes[layer_index] activation = activations[layer_index] weight = tfd.Normal(loc=tf.zeros(weight_shape) + self.weight_loc, scale=tf.zeros(weight_shape) + self.weight_scale) bias = tfd.Normal(loc=tf.zeros(bias_shape) + self.bias_loc, scale=tf.zeros(bias_shape) + self.bias_scale) self.priors['weight_%d' % layer_index] = weight self.priors['bias_%d' % layer_index] = bias prior_layer_output = activation(tf.matmul(self.prior_layer_outputs[-1], weight.sample()) + bias.sample()) self.prior_layer_outputs.append(prior_layer_output) self.prior_bnn_output = self.prior_layer_outputs[-1] # draw precisions from gamma distribution self.prior_tau_normed = tfd.Gamma( 12*(self.num_obs/self.frac_feas)**2 + tf.zeros((self.num_obs, self.bnn_output_size)), tf.ones((self.num_obs, self.bnn_output_size)), ) self.prior_tau = self.prior_tau_normed.sample() / self.tau_rescaling self.prior_scale = tfd.Deterministic(1. / tf.sqrt(self.prior_tau)) # ------------------- # construct posterior # ------------------- self.post_layer_outputs = [self.x] self.posteriors = {} for layer_index in range(self.num_layers): weight_shape, bias_shape = weight_shapes[layer_index], bias_shapes[layer_index] activation = activations[layer_index] weight = tfd.Normal(loc=tf.Variable(tf.random.normal(weight_shape)), scale=tf.nn.softplus(tf.Variable(tf.zeros(weight_shape)))) bias = tfd.Normal(loc=tf.Variable(tf.random.normal(bias_shape)), scale=tf.nn.softplus(tf.Variable(tf.zeros(bias_shape)))) self.posteriors['weight_%d' % layer_index] = weight self.posteriors['bias_%d' % layer_index] = bias post_layer_output = activation(tf.matmul(self.post_layer_outputs[-1], weight.sample()) + bias.sample()) self.post_layer_outputs.append(post_layer_output) self.post_bnn_output = self.post_layer_outputs[-1] self.post_tau_normed = tfd.Gamma( 12*(self.num_obs/self.frac_feas)**2 + tf.Variable(tf.zeros((self.num_obs, self.bnn_output_size))), tf.nn.softplus(tf.Variable(tf.ones((self.num_obs, self.bnn_output_size)))), ) self.post_tau = self.post_tau_normed.sample() / self.tau_rescaling self.post_sqrt_tau = tf.sqrt(self.post_tau) self.post_scale = tfd.Deterministic(1. / self.post_sqrt_tau) # map bnn output to prediction post_kernels = {} targets_dict = {} inferences = [] target_element_index = 0 kernel_element_index = 0 while kernel_element_index < len(self.config.kernel_names): kernel_type = self.config.kernel_types[kernel_element_index] kernel_size = self.config.kernel_sizes[kernel_element_index] feature_begin, feature_end = target_element_index, target_element_index + 1 kernel_begin, kernel_end = kernel_element_index, kernel_element_index + kernel_size prior_relevant = self.prior_bnn_output[:, kernel_begin: kernel_end] post_relevant = self.post_bnn_output[:, kernel_begin: kernel_end] if kernel_type == 'continuous': target = self.y[:, kernel_begin: kernel_end] lowers, uppers = self.config.kernel_lowers[kernel_begin: kernel_end], self.config.kernel_uppers[kernel_begin : kernel_end] prior_support = (uppers - lowers) * (1.2 * tf.nn.sigmoid(prior_relevant) - 0.1) + lowers post_support = (uppers - lowers) * (1.2 * tf.nn.sigmoid(post_relevant) - 0.1) + lowers prior_predict = tfd.Normal(prior_support, self.prior_scale[:, kernel_begin: kernel_end].sample()) post_predict = tfd.Normal(post_support, self.post_scale[:, kernel_begin: kernel_end].sample()) targets_dict[prior_predict] = target post_kernels['param_%d' % target_element_index] = { 'loc': tfd.Deterministic(post_support), 'sqrt_prec': tfd.Deterministic(self.post_sqrt_tau[:, kernel_begin: kernel_end]), 'scale': tfd.Deterministic(self.post_scale[:, kernel_begin: kernel_end].sample())} inference = {'pred': post_predict, 'target': target} inferences.append(inference) elif kernel_type in ['categorical', 'discrete']: target = tf.cast(self.y[:, kernel_begin: kernel_end], tf.int32) prior_temperature = 0.5 + 10.0 / (self.num_obs / self.frac_feas) #prior_temperature = 1.0 post_temperature = prior_temperature prior_support = prior_relevant post_support = post_relevant prior_predict_relaxed = tfd.RelaxedOneHotCategorical(prior_temperature, prior_support) prior_predict = tfd.OneHotCategorical(probs=prior_predict_relaxed.sample()) post_predict_relaxed = tfd.RelaxedOneHotCategorical(post_temperature, post_support) post_predict = tfd.OneHotCategorical(probs=post_predict_relaxed.sample()) targets_dict[prior_predict] = target post_kernels['param_%d' % target_element_index] = {'probs': post_predict_relaxed} inference = {'pred': post_predict, 'target': target} inferences.append(inference) ''' Temperature annealing schedule: - temperature of 100 yields 1e-2 deviation from uniform - temperature of 10 yields 1e-1 deviation from uniform - temperature of 1 yields *almost* perfect agreement with expectation - temperature of 0.1 yields perfect agreement with expectation ''' else: GryffinUnknownSettingsError(f'did not understand kernel type: {kernel_type}') target_element_index += 1 kernel_element_index += kernel_size self.post_kernels = post_kernels self.targets_dict = targets_dict self.loss = 0. for inference in inferences: self.loss += - tf.reduce_sum(inference['pred'].log_prob(inference['target'])) self.optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate) self.train_op = self.optimizer.minimize(self.loss) tf.compat.v1.global_variables_initializer().run()
def __init__(self, temp, logits=None, probs=None, dtype=None): self._sample_dtype = dtype or prec.global_policy().compute_dtype self._exact = tfd.OneHotCategorical(logits=logits, probs=probs) super().__init__(temp, logits=logits, probs=probs)
def sample_posterior_zs(self, r, nsamples=1000): posterior_z = tfd.OneHotCategorical(probs=r, dtype=r.dtype) zs = posterior_z.sample(nsamples) return posterior_z, zs