Exemple #1
0
 def topics_prior_distribution(
         self) -> Union[Dirichlet, MultivariateNormalDiag]:
     r""" Create the prior distribution (i.e. the Dirichlet topics distribution),
 `batch_shape=(1,)` and `event_shape=(n_topics,)` """
     # warm-up: stop gradients update for prior parameters
     logits = tf.cond(
         self.step < self.warmup,
         true_fn=lambda: tf.stop_gradient(self.topics_prior_logits),
         false_fn=lambda: self.topics_prior_logits)
     if self.posterior == "dirichlet":
         concentration = tf.nn.softplus(logits)
         concentration = tf.clip_by_value(concentration, 1e-3, 1e3)
         prior = Dirichlet(concentration=concentration, name="TopicsPrior")
     # logistic-normal
     elif self.posterior == "gaussian":
         prior = MultivariateNormalDiag(loc=logits,
                                        scale_identity_multiplier=1.,
                                        name="TopicsPrior")
     return prior
Exemple #2
0
 def predict_topics(
     self,
     inputs: Union[TensorTypes, List[TensorTypes], DatasetV2],
     hard_topics: bool = False,
     verbose: bool = False
 ) -> Union[Dirichlet, MultivariateNormalDiag, tf.Tensor]:
     if not isinstance(inputs, DatasetV2):
         inputs = [inputs]
     if verbose:
         inputs = tqdm(inputs, desc="Predicting topics")
     concentration = []
     loc, scale_diag = [], []
     for x in inputs:
         (_, qZ_X), _ = self(x, training=False)
         if self.lda.posterior == 'dirichlet':
             concentration.append(qZ_X.concentration)
         elif self.lda.posterior == 'gaussian':
             loc.append(qZ_X.loc)
             scale_diag.append(qZ_X.scale._diag)
     # final distribution
     if self.lda.posterior == 'dirichlet':
         concentration = tf.concat(concentration, axis=0)
         dist = Dirichlet(concentration=concentration,
                          name="TopicsDistribution")
         if hard_topics:
             return tf.argmax(dist.mean(), axis=-1)
         return dist
     elif self.lda.posterior == 'gaussian':
         loc = tf.concat(loc, axis=0)
         scale_diag = tf.concat(scale_diag, axis=0)
         dist = MultivariateNormalDiag(loc=loc,
                                       scale_diag=scale_diag,
                                       name="TopicsDistribution")
         if hard_topics:
             probs = tf.nn.softmax(dist.mean(), axis=-1)
             return tf.argmax(probs, axis=-1)
         return dist