Exemplo n.º 1
0
    def _fn(dtype, shape, name, trainable, add_variable_fn):
        """Creates multivariate `Deterministic` or `Normal` distribution.

    Args:
      dtype: Type of parameter's event.
      shape: Python `list`-like representing the parameter's event shape.
      name: Python `str` name prepended to any created (or existing)
        `tf.Variable`s.
      trainable: Python `bool` indicating all created `tf.Variable`s should be
        added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
      add_variable_fn: `tf.get_variable`-like `callable` used to create (or
        access existing) `tf.Variable`s.

    Returns:
      Multivariate `Deterministic` or `Normal` distribution.
    """
        loc, scale = loc_scale_fn(dtype, shape, name, trainable,
                                  add_variable_fn)
        if scale is None:
            dist = deterministic_lib.Deterministic(loc=loc)
        else:
            dist = normal_lib.Normal(loc=loc, scale=scale)
        batch_ndims = tf.size(input=dist.batch_shape_tensor())
        return independent_lib.Independent(
            dist, reinterpreted_batch_ndims=batch_ndims)
def _wrap_as_distributions(structure):
  return tf.nest.map_structure(
      lambda x: independent.Independent(  # pylint: disable=g-long-lambda
          deterministic.Deterministic(x),
          # Particles are a batch dimension.
          reinterpreted_batch_ndims=tf.rank(x) - 1),
      structure)
Exemplo n.º 3
0
 def _fn(dtype, shape, name, trainable, add_variable_fn):
     loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn)
     if scale is None:
         dist = deterministic_lib.Deterministic(loc=loc)
     else:
         dist = normal_lib.Normal(loc=loc, scale=scale)
     batch_ndims = tf2.size(dist.batch_shape_tensor())
     return independent_lib.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
Exemplo n.º 4
0
    def testBackwardsCompatibilityDeterministic(self):
        tfp_normal = normal.Normal(0.0, 1.0)
        tf_normal = tf.distributions.Normal(0.0, 1.0)
        tfp_deterministic = deterministic.Deterministic(0.0)

        kullback_leibler.kl_divergence(tfp_deterministic, tf_normal)
        tf.distributions.kl_divergence(tfp_deterministic, tf_normal)
        kullback_leibler.kl_divergence(tfp_deterministic, tfp_normal)
        tf.distributions.kl_divergence(tfp_deterministic, tfp_normal)