def to_noncentered(centered_state):
     set_values = ed_transforms.make_value_setter(*centered_state)
     with ed.tape() as noncentered_tape:
         with ed.interception(ed_transforms.ncp):
             with ed.interception(set_values):
                 model(*model_args)
     return [tf.identity(v) for v in list(noncentered_tape.values())[:-1]]
 def to_centered(uncentered_state):
     set_values = ed_transforms.make_value_setter(*uncentered_state)
     with ed.interception(set_values):
         with ed.interception(parametrisation):
             with ed.tape() as centered_tape:
                 model(*model_args)
     return [tf.identity(v) for v in list(centered_tape.values())[:-1]]
Exemple #3
0
def custom_elbo(pmodel, qmodel, sample_dict):
    # create combined model
    plate_size = pmodel._get_plate_size(sample_dict)

    # expand the qmodel (just in case the q model uses data from sample_dict, use interceptor too)
    with ed.interception(inf.util.interceptor.set_values(**sample_dict)):
        qvars, _ = qmodel.expand_model(plate_size)

    # expand de pmodel, using the intercept.set_values function, to include the sample_dict and the expanded qvars
    with ed.interception(
            inf.util.interceptor.set_values(**{
                **qvars,
                **sample_dict
            })):
        pvars, _ = pmodel.expand_model(plate_size)

    # compute energy
    energy = tf.reduce_sum(
        [tf.reduce_sum(p.log_prob(p.value)) for p in pvars.values()])

    # compute entropy
    entropy = -tf.reduce_sum(
        [tf.reduce_sum(q.log_prob(q.value)) for q in qvars.values()])

    # compute ELBO
    ELBO = energy + entropy

    # This function will be minimized. Return minus ELBO
    return -ELBO
Exemple #4
0
    def to_noncentered(centered_state):
        set_values = ed_transforms.make_value_setter(*centered_state)
        with ed.tape() as noncentered_tape:
            with ed.interception(ed_transforms.ncp):
                with ed.interception(set_values):
                    model(*model_args)

        param_vals = [
            tf.identity(v) for k, v in noncentered_tape.items()
            if k not in observed_data.keys()
        ]
        return param_vals
    def testMakeValueSetterWorksWithPartialAssignment(self):
        def normal_with_unknown_mean():
            loc = ed.Normal(loc=0., scale=1., name="loc")
            x = ed.Normal(loc=loc, scale=0.5, name="x")
            return x

        # Setting only the latents produces the posterior predictive distribution.
        loc_value = 3.
        with ed.interception(ed.make_value_setter(loc=loc_value)):
            x_predictive = normal_with_unknown_mean()
        self.assertAllEqual(self.evaluate(x_predictive.distribution.mean()),
                            loc_value)

        # Setting observed values allows calling the log joint as a fn of latents.
        x_value = 4.

        def model_with_observed_x():
            with ed.interception(ed.make_value_setter(x=x_value)):
                normal_with_unknown_mean()

        observed_log_joint_fn = ed.make_log_joint_fn(model_with_observed_x)

        expected_joint_log_prob = (
            tfd.Normal(0., 1.).log_prob(loc_value) +
            tfd.Normal(loc_value, 0.5).log_prob(x_value))
        self.assertEqual(self.evaluate(expected_joint_log_prob),
                         self.evaluate(observed_log_joint_fn(loc=loc_value)))
Exemple #6
0
        def log_joint_fn(*args, **kwargs):  # pylint: disable=unused-argument
            states = dict(zip(self.unobserved.keys(), args))
            states.update(self.observed)
            interceptor = interceptors.CollectLogProb(states)
            with ed.interception(interceptor):
                self._f(self._cfg)

            log_prob = sum(interceptor.log_probs)
            return log_prob
Exemple #7
0
 def sample(self, size=1, data={}):
     """ Generates a sample for eache variable in the model """
     expanded_vars, expanded_params = self.expand_model(size)
     with ed.interception(util.interceptor.set_values(**data)):
         expanded_vars, expanded_params = self.expand_model(size)
     return {
         name: tf.convert_to_tensor(var)
         for name, var in expanded_vars.items()
     }
    def testMakeValueSetterSetsValues(self):
        def normal_with_unknown_mean():
            loc = ed.Normal(loc=0., scale=1., name="loc")
            x = ed.Normal(loc=loc, scale=0.5, name="x")
            return loc, x

        loc_value, x_value = 3., 4.
        with ed.interception(ed.make_value_setter(loc=loc_value, x=x_value)):
            loc_rv, x_rv = normal_with_unknown_mean()
        self.assertAllEqual(self.evaluate((loc_rv, x_rv)),
                            (loc_value, x_value))
 def testInterceptionException(self):
   def f():
     raise NotImplementedError()
   def interceptor(f, *fargs, **fkwargs):
     return f(*fargs, **fkwargs)
   old_interceptor = ed.get_interceptor()
   with self.assertRaises(NotImplementedError):
     with ed.interception(interceptor):
       f()
   new_interceptor = ed.get_interceptor()
   self.assertEqual(old_interceptor, new_interceptor)
 def testInterception(self, cls, value, kwargs):
   def interceptor(f, *fargs, **fkwargs):
     name = fkwargs.get("name", None)
     if name == "rv2":
       fkwargs["value"] = value
     return f(*fargs, **fkwargs)
   rv1 = cls(value=value, name="rv1", **kwargs)
   with ed.interception(interceptor):
     rv2 = cls(name="rv2", **kwargs)
   rv1_value, rv2_value = self.evaluate([rv1.value, rv2.value])
   self.assertEqual(rv1_value, value)
   self.assertEqual(rv2_value, value)
 def _testInterception(self, cls, value, *args, **kwargs):
   def interceptor(f, *fargs, **fkwargs):
     name = fkwargs.get("name", None)
     if name == "rv2":
       fkwargs["value"] = value
     return f(*fargs, **fkwargs)
   rv1 = cls(*args, value=value, name="rv1", **kwargs)
   with ed.interception(interceptor):
     rv2 = cls(*args, name="rv2", **kwargs)
   rv1_value, rv2_value = self.evaluate([rv1.value, rv2.value])
   self.assertEqual(rv1_value, value)
   self.assertEqual(rv2_value, value)
