def _check_shape(*elements): flatten_tensors = nest.flatten(elements) flatten_shapes = nest.flatten(expected_shapes) checked_tensors = [with_shape(shape, tensor) for shape, tensor in zip(flatten_shapes, flatten_tensors)] return nest.pack_sequence_as(elements, checked_tensors)
def _build_lower_bound(self) -> tf.Tensor: """ Note that this Tensor is only used for logging, not visualization. A better name for it would be mean_lower_bound. """ lower_bound = with_shape([], self.mean_log_py_xw - self.config.kl * self.kl / self.n_data) lower_bound = lower_bound - tf.reduce_mean(self.loss_prec) return tf.check_numerics(lower_bound, "lb")
def build_kl(self): kl = 0. for w in self._kl_weights: kl = kl + w.kl_exact return with_shape([], kl)