Пример #1
0
  def __call__(self, kl_fn):
    """Perform the KL registration.

    Args:
      kl_fn: The function to use for the KL divergence.

    Returns:
      kl_fn

    Raises:
      TypeError: if kl_fn is not a callable.
      ValueError: if a KL divergence function has already been registered for
        the given argument classes.
    """
    if not callable(kl_fn):
      raise TypeError("kl_fn must be callable, received: %s" % kl_fn)
    if self._key in _DIVERGENCES:
      raise ValueError("KL(%s || %s) has already been registered to: %s"
                       % (self._key[0].__name__, self._key[1].__name__,
                          _DIVERGENCES[self._key]))
    _DIVERGENCES[self._key] = kl_fn
    # TODO(b/117098119): For backwards compatibility, we register the
    # distributions in both this registry and the deprecated TF's registry.
    #
    # Additionally, for distributions which have deprecated copies, we register
    # all 3 combinations in their respective files (see test for the list).
    with deprecation.silence():
      tf.distributions.RegisterKL(*self._key)(kl_fn)
    return kl_fn
Пример #2
0
def kl_divergence(distribution_a, distribution_b,
                  allow_nan_stats=True, name=None):
  """Get the KL-divergence KL(distribution_a || distribution_b).

  If there is no KL method registered specifically for `type(distribution_a)`
  and `type(distribution_b)`, then the class hierarchies of these types are
  searched.

  If one KL method is registered between any pairs of classes in these two
  parent hierarchies, it is used.

  If more than one such registered method exists, the method whose registered
  classes have the shortest sum MRO paths to the input types is used.

  If more than one such shortest path exists, the first method
  identified in the search is used (favoring a shorter MRO distance to
  `type(distribution_a)`).

  Args:
    distribution_a: The first distribution.
    distribution_b: The second distribution.
    allow_nan_stats: Python `bool`, default `True`. When `True`,
      statistics (e.g., mean, mode, variance) use the value "`NaN`" to
      indicate the result is undefined. When `False`, an exception is raised
      if one or more of the statistic's batch members are undefined.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    A Tensor with the batchwise KL-divergence between `distribution_a`
    and `distribution_b`.

  Raises:
    NotImplementedError: If no KL method is defined for distribution types
      of `distribution_a` and `distribution_b`.
  """
  kl_fn = _registered_kl(type(distribution_a), type(distribution_b))
  if kl_fn is None:
    # TODO(b/117098119): For backwards compatibility, we check TF's registry as
    # well. This typically happens when this function is called on a pair of
    # TF's distributions.
    with deprecation.silence():
      return tf.distributions.kl_divergence(distribution_a, distribution_b)

  with tf.name_scope("KullbackLeibler"):
    kl_t = kl_fn(distribution_a, distribution_b, name=name)
    if allow_nan_stats:
      return kl_t

    # Check KL for NaNs
    kl_t = tf.identity(kl_t, name="kl")

    with tf.control_dependencies([
        tf.Assert(
            tf.logical_not(
                tf.reduce_any(tf.is_nan(kl_t))),
            ["KL calculation between %s and %s returned NaN values "
             "(and was called with allow_nan_stats=False). Values:"
             % (distribution_a.name, distribution_b.name), kl_t])]):
      return tf.identity(kl_t, name="checked_kl")
Пример #3
0
  def test_silence(self, mock_warning):
    date = "2016-07-04"
    instructions = "This is how you update..."

    @deprecation.deprecated(date, instructions, warn_once=False)
    def _fn():
      pass

    _fn()
    self.assertEqual(1, mock_warning.call_count)

    with deprecation.silence():
      _fn()
    self.assertEqual(1, mock_warning.call_count)

    _fn()
    self.assertEqual(2, mock_warning.call_count)