Exemple #12
0
    def testInterceptionException(self):
        def f():
            raise NotImplementedError()

        def interceptor(f, *fargs, **fkwargs):
            return f(*fargs, **fkwargs)

        old_interceptor = ed.get_interceptor()
        with self.assertRaises(NotImplementedError):
            with ed.interception(interceptor):
                f()
        new_interceptor = ed.get_interceptor()
        self.assertEqual(old_interceptor, new_interceptor)
  def testInterceptionForwarding(self):
    def double(f, *args, **kwargs):
      return 2. * ed.interceptable(f)(*args, **kwargs)

    def set_xy(f, *args, **kwargs):
      if kwargs.get("name") == "x":
        kwargs["value"] = 1.
      if kwargs.get("name") == "y":
        kwargs["value"] = 0.42
      return ed.interceptable(f)(*args, **kwargs)

    def model():
      x = ed.Normal(loc=0., scale=1., name="x")
      y = ed.Normal(loc=x, scale=1., name="y")
      return x + y

    with ed.interception(set_xy):
      with ed.interception(double):
        z = model()

    value = 2. * 1. + 2. * 0.42
    z_value = self.evaluate(z)
    self.assertAlmostEqual(z_value, value, places=5)
  def testInterceptionForwarding(self):
    def double(f, *args, **kwargs):
      return 2. * ed.interceptable(f)(*args, **kwargs)

    def set_xy(f, *args, **kwargs):
      if kwargs.get("name") == "x":
        kwargs["value"] = 1.
      if kwargs.get("name") == "y":
        kwargs["value"] = 0.42
      return ed.interceptable(f)(*args, **kwargs)

    def model():
      x = ed.Normal(loc=0., scale=1., name="x")
      y = ed.Normal(loc=x, scale=1., name="y")
      return x + y

    with ed.interception(set_xy):
      with ed.interception(double):
        z = model()

    value = 2. * 1. + 2. * 0.42
    z_value = self.evaluate(z)
    self.assertAlmostEqual(z_value, value, places=5)
  def testTapeInnerForwarding(self):
    def double(f, *args, **kwargs):
      return 2. * ed.interceptable(f)(*args, **kwargs)

    def model():
      x = ed.Normal(loc=0., scale=1., name="x")
      y = ed.Normal(loc=x, scale=1., name="y")
      return x + y

    with ed.interception(double):
      with ed.tape() as model_tape:
        output = model()

    expected_value, actual_value = self.evaluate([
        model_tape["x"] + model_tape["y"], output])
    self.assertEqual(list(six.iterkeys(model_tape)), ["x", "y"])
    self.assertEqual(expected_value, actual_value)
  def testTapeInnerForwarding(self):
    def double(f, *args, **kwargs):
      return 2. * ed.interceptable(f)(*args, **kwargs)

    def model():
      x = ed.Normal(loc=0., scale=1., name="x")
      y = ed.Normal(loc=x, scale=1., name="y")
      return x + y

    with ed.interception(double):
      with ed.tape() as model_tape:
        output = model()

    expected_value, actual_value = self.evaluate([
        model_tape["x"] + model_tape["y"], output])
    self.assertEqual(list(six.iterkeys(model_tape)), ["x", "y"])
    self.assertEqual(expected_value, actual_value)
Exemple #17
0
    def test_point(self, sample=True):
        def not_observed(var, *args, **kwargs):  # pylint: disable=unused-argument
            return kwargs['name'] not in self.observed

        values_collector = interceptors.CollectVariables(filter=not_observed)
        chain = [values_collector]
        if not sample:

            def get_mode(state, rv, *args, **kwargs):  # pylint: disable=unused-argument
                return rv.distribution.mode()

            chain.insert(0, interceptors.Generic(after=get_mode))

        with self.graph.as_default(), ed.interception(
                interceptors.Chain(*chain)):
            self._f(self.cfg)
        with self.session.as_default():
            returns = self.session.run(list(values_collector.result.values()))
        return dict(zip(values_collector.result.keys(), returns))
Exemple #18
0
 def testDenseMean(self, layer):
   """Tests that forward pass can use other values, e.g., posterior mean."""
   tf.keras.backend.set_learning_phase(0)  # test time
   def take_mean(f, *args, **kwargs):
     """Sets random variable value to its mean."""
     rv = f(*args, **kwargs)
     rv._value = rv.distribution.mean()
     return rv
   inputs = tf.to_float(np.random.rand(5, 3, 7))
   model = layer(4, activation=tf.nn.relu, use_bias=False)
   outputs1 = model(inputs)
   with ed.interception(take_mean):
     outputs2 = model(inputs)
   self.evaluate(tf.global_variables_initializer())
   res1, res2 = self.evaluate([outputs1, outputs2])
   self.assertEqual(res1.shape, (5, 3, 4))
   self.assertNotAllClose(res1, res2)
   if layer != bayes.DenseDVI:
     self.assertAllClose(res2, np.zeros((5, 3, 4)), atol=1e-4)
Exemple #19
0
 def testDenseReparameterizationMean(self):
   """Tests that forward pass can use other values, e.g., posterior mean."""
   def take_mean(f, *args, **kwargs):
     """Sets random variable value to its mean."""
     rv = f(*args, **kwargs)
     rv._value = rv.distribution.mean()
     return rv
   inputs = tf.to_float(np.random.rand(5, 3, 7))
   layer = bayes.DenseReparameterization(4,
                                         activation=tf.nn.relu,
                                         use_bias=False)
   outputs1 = layer(inputs)
   with ed.interception(take_mean):
     outputs2 = layer(inputs)
   self.evaluate(tf.global_variables_initializer())
   res1, res2 = self.evaluate([outputs1, outputs2])
   self.assertEqual(res1.shape, (5, 3, 4))
   self.assertNotAllClose(res1, res2)
   self.assertAllClose(res2, np.zeros((5, 3, 4)), atol=1e-4)
Exemple #20
0
def ELBO(pmodel, qmodel, plate_size, batch_weight=1):
    # expand de qmodel
    qvars, _ = qmodel.expand_model(plate_size)

    # expand de pmodel, using the intercept.set_values function, to include the sample_dict and the expanded qvars
    with ed.interception(util.interceptor.set_values(**qvars)):
        pvars, _ = pmodel.expand_model(plate_size)

    # compute energy
    energy = tf.reduce_sum([(batch_weight if p.is_datamodel else 1) *
                            tf.reduce_sum(p.log_prob(p.value))
                            for p in pvars.values()])

    # compute entropy
    entropy = -tf.reduce_sum([(batch_weight if q.is_datamodel else 1) *
                              tf.reduce_sum(q.log_prob(q.value))
                              for q in qvars.values() if not q.is_datamodel])

    # compute ELBO
    ELBO = energy + entropy

    # This function will be minimized. Return minus ELBO
    return -ELBO
