def __init__(self, variational, model, log_joint=None, latent_names=None, latent_axis=None): """ Construct the :class:`VariationalChain`. Args: variational (BayesianNet): The variational net. model (BayesianNet): The model net. log_joint (tf.Tensor): The log-joint of the model net. If :obj:`None`, the log-densities of all variables within `model` net will be summed up as the log-joint. (default :obj:`None`) latent_names (Iterable[str]): Names of the latent variables in variational inference. If :obj:`None`, all of the variables within `variational` net will be collected. (default :obj:`None`) latent_axis: The axis or axes to be considered as the sampling dimensions of latent variables. The specified axes will be summed up in the variational lower-bounds or training objectives. (default :obj:`None`) """ if latent_names is None: latent_names = tuple(variational) else: latent_names = tuple(latent_names) with tf.name_scope('VariationalChain'): if log_joint is None: with tf.name_scope('model_log_joint'): log_joint = add_n_broadcast( model.local_log_probs(iter(model))) with tf.name_scope('latent_log_probs'): latent_log_probs = variational.local_log_probs(latent_names) self._variational = variational self._model = model self._log_joint = log_joint self._latent_names = latent_names self._latent_axis = latent_axis self._vi = VariationalInference( log_joint=self.log_joint, latent_log_probs=latent_log_probs, axis=latent_axis )
def __init__(self, log_joint, latent_log_probs, axis=None): """ Construct the :class:`VariationalInference`. Args: log_joint (tf.Tensor): The log-joint of model. latent_log_probs (Iterable[tf.Tensor]): The log-densities of latent variables from the variational net. axis: The axis or axes to be considered as the sampling dimensions of latent variables. The specified axes will be summed up in the variational lower-bounds or training objectives. (default :obj:`None`) """ self._log_joint = tf.convert_to_tensor(log_joint) self._latent_log_probs = tuple( tf.convert_to_tensor(t) for t in latent_log_probs) self._latent_log_prob = add_n_broadcast(self._latent_log_probs, name='latent_log_prob') self._axis = axis self._lower_bound = VariationalLowerBounds(self) self._training = VariationalTrainingObjectives(self) self._evaluation = VariationalEvaluation(self)
def log_joint_func(observed): net = model_func(observed) return add_n_broadcast(net.local_log_prob(['z', 'x']))