Пример #4
0
# limitations under the License.
# ==============================================================================
"""Classes representing statistical distributions and ops for working with them.

Use [tfp.distributions](/probability/api_docs/python/tfp/distributions) instead.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.util import deprecation


# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member,g-import-not-at-top

with deprecation.silence():
  from tensorflow.contrib.distributions.python.ops import bijectors
  from tensorflow.contrib.distributions.python.ops.autoregressive import *
  from tensorflow.contrib.distributions.python.ops.batch_reshape import *
  from tensorflow.contrib.distributions.python.ops.binomial import *
  from tensorflow.contrib.distributions.python.ops.cauchy import *
  from tensorflow.contrib.distributions.python.ops.chi2 import *
  from tensorflow.contrib.distributions.python.ops.conditional_distribution import *
  from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
  from tensorflow.contrib.distributions.python.ops.deterministic import *
  from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular
  from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse
  from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
  from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp
  from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse
  from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag
Пример #5
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes representing statistical distributions and ops for working with them.

Use [tfp.distributions](/probability/api_docs/python/tfp/distributions) instead.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.util import deprecation

# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member,g-import-not-at-top

with deprecation.silence():
    from tensorflow.contrib.distributions.python.ops import bijectors
    from tensorflow.contrib.distributions.python.ops.autoregressive import *
    from tensorflow.contrib.distributions.python.ops.batch_reshape import *
    from tensorflow.contrib.distributions.python.ops.binomial import *
    from tensorflow.contrib.distributions.python.ops.cauchy import *
    from tensorflow.contrib.distributions.python.ops.chi2 import *
    from tensorflow.contrib.distributions.python.ops.conditional_distribution import *
    from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
    from tensorflow.contrib.distributions.python.ops.deterministic import *
    from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular
    from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse
    from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
    from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp
    from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse
    from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag
Пример #6
0
def kl_divergence(distribution_a, distribution_b,
                  allow_nan_stats=True, name=None):
  """Get the KL-divergence KL(distribution_a || distribution_b).

  If there is no KL method registered specifically for `type(distribution_a)`
  and `type(distribution_b)`, then the class hierarchies of these types are
  searched.

  If one KL method is registered between any pairs of classes in these two
  parent hierarchies, it is used.

  If more than one such registered method exists, the method whose registered
  classes have the shortest sum MRO paths to the input types is used.

  If more than one such shortest path exists, the first method
  identified in the search is used (favoring a shorter MRO distance to
  `type(distribution_a)`).

  Args:
    distribution_a: The first distribution.
    distribution_b: The second distribution.
    allow_nan_stats: Python `bool`, default `True`. When `True`,
      statistics (e.g., mean, mode, variance) use the value "`NaN`" to
      indicate the result is undefined. When `False`, an exception is raised
      if one or more of the statistic's batch members are undefined.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    A Tensor with the batchwise KL-divergence between `distribution_a`
    and `distribution_b`.

  Raises:
    NotImplementedError: If no KL method is defined for distribution types
      of `distribution_a` and `distribution_b`.
  """
  kl_fn = _registered_kl(type(distribution_a), type(distribution_b))
  if kl_fn is None:
    # TODO(b/117098119): For backwards compatibility, we check TF's registry as
    # well. This typically happens when this function is called on a pair of
    # TF's distributions.
    with deprecation.silence():
      return tf.compat.v1.distributions.kl_divergence(distribution_a,
                                                      distribution_b)

  with tf.name_scope("KullbackLeibler"):
    kl_t = kl_fn(distribution_a, distribution_b, name=name)
    if allow_nan_stats:
      return kl_t

    # Check KL for NaNs
    kl_t = tf.identity(kl_t, name="kl")

    with tf.control_dependencies([
        tf.Assert(
            tf.logical_not(tf.reduce_any(input_tensor=tf.math.is_nan(kl_t))), [
                "KL calculation between %s and %s returned NaN values "
                "(and was called with allow_nan_stats=False). Values:" %
                (distribution_a.name, distribution_b.name), kl_t
            ])
    ]):
      return tf.identity(kl_t, name="checked_kl")
Пример #7
0
else:
    # Install stacktrace handler
    try:
        from tensorflow.python.framework import test_util
        test_util.InstallStackTraceHandler()
    except Exception:
        pass

    # silence the massive deprecation warnings in TF 1.13+
    if (int(_version[0]), int(_version[1])) >= (1, 13):
        try:
            from tensorflow.python.util.deprecation import silence
        except Exception:
            pass
        else:
            silence().__enter__()
        try:
            from tensorflow.python.util import deprecation_wrapper
            deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 0
        except Exception:
            pass

    # Monkey-patch tf.test.is_gpu_available to avoid side effects:
    # https://github.com/tensorflow/tensorflow/issues/26460
    try:
        list_dev = tf.config.experimental.list_physical_devices
    except AttributeError:
        pass
    else:
        old_is_gpu_available = tf.test.is_gpu_available