Exemple #21
0
def _make_likelihood(rv_dict, model):
    """Produces optimizable tensor for model likelihood.

    Args:
        rv_dict: (dict of RandomVariable) Dictionary of random variables
            representing variational family for each model parameter.
        model: (Model) A model that contains definition, likelihood and
            training labels.

    Returns:
        log_likelihood: (tf.Tensor) A likelihood tensor with registered
            gradient with respect to VI parameters.
        outcome_rv: (ed.RandomVariable) A random variable representing
            model's predictive distribution.
        model_tape: (ContextManager) A ContextManager recording the
            model variables in model graph.
    """
    with ed.tape() as model_tape:
        with ed.interception(model_util.make_value_setter(**rv_dict)):
            outcome_rv = model.definition()

    log_likelihood = model.likelihood(outcome_rv, model.outcome_obs)

    return log_likelihood, outcome_rv, model_tape
Exemple #22
0
 def model_vip(*params):
     with ed.interception(insightful_parametrisation):
         return model_config.model(*params)
Exemple #23
0
 def model_vip(*params):
     with ed.interception(learnable_parametrisation):
         return model_config.model(*params)
Exemple #24
0
def main(argv):
    del argv  # unused
    if tf.gfile.Exists(FLAGS.model_dir):
        tf.logging.warning("Warning: deleting old log directory at {}".format(
            FLAGS.model_dir))
        tf.gfile.DeleteRecursively(FLAGS.model_dir)
    tf.gfile.MakeDirs(FLAGS.model_dir)
    tf.enable_eager_execution()

    grammar = SmilesGrammar()
    synthetic_data_distribution = ProbabilisticGrammar(
        grammar=grammar,
        latent_size=FLAGS.latent_size,
        num_units=FLAGS.num_units)

    print("Random examples from synthetic data distribution:")
    for _ in range(5):
        productions = synthetic_data_distribution()
        string = grammar.convert_to_string(productions)
        print(string)

    probabilistic_grammar = ProbabilisticGrammar(grammar=grammar,
                                                 latent_size=FLAGS.latent_size,
                                                 num_units=FLAGS.num_units)
    probabilistic_grammar_variational = ProbabilisticGrammarVariational(
        latent_size=FLAGS.latent_size)

    checkpoint = tf.train.Checkpoint(
        synthetic_data_distribution=synthetic_data_distribution,
        probabilistic_grammar=probabilistic_grammar,
        probabilistic_grammar_variational=probabilistic_grammar_variational)
    global_step = tf.train.get_or_create_global_step()
    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
    writer = tf.contrib.summary.create_file_writer(FLAGS.model_dir)
    writer.set_as_default()

    start_time = time.time()
    for step in range(FLAGS.max_steps):
        productions = synthetic_data_distribution()
        with tf.GradientTape() as tape:
            # Sample from amortized variational distribution and record its trace.
            with ed.tape() as variational_tape:
                _ = probabilistic_grammar_variational(productions)

            # Set model trace to take on the data's values and the sample from the
            # variational distribution.
            values = {"latent_code": variational_tape["latent_code_posterior"]}
            values.update({
                "production_" + str(t): production
                for t, production in enumerate(tf.unstack(productions, axis=1))
            })
            with ed.tape() as model_tape:
                with ed.interception(make_value_setter(**values)):
                    _ = probabilistic_grammar()

            # Compute the ELBO given the variational sample, averaged over the batch
            # size and the number of time steps (number of productions). Although the
            # ELBO per data point sums over time steps, we average in order to have a
            # value that remains on the same scale across batches.
            log_likelihood = 0.
            for name, rv in six.iteritems(model_tape):
                if name.startswith("production"):
                    log_likelihood += rv.distribution.log_prob(rv.value)

            kl = tfp.distributions.kl_divergence(
                variational_tape["latent_code_posterior"].distribution,
                model_tape["latent_code"].distribution)

            timesteps = tf.to_float(productions.shape[1])
            elbo = tf.reduce_mean(log_likelihood - kl) / timesteps
            loss = -elbo
            with tf.contrib.summary.record_summaries_every_n_global_steps(500):
                tf.contrib.summary.scalar(
                    "log_likelihood",
                    tf.reduce_mean(log_likelihood) / timesteps)
                tf.contrib.summary.scalar("kl", tf.reduce_mean(kl) / timesteps)
                tf.contrib.summary.scalar("elbo", elbo)

        variables = (probabilistic_grammar.variables +
                     probabilistic_grammar_variational.variables)
        grads = tape.gradient(loss, variables)
        grads_and_vars = zip(grads, variables)
        optimizer.apply_gradients(grads_and_vars, global_step)

        if step % 500 == 0:
            duration = time.time() - start_time
            print("Step: {:>3d} Loss: {:.3f} ({:.3f} sec)".format(
                step, loss, duration))
            checkpoint.save(file_prefix=FLAGS.model_dir)
