Example #1
0
    def call(self, inputs):
        """Runs the model forward to generate a sequence of productions.

    Args:
      inputs: Unused.

    Returns:
      productions: Tensor of shape [1, num_productions, num_production_rules].
        Slices along the `num_productions` dimension represent one-hot vectors.
    """
        del inputs  # unused
        latent_code = ed.MultivariateNormalDiag(loc=tf.zeros(self.latent_size),
                                                sample_shape=1,
                                                name="latent_code")
        state = self.lstm.zero_state(1, dtype=tf.float32)
        t = 0
        productions = []
        stack = [self.grammar.start_symbol]
        while stack:
            symbol = stack.pop()
            net, state = self.lstm(latent_code, state)
            logits = (self.output_layer(net) +
                      self.grammar.mask(symbol, on_value=0., off_value=-1e9))
            production = ed.OneHotCategorical(logits=logits,
                                              name="production_" + str(t))
            _, rhs = self.grammar.production_rules[tf.argmax(production,
                                                             axis=-1)]
            for symbol in rhs:
                if symbol in self.grammar.nonterminal_symbols:
                    stack.append(symbol)
            productions.append(production)
            t += 1
        return tf.stack(productions, axis=1)
Example #2
0
def latent_dirichlet_allocation(concentration, topics_words):
  """Latent Dirichlet Allocation in terms of its generative process.
  The model posits a distribution over bags of words and is parameterized by
  a concentration and the topic-word probabilities. It collapses per-word
  topic assignments.
  Args:
    concentration: A Tensor of shape [1, num_topics], which parameterizes the
      Dirichlet prior over topics.
    topics_words: A Tensor of shape [num_topics, num_words], where each row
      (topic) denotes the probability of each word being in that topic.
  Returns:
    bag_of_words: A random variable capturing a sample from the model, of shape
      [1, num_words]. It represents one generated document as a bag of words.
  """
  topics = ed.Dirichlet(concentration=concentration, name="topics")
  word_probs = tf.matmul(topics, topics_words)
  
  # The observations are bags of words and therefore not one-hot. However,
  # log_prob of OneHotCategorical computes the probability correctly in
  # this case.
  bag_of_words = ed.OneHotCategorical(probs=word_probs, name="bag_of_words")

  return bag_of_words