def model_fn(features, labels, mode, params, config):
    """Builds the model function for use in an Estimator.

  Arguments:
    features: The input features for the Estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
    del labels, config

    # Set up the model's learnable parameters.
    logit_concentration = tf.get_variable(
        "logit_concentration",
        shape=[1, params["num_topics"]],
        initializer=tf.constant_initializer(
            _softplus_inverse(params["prior_initial_value"])))
    concentration = _clip_dirichlet_parameters(
        tf.nn.softplus(logit_concentration))

    num_words = features.shape[1]
    topics_words_logits = tf.get_variable(
        "topics_words_logits",
        shape=[params["num_topics"], num_words],
        initializer=tf.glorot_normal_initializer())
    topics_words = tf.nn.softmax(topics_words_logits, axis=-1)

    # Compute expected log-likelihood. First, sample from the variational
    # distribution; second, compute the log-likelihood given the sample.
    lda_variational = make_lda_variational(params["activation"],
                                           params["num_topics"],
                                           params["layer_sizes"])
    with ed.tape() as variational_tape:
        _ = lda_variational(features)

    with ed.tape() as model_tape:
        with ed.interception(
                make_value_setter(
                    topics=variational_tape["topics_posterior"])):
            posterior_predictive = latent_dirichlet_allocation(
                concentration, topics_words)

    log_likelihood = posterior_predictive.distribution.log_prob(features)
    tf.summary.scalar("log_likelihood", tf.reduce_mean(log_likelihood))

    # Compute the KL-divergence between two Dirichlets analytically.
    # The sampled KL does not work well for "sparse" distributions
    # (see Appendix D of [2]).
    kl = variational_tape["topics_posterior"].distribution.kl_divergence(
        model_tape["topics"].distribution)
    tf.summary.scalar("kl", tf.reduce_mean(kl))

    # Ensure that the KL is non-negative (up to a very small slack).
    # Negative KL can happen due to numerical instability.
    with tf.control_dependencies([tf.assert_greater(kl, -1e-3, message="kl")]):
        kl = tf.identity(kl)

    elbo = log_likelihood - kl
    avg_elbo = tf.reduce_mean(elbo)
    tf.summary.scalar("elbo", avg_elbo)
    loss = -avg_elbo

    # Perform variational inference by minimizing the -ELBO.
    global_step = tf.train.get_or_create_global_step()
    optimizer = tf.train.AdamOptimizer(params["learning_rate"])

    # This implements the "burn-in" for prior parameters (see Appendix D of [2]).
    # For the first prior_burn_in_steps steps they are fixed, and then trained
    # jointly with the other parameters.
    grads_and_vars = optimizer.compute_gradients(loss)
    grads_and_vars_except_prior = [
        x for x in grads_and_vars if x[1] != logit_concentration
    ]

    def train_op_except_prior():
        return optimizer.apply_gradients(grads_and_vars_except_prior,
                                         global_step=global_step)

    def train_op_all():
        return optimizer.apply_gradients(grads_and_vars,
                                         global_step=global_step)

    train_op = tf.cond(global_step < params["prior_burn_in_steps"],
                       true_fn=train_op_except_prior,
                       false_fn=train_op_all)

    # The perplexity is an exponent of the average negative ELBO per word.
    words_per_document = tf.reduce_sum(features, axis=1)
    log_perplexity = -elbo / words_per_document
    tf.summary.scalar("perplexity", tf.exp(tf.reduce_mean(log_perplexity)))
    (log_perplexity_tensor,
     log_perplexity_update) = tf.metrics.mean(log_perplexity)
    perplexity_tensor = tf.exp(log_perplexity_tensor)

    # Obtain the topics summary. Implemented as a py_func for simplicity.
    topics = tf.py_func(functools.partial(get_topics_strings,
                                          vocabulary=params["vocabulary"]),
                        [topics_words, concentration],
                        tf.string,
                        stateful=False)
    tf.summary.text("topics", topics)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops={
            "elbo": tf.metrics.mean(elbo),
            "log_likelihood": tf.metrics.mean(log_likelihood),
            "kl": tf.metrics.mean(kl),
            "perplexity": (perplexity_tensor, log_perplexity_update),
            "topics": (topics, tf.no_op()),
        },
    )
def main(argv):
  del argv  # unused
  FLAGS.layer_sizes = [int(layer_size) for layer_size in FLAGS.layer_sizes]
  if len(FLAGS.layer_sizes) != 3:
    raise NotImplementedError("Specifying fewer or more than 3 layers is not "
                              "currently available.")
  if tf.gfile.Exists(FLAGS.model_dir):
    tf.logging.warning(
        "Warning: deleting old log directory at {}".format(FLAGS.model_dir))
    tf.gfile.DeleteRecursively(FLAGS.model_dir)
  tf.gfile.MakeDirs(FLAGS.model_dir)

  if FLAGS.fake_data:
    bag_of_words = np.random.poisson(1., size=[10, 25])
    words = [str(i) for i in range(25)]
  else:
    bag_of_words, words = load_nips2011_papers(FLAGS.data_dir)

  total_count = np.sum(bag_of_words)
  bag_of_words = tf.to_float(bag_of_words)
  data_size, feature_size = bag_of_words.shape

  # Compute expected log-likelihood. First, sample from the variational
  # distribution; second, compute the log-likelihood given the sample.
  qw2, qw1, qw0, qz2, qz1, qz0 = deep_exponential_family_variational(
      data_size,
      feature_size,
      FLAGS.layer_sizes)

  with ed.tape() as model_tape:
    with ed.interception(make_value_setter(w2=qw2, w1=qw1, w0=qw0,
                                           z2=qz2, z1=qz1, z0=qz0)):
      posterior_predictive = deep_exponential_family(data_size,
                                                     feature_size,
                                                     FLAGS.layer_sizes,
                                                     FLAGS.shape)

  log_likelihood = posterior_predictive.distribution.log_prob(bag_of_words)
  log_likelihood = tf.reduce_sum(log_likelihood)
  tf.summary.scalar("log_likelihood", log_likelihood)

  # Compute analytic KL-divergence between variational and prior distributions.
  kl = 0.
  for rv_name, variational_rv in [("z0", qz0), ("z1", qz1), ("z2", qz2),
                                  ("w0", qw0), ("w1", qw1), ("w2", qw2)]:
    kl += tf.reduce_sum(variational_rv.distribution.kl_divergence(
        model_tape[rv_name].distribution))

  tf.summary.scalar("kl", kl)

  elbo = log_likelihood - kl
  tf.summary.scalar("elbo", elbo)
  optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
  train_op = optimizer.minimize(-elbo)

  sess = tf.Session()
  summary = tf.summary.merge_all()
  summary_writer = tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
  start_time = time.time()

  sess.run(tf.global_variables_initializer())
  for step in range(FLAGS.max_steps):
    start_time = time.time()
    _, elbo_value = sess.run([train_op, elbo])
    if step % 500 == 0:
      duration = time.time() - start_time
      print("Step: {:>3d} Loss: {:.3f} ({:.3f} sec)".format(
          step, elbo_value, duration))
      summary_str = sess.run(summary)
      summary_writer.add_summary(summary_str, step)
      summary_writer.flush()

      # Compute perplexity of the full data set. The model's negative
      # log-likelihood of data is upper bounded by the variational objective.
      negative_log_likelihood = -elbo_value
      perplexity = np.exp(negative_log_likelihood / total_count)
      print("Negative log-likelihood <= {:0.3f}".format(
          negative_log_likelihood))
      print("Perplexity <= {:0.3f}".format(perplexity))

      # Print top 10 words for first 10 topics.
      qw0_values = sess.run(qw0)
      for k in range(min(10, FLAGS.layer_sizes[-1])):
        top_words_idx = qw0_values[k, :].argsort()[-10:][::-1]
        top_words = " ".join([words[i] for i in top_words_idx])
        print("Topic {}: {}".format(k, top_words))
 def transformed_model():
   with ed.interception(trivial_interceptor):
     model()
 def model_noncentered(*params):
   with ed.interception(ed_transforms.ncp):
     return model_config.model(*params)
 def transformed_model():
   with ed.interception(trivial_interceptor):
     model()
 
 # create data_train for training
 dtype = np.float32
 n_classes = max(train_target)+1
 Data = {i: np.array(train_data[train_target == i,:])  for i in range(n_classes)}
 sample_shape = {i: sum(train_target == i)  for i in range(n_classes)}# sample_shape for each class
 data_train = Data[0]
 for i in range(n_classes-1):
     data_train=np.vstack((data_train,Data[i+1]))
 # get Variables of EpiAnno
 qmu , qsigma,qz,qw,qnoise = EpiAnno.Q(latent_dim,data_train.shape[1],n_classes,sample_shape)
 qmu_dict = {v.distribution.name.split("_")[0].split("/")[0][1:]: v for v in qmu}
 qw_dict = {v.distribution.name.split("_")[0].split("/")[0][1:]: v for v in qw}
 qz_dict = {v.distribution.name.split("_")[0].split("/")[0][1:]: v for v in qz}
 # set Variables to EpiAnnp, then get the ELBO
 with ed.interception(EpiAnno.set_values(**qmu_dict,sigma=qsigma,**qw_dict,x = data_train,\
                                       **qz_dict,noise = qnoise)):
     pmu,psigma,pz,pw,pnoise,px = EpiAnno.EpiAnno(10,data_train.shape[1],n_classes,sample_shape)
 elbo = EpiAnno.ELBO(pmu,psigma,pz,pw,pnoise,px,qmu , qsigma,qz,qw,qnoise,data_train.shape[1],data_train.shape[0])
 # train the EpiAnno model, get the last 1000 parameters of the model from training
 with tf.Session(config=tf_config) as sess:
     posterior_mu,posterior_sigma,posterior_qw = EpiAnno.train(qmu,qsigma,qw,elbo,sess,learning_rate,num_epochs,verbose)
 # random select 10 parameters from 1000 parameters
 select = np.random.randint(0,1000,10)
 pred_target_test = []
 for i in range(10):
     pred_z_test = EpiAnno.predict_z(posterior_qw[select[i]],test_data)
     pred_target_test.append(EpiAnno.predict_target(posterior_mu[select[i]],posterior_sigma[select[i]],pred_z_test))
 pred_target_test = np.array(pred_target_test)
 pred_target = []
 for i in range(pred_target_test.shape[1]):
     b = np.bincount(pred_target_test[:,i])
Exemple #31
0
 def model_ncp(*params):
     with ed.interception(interceptor):
         return model_config.model(*params)
# 3. Mean-field VI
""" """""" """""" """""" """""" """"""
""" 3.1. Set up the computational graph """
mfvi_graph = tf.Graph()

with mfvi_graph.as_default():
    # sample from variational family
    (q_f, q_f_deriv, q_sig, qf_mean, qf_sdev, qf_deriv_mean,
     qf_deriv_sdev) = gpr_mono.variational_mfvi(X=X_train,
                                                X_deriv=X_deriv,
                                                ls=DEFAULT_LS_VAL)

    # compute the expected predictive log-likelihood
    with ed.tape() as model_tape:
        with ed.interception(
                make_value_setter(gp_f=q_f, gp_f_deriv=q_f_deriv,
                                  sigma=q_sig)):
            gp_f, gp_f_deriv, _, y, _ = gpr_mono.model(X=X_train,
                                                       X_deriv=X_deriv,
                                                       ls=DEFAULT_LS_VAL)

    # add penalized likelihood
    log_likelihood = gpr_mono.make_log_likelihood_tensor(
        gp_f,
        gp_f_deriv,
        y,
        y_train,
        deriv_prior_scale=DEFAULT_DERIV_CDF_SCALE)

    # compute the KL divergence
    kl = 0.
def PAC2VI(dataSource=tf.keras.datasets.fashion_mnist,
           NPixels=14,
           algorithm=0,
           PARTICLES=20,
           batch_size=100,
           num_epochs=50,
           num_hidden_units=20):
    """ Run experiments for MAP, Variational, PAC^2-Variational and PAC^2_T-Variational algorithms for the self-supervised classification task with a Categorical data model.
        Args:
            dataSource: The data set used in the evaluation.
            NLabels: The number of labels to predict.
            NPixels: The size of the images: NPixels\times NPixels.
            algorithm: Integer indicating the algorithm to be run.
                0- MAP Learning
                1- Variational Learning
                2- PAC^2-Variational Learning
                3- PAC^2_T-Variational Learning
            PARTICLES: Number of Monte-Carlo samples used to compute the posterior prediction distribution.
            batch_size: Size of the batch.
            num_epochs: Number of epochs.
            num_hidden_units: Number of hidden units in the MLP.
        Returns:
            NLL: The negative log-likelihood over the test data set.
            :param algorithm:
    """

    np.random.seed(1)
    tf.set_random_seed(1)

    sess = tf.Session()

    (x_train, y_train), (x_test, y_test) = dataSource.load_data()

    if (dataSource.__name__.__contains__('cifar')):
        x_train = sess.run(
            tf.cast(tf.squeeze(tf.image.rgb_to_grayscale(x_train)),
                    dtype=tf.float32))
        x_test = sess.run(
            tf.cast(tf.squeeze(tf.image.rgb_to_grayscale(x_test)),
                    dtype=tf.float32))

    x_train = (x_train < 128).astype(np.int32)
    x_test = (x_test < 128).astype(np.int32)

    NPixels = np.int(NPixels / 2)

    y_train = x_train[:, NPixels:]
    x_train = x_train[:, 0:NPixels]

    y_test = x_test[:, NPixels:]
    x_test = x_test[:, 0:NPixels]

    NPixels = NPixels * NPixels * 2

    N = x_train.shape[0]
    M = batch_size

    x_batch = tf.placeholder(dtype=tf.float32,
                             name="x_batch",
                             shape=[None, NPixels])
    y_batch = tf.placeholder(dtype=tf.int32,
                             name="y_batch",
                             shape=[None, NPixels])

    def model(NHIDDEN, x):
        W = ed.Normal(loc=tf.zeros([NPixels, NHIDDEN]), scale=1., name="W")
        b = ed.Normal(loc=tf.zeros([1, NHIDDEN]), scale=1., name="b")

        W_out = ed.Normal(loc=tf.zeros([NHIDDEN, 2 * NPixels]),
                          scale=1.,
                          name="W_out")
        b_out = ed.Normal(loc=tf.zeros([1, 2 * NPixels]),
                          scale=1.,
                          name="b_out")

        hidden_layer = tf.nn.relu(tf.matmul(x, W) + b)
        out = tf.matmul(hidden_layer, W_out) + b_out
        y = ed.Categorical(logits=tf.reshape(
            out, [tf.shape(x_batch)[0], NPixels, 2]),
                           name="y")

        return W, b, W_out, b_out, x, y

    def qmodel(NHIDDEN):
        W_loc = tf.Variable(
            tf.random_normal([NPixels, NHIDDEN], 0.0, 0.1, dtype=tf.float32))
        b_loc = tf.Variable(
            tf.random_normal([1, NHIDDEN], 0.0, 0.1, dtype=tf.float32))

        if algorithm == 0:
            W_scale = 0.000001
            b_scale = 0.000001
        else:
            W_scale = tf.nn.softplus(
                tf.Variable(
                    tf.random_normal([NPixels, NHIDDEN],
                                     -3.,
                                     stddev=0.1,
                                     dtype=tf.float32)))
            b_scale = tf.nn.softplus(
                tf.Variable(
                    tf.random_normal([1, NHIDDEN],
                                     -3.,
                                     stddev=0.1,
                                     dtype=tf.float32)))

        qW = ed.Normal(W_loc, scale=W_scale, name="W")
        qW_ = ed.Normal(W_loc, scale=W_scale, name="W")

        qb = ed.Normal(b_loc, scale=b_scale, name="b")
        qb_ = ed.Normal(b_loc, scale=b_scale, name="b")

        W_out_loc = tf.Variable(
            tf.random_normal([NHIDDEN, 2 * NPixels],
                             0.0,
                             0.1,
                             dtype=tf.float32))
        b_out_loc = tf.Variable(
            tf.random_normal([1, 2 * NPixels], 0.0, 0.1, dtype=tf.float32))
        if algorithm == 0:
            W_out_scale = 0.000001
            b_out_scale = 0.000001
        else:
            W_out_scale = tf.nn.softplus(
                tf.Variable(
                    tf.random_normal([NHIDDEN, 2 * NPixels],
                                     -3.,
                                     stddev=0.1,
                                     dtype=tf.float32)))
            b_out_scale = tf.nn.softplus(
                tf.Variable(
                    tf.random_normal([1, 2 * NPixels],
                                     -3.,
                                     stddev=0.1,
                                     dtype=tf.float32)))

        qW_out = ed.Normal(W_out_loc, scale=W_out_scale, name="W_out")
        qb_out = ed.Normal(b_out_loc, scale=b_out_scale, name="b_out")

        qW_out_ = ed.Normal(W_out_loc, scale=W_out_scale, name="W_out")
        qb_out_ = ed.Normal(b_out_loc, scale=b_out_scale, name="b_out")

        return qW, qW_, qb, qb_, qW_out, qW_out_, qb_out, qb_out_

    W, b, W_out, b_out, x, y = model(num_hidden_units, x_batch)

    qW, qW_, qb, qb_, qW_out, qW_out_, qb_out, qb_out_ = qmodel(
        num_hidden_units)

    with ed.interception(
            ed.make_value_setter(W=qW, b=qb, W_out=qW_out, b_out=qb_out)):
        pW, pb, pW_out, pb_out, px, py = model(num_hidden_units, x)

    with ed.interception(
            ed.make_value_setter(W=qW_, b=qb_, W_out=qW_out_, b_out=qb_out_)):
        pW_, pb_, pW_out_, pb_out_, px_, py_ = model(num_hidden_units, x)

    pylogprob = tf.expand_dims(
        tf.reduce_sum(py.distribution.log_prob(y_batch), axis=1), 1)
    py_logprob = tf.expand_dims(
        tf.reduce_sum(py_.distribution.log_prob(y_batch), axis=1), 1)

    logmax = tf.stop_gradient(tf.math.maximum(pylogprob, py_logprob) + 0.1)
    logmean_logmax = tf.math.reduce_logsumexp(tf.concat(
        [pylogprob - logmax, py_logprob - logmax], 1),
                                              axis=1) - tf.log(2.)
    alpha = tf.expand_dims(logmean_logmax, 1)

    if (algorithm == 3):
        hmax = 2 * tf.stop_gradient(
            alpha / tf.math.pow(1 - tf.math.exp(alpha), 2) +
            tf.math.pow(tf.math.exp(alpha) * (1 - tf.math.exp(alpha)), -1))
    else:
        hmax = 1.

    var = 0.5 * (
        tf.reduce_mean(tf.exp(2 * pylogprob - 2 * logmax) * hmax) -
        tf.reduce_mean(tf.exp(pylogprob + py_logprob - 2 * logmax) * hmax))

    datalikelihood = tf.reduce_mean(pylogprob)


    logprior = tf.reduce_sum(pW.distribution.log_prob(pW.value)) + \
             tf.reduce_sum(pb.distribution.log_prob(pb.value)) + \
             tf.reduce_sum(pW_out.distribution.log_prob(pW_out.value)) + \
             tf.reduce_sum(pb_out.distribution.log_prob(pb_out.value))


    entropy = tf.reduce_sum(qW.distribution.log_prob(qW.value)) + \
              tf.reduce_sum(qb.distribution.log_prob(qb.value)) + \
              tf.reduce_sum(qW_out.distribution.log_prob(qW_out.value)) + \
              tf.reduce_sum(qb_out.distribution.log_prob(qb_out.value))

    entropy = -entropy

    KL = (-entropy - logprior) / N

    if (algorithm == 2 or algorithm == 3):
        elbo = datalikelihood + var - KL
    elif algorithm == 1:
        elbo = datalikelihood - KL
    elif algorithm == 0:
        elbo = datalikelihood + logprior / N

    verbose = True
    optimizer = tf.train.AdamOptimizer(0.001)
    t = []
    train = optimizer.minimize(-elbo)
    init = tf.global_variables_initializer()
    sess.run(init)

    for i in range(num_epochs + 1):
        perm = np.random.permutation(N)
        x_train = np.take(x_train, perm, axis=0)
        y_train = np.take(y_train, perm, axis=0)

        x_batches = np.array_split(x_train, N / M)
        y_batches = np.array_split(y_train, N / M)

        for j in range(N // M):
            batch_x = np.reshape(
                x_batches[j], [x_batches[j].shape[0], -1]).astype(np.float32)
            batch_y = np.reshape(
                y_batches[j], [y_batches[j].shape[0], -1]).astype(np.float32)

            value, _ = sess.run([elbo, train],
                                feed_dict={
                                    x_batch: batch_x,
                                    y_batch: batch_y
                                })
            t.append(-value)
            if verbose:
                #if j % 1 == 0: print(".", end="", flush=True)
                if i % 50 == 0 and j % 1000 == 0:
                    #if j >= 5 :
                    print("\nEpoch: " + str(i))
                    str_elbo = str(t[-1])
                    print("\n" + str(j) + " epochs\t" + str_elbo,
                          end="",
                          flush=True)
                    print("\n" + str(j) + " data\t" + str(
                        sess.run(datalikelihood,
                                 feed_dict={
                                     x_batch: batch_x,
                                     y_batch: batch_y
                                 })),
                          end="",
                          flush=True)
                    print("\n" + str(j) + " var\t" + str(
                        sess.run(var,
                                 feed_dict={
                                     x_batch: batch_x,
                                     y_batch: batch_y
                                 })),
                          end="",
                          flush=True)
                    print("\n" + str(j) + " KL\t" + str(
                        sess.run(KL,
                                 feed_dict={
                                     x_batch: batch_x,
                                     y_batch: batch_y
                                 })),
                          end="",
                          flush=True)
                    print("\n" + str(j) + " energy\t" + str(
                        sess.run(logprior,
                                 feed_dict={
                                     x_batch: batch_x,
                                     y_batch: batch_y
                                 })),
                          end="",
                          flush=True)
                    print("\n" + str(j) + " entropy\t" + str(
                        sess.run(entropy,
                                 feed_dict={
                                     x_batch: batch_x,
                                     y_batch: batch_y
                                 })),
                          end="",
                          flush=True)
                    print("\n" + str(j) + " hmax\t" + str(
                        sess.run(tf.reduce_mean(hmax),
                                 feed_dict={
                                     x_batch: batch_x,
                                     y_batch: batch_y
                                 })),
                          end="",
                          flush=True)
                    print("\n" + str(j) + " alpha\t" + str(
                        sess.run(tf.reduce_mean(alpha),
                                 feed_dict={
                                     x_batch: batch_x,
                                     y_batch: batch_y
                                 })),
                          end="",
                          flush=True)
                    print("\n" + str(j) + " logmax\t" + str(
                        sess.run(tf.reduce_mean(logmax),
                                 feed_dict={
                                     x_batch: batch_x,
                                     y_batch: batch_y
                                 })),
                          end="",
                          flush=True)

    M = 1000

    N = x_test.shape[0]
    x_batches = np.array_split(x_test, N / M)
    y_batches = np.array_split(y_test, N / M)

    NLL = 0

    for j in range(N // M):
        batch_x = np.reshape(x_batches[j],
                             [x_batches[j].shape[0], -1]).astype(np.float32)
        batch_y = np.reshape(y_batches[j],
                             [y_batches[j].shape[0], -1]).astype(np.float32)
        y_pred_list = []
        for i in range(PARTICLES):
            y_pred_list.append(
                sess.run(pylogprob,
                         feed_dict={
                             x_batch: batch_x,
                             y_batch: batch_y
                         }))
        y_preds = np.concatenate(y_pred_list, axis=1)
        score = tf.reduce_sum(
            tf.math.reduce_logsumexp(y_preds, axis=1) -
            tf.log(np.float32(PARTICLES)))
        score = sess.run(score)
        NLL = NLL + score
        if verbose:
            if j % 1 == 0: print(".", end="", flush=True)
            if j % 1 == 0:
                str_elbo = str(score)
                print("\n" + str(j) + " epochs\t" + str_elbo,
                      end="",
                      flush=True)

    print("\nNLL: " + str(NLL))

    return NLL
Exemple #34
0
 def variational_model(*args):
     with ed.interception(mean_field):
         return model(*args)
Exemple #35
0
    def func(*args, **kwargs):
        # The name used to identify the random variable by string
        if 'name' not in kwargs:
            kwargs['name'] = util.name.generate('randvar')
        rv_name = kwargs.get('name')

        # compute maximum shape between shapes of inputs, and apply broadcast to the smallers in _sanitize_input
        # if batch_shape is provided, use such shape instead
        if 'batch_shape' in kwargs:
            b = kwargs.pop('batch_shape')
            if np.isscalar(b):
                b = [b]
            max_shape = b
        else:
            max_shape = _maximum_shape(args + tuple(kwargs.values()))

        if contextmanager.data_model.is_active():
            if 'sample_shape' in kwargs:
                # warn that sampe_shape will be ignored
                warnings.warn('Random Variables defined inside a probabilistic model ignore the sample_shape argument.')
                kwargs.pop('sample_shape', None)
            sample_shape = ()  # used in case that RV is in probmodel, but not in a datamodel
        else:
            # only used if prob model is active
            sample_shape = kwargs.pop('sample_shape', ())

        # sanitize will consist on tf.stack list, and each element must be broadcast_to to match the shape
        sanitized_args = [_sanitize_input(arg, max_shape) for arg in args]
        sanitized_kwargs = {k: _sanitize_input(v, max_shape) for k, v in kwargs.items()}

        # If it is inside a data model, ommit the sample_shape in kwargs if exist and use size from data_model
        # NOTE: comment this. Needed here because we need to know the shape of the distribution
        # Not using sample shape yet. Used just to create the tensors, and
        # compute the dependencies by using the tf graph
        tfp_dist = distribution_cls(*sanitized_args, **sanitized_kwargs)
        if contextmanager.data_model.is_active():
            # create graph once tensors are registered in graph
            contextmanager.randvar_registry.update_graph(rv_name)

            # compute sample_shape now that we have computed the dependencies
            sample_shape = contextmanager.data_model.get_sample_shape(rv_name)

            # create tf.Variable's to allow to observe the Random Variable
            shape = ([sample_shape] if sample_shape else []) + \
                tfp_dist.batch_shape.as_list() + \
                tfp_dist.event_shape.as_list()
            
            # take into account the dtype of tfp_dist in order to create the initial value correctly
            initial_value = tf.zeros(shape, dtype=tfp_dist.dtype) if shape else tf.constant(0,  dtype=tfp_dist.dtype)

            is_observed, is_observed_var, observed_value_var = _make_predictable_variables(initial_value, rv_name)

            with ed.interception(util.interceptor.set_values_condition(is_observed_var, observed_value_var)):
                ed_random_var = _make_edward_random_variable(tfp_dist)(sample_shape=sample_shape, name=rv_name)

            is_datamodel = True
        else:
            # create tf.Variable's to allow to observe the Random Variable
            shape = tfp_dist.batch_shape.as_list() + tfp_dist.event_shape.as_list()

            # take into account the dtype of tfp_dist in order to create the initial value correctly
            initial_value = tf.zeros(shape, dtype=tfp_dist.dtype) if shape else tf.constant(0,  dtype=tfp_dist.dtype)

            is_observed, is_observed_var, observed_value_var = _make_predictable_variables(initial_value, rv_name)

            # sample_shape is sample_shape in kwargs or ()
            with ed.interception(util.interceptor.set_values_condition(is_observed_var, observed_value_var)):
                ed_random_var = ed_random_variable_cls(*sanitized_args, **sanitized_kwargs, sample_shape=sample_shape)
            is_datamodel = False

        rv = RandomVariable(
            var=ed_random_var,
            name=rv_name,
            is_datamodel=is_datamodel,
            ed_cls=ed_random_variable_cls,
            var_args=sanitized_args,
            var_kwargs=sanitized_kwargs,
            sample_shape=sample_shape,
            is_observed=is_observed,
            is_observed_var=is_observed_var,
            observed_value_var=observed_value_var,
        )

        # register the variable as it is created. Used to detect dependencies
        contextmanager.randvar_registry.register_variable(rv)
        contextmanager.randvar_registry.update_graph()

        # Doc for help menu
        rv.__doc__ += docs
        rv.__name__ = name

        return rv
Exemple #36
0

DATA_SIZE = 100
FEATURE_SIZE = 41
UNITS = [23, 7, 2]
SHAPE = 0.1
x, w2, w1, w0, z2, z1, z0 = deep_exponential_family(DATA_SIZE, FEATURE_SIZE,
                                                    UNITS, SHAPE)
qw2, qw1, qw0, qz2, qz1, qz0 = deep_exponential_family_variational(
    w2, w1, w0, z2, z1, z0)

# x_sample = np.random.poisson(5., size=[DATA_SIZE, FEATURE_SIZE])  # 生成虚拟的训练数据,size与模型匹配
x_sample = tf.placeholder(tf.float32,
                          shape=[DATA_SIZE, FEATURE_SIZE])  # 可以用placeholder占位符
with ed.tape() as model_tape:
    with ed.interception(
            make_value_setter(w2=qw2, w1=qw1, w0=qw0, z2=qz2, z1=qz1, z0=qz0)):
        # 对分布的参数用后验分布进行替换,生成后验分布
        posterior_predictive, _, _, _, _, _, _ = deep_exponential_family(
            DATA_SIZE, FEATURE_SIZE, UNITS, SHAPE)
log_likelihood = posterior_predictive.distribution.log_prob(x_sample)
print(log_likelihood)  # log_likelihood为根据x_sample计算的对数似然函数

# 损失函数的定义,用变分法
kl = 0.
for rv_name, variational_rv in [("z0", qz0), ("z1", qz1), ("z2", qz2),
                                ("w0", qw0), ("w1", qw1), ("w2", qw2)]:
    # rv_name代表先验分布的name
    # variational_rv代表后验分布的名字
    kl += tf.reduce_sum(
        variational_rv.distribution.kl_divergence(
            model_tape[rv_name].distribution))  # q分布与p分布计算KL散度,q为后验,p为先验
def model_fn(features, labels, mode, params, config):
  """Builds the model function for use in an Estimator.

  Arguments:
    features: The input features for the Estimator.
    labels: The labels, unused here.
    mode: Signifies whether it is train or test or predict.
    params: Some hyperparameters as a dictionary.
    config: The RunConfig, unused here.

  Returns:
    EstimatorSpec: A tf.estimator.EstimatorSpec instance.
  """
  del labels, config

  # Set up the model's learnable parameters.
  logit_concentration = tf.get_variable(
      "logit_concentration",
      shape=[1, params["num_topics"]],
      initializer=tf.constant_initializer(
          _softplus_inverse(params["prior_initial_value"])))
  concentration = _clip_dirichlet_parameters(
      tf.nn.softplus(logit_concentration))

  num_words = features.shape[1]
  topics_words_logits = tf.get_variable(
      "topics_words_logits",
      shape=[params["num_topics"], num_words],
      initializer=tf.glorot_normal_initializer())
  topics_words = tf.nn.softmax(topics_words_logits, axis=-1)

  # Compute expected log-likelihood. First, sample from the variational
  # distribution; second, compute the log-likelihood given the sample.
  lda_variational = make_lda_variational(
      params["activation"],
      params["num_topics"],
      params["layer_sizes"])
  with ed.tape() as variational_tape:
    _ = lda_variational(features)

  with ed.tape() as model_tape:
    with ed.interception(
        make_value_setter(topics=variational_tape["topics_posterior"])):
      posterior_predictive = latent_dirichlet_allocation(concentration,
                                                         topics_words)

  log_likelihood = posterior_predictive.distribution.log_prob(features)
  tf.summary.scalar("log_likelihood", tf.reduce_mean(log_likelihood))

  # Compute the KL-divergence between two Dirichlets analytically.
  # The sampled KL does not work well for "sparse" distributions
  # (see Appendix D of [2]).
  kl = variational_tape["topics_posterior"].distribution.kl_divergence(
      model_tape["topics"].distribution)
  tf.summary.scalar("kl", tf.reduce_mean(kl))

  # Ensure that the KL is non-negative (up to a very small slack).
  # Negative KL can happen due to numerical instability.
  with tf.control_dependencies([tf.assert_greater(kl, -1e-3, message="kl")]):
    kl = tf.identity(kl)

  elbo = log_likelihood - kl
  avg_elbo = tf.reduce_mean(elbo)
  tf.summary.scalar("elbo", avg_elbo)
  loss = -avg_elbo

  # Perform variational inference by minimizing the -ELBO.
  global_step = tf.train.get_or_create_global_step()
  optimizer = tf.train.AdamOptimizer(params["learning_rate"])

  # This implements the "burn-in" for prior parameters (see Appendix D of [2]).
  # For the first prior_burn_in_steps steps they are fixed, and then trained
  # jointly with the other parameters.
  grads_and_vars = optimizer.compute_gradients(loss)
  grads_and_vars_except_prior = [
      x for x in grads_and_vars if x[1] != logit_concentration]

  def train_op_except_prior():
    return optimizer.apply_gradients(
        grads_and_vars_except_prior,
        global_step=global_step)

  def train_op_all():
    return optimizer.apply_gradients(
        grads_and_vars,
        global_step=global_step)

  train_op = tf.cond(
      global_step < params["prior_burn_in_steps"],
      true_fn=train_op_except_prior,
      false_fn=train_op_all)

  # The perplexity is an exponent of the average negative ELBO per word.
  words_per_document = tf.reduce_sum(features, axis=1)
  log_perplexity = -elbo / words_per_document
  tf.summary.scalar("perplexity", tf.exp(tf.reduce_mean(log_perplexity)))
  (log_perplexity_tensor, log_perplexity_update) = tf.metrics.mean(
      log_perplexity)
  perplexity_tensor = tf.exp(log_perplexity_tensor)

  # Obtain the topics summary. Implemented as a py_func for simplicity.
  topics = tf.py_func(
      functools.partial(get_topics_strings, vocabulary=params["vocabulary"]),
      [topics_words, concentration], tf.string, stateful=False)
  tf.summary.text("topics", topics)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops={
          "elbo": tf.metrics.mean(elbo),
          "log_likelihood": tf.metrics.mean(log_likelihood),
          "kl": tf.metrics.mean(kl),
          "perplexity": (perplexity_tensor, log_perplexity_update),
          "topics": (topics, tf.no_op()),
      },
  )