def __call__(self, kl_fn): """Perform the KL registration. Args: kl_fn: The function to use for the KL divergence. Returns: kl_fn Raises: TypeError: if kl_fn is not a callable. ValueError: if a KL divergence function has already been registered for the given argument classes. """ if not callable(kl_fn): raise TypeError("kl_fn must be callable, received: %s" % kl_fn) if self._key in _DIVERGENCES: raise ValueError("KL(%s || %s) has already been registered to: %s" % (self._key[0].__name__, self._key[1].__name__, _DIVERGENCES[self._key])) _DIVERGENCES[self._key] = kl_fn # TODO(b/117098119): For backwards compatibility, we register the # distributions in both this registry and the deprecated TF's registry. # # Additionally, for distributions which have deprecated copies, we register # all 3 combinations in their respective files (see test for the list). with deprecation.silence(): tf.distributions.RegisterKL(*self._key)(kl_fn) return kl_fn
def kl_divergence(distribution_a, distribution_b, allow_nan_stats=True, name=None): """Get the KL-divergence KL(distribution_a || distribution_b). If there is no KL method registered specifically for `type(distribution_a)` and `type(distribution_b)`, then the class hierarchies of these types are searched. If one KL method is registered between any pairs of classes in these two parent hierarchies, it is used. If more than one such registered method exists, the method whose registered classes have the shortest sum MRO paths to the input types is used. If more than one such shortest path exists, the first method identified in the search is used (favoring a shorter MRO distance to `type(distribution_a)`). Args: distribution_a: The first distribution. distribution_b: The second distribution. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Returns: A Tensor with the batchwise KL-divergence between `distribution_a` and `distribution_b`. Raises: NotImplementedError: If no KL method is defined for distribution types of `distribution_a` and `distribution_b`. """ kl_fn = _registered_kl(type(distribution_a), type(distribution_b)) if kl_fn is None: # TODO(b/117098119): For backwards compatibility, we check TF's registry as # well. This typically happens when this function is called on a pair of # TF's distributions. with deprecation.silence(): return tf.distributions.kl_divergence(distribution_a, distribution_b) with tf.name_scope("KullbackLeibler"): kl_t = kl_fn(distribution_a, distribution_b, name=name) if allow_nan_stats: return kl_t # Check KL for NaNs kl_t = tf.identity(kl_t, name="kl") with tf.control_dependencies([ tf.Assert( tf.logical_not( tf.reduce_any(tf.is_nan(kl_t))), ["KL calculation between %s and %s returned NaN values " "(and was called with allow_nan_stats=False). Values:" % (distribution_a.name, distribution_b.name), kl_t])]): return tf.identity(kl_t, name="checked_kl")
def test_silence(self, mock_warning): date = "2016-07-04" instructions = "This is how you update..." @deprecation.deprecated(date, instructions, warn_once=False) def _fn(): pass _fn() self.assertEqual(1, mock_warning.call_count) with deprecation.silence(): _fn() self.assertEqual(1, mock_warning.call_count) _fn() self.assertEqual(2, mock_warning.call_count)
# limitations under the License. # ============================================================================== """Classes representing statistical distributions and ops for working with them. Use [tfp.distributions](/probability/api_docs/python/tfp/distributions) instead. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.util import deprecation # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member,g-import-not-at-top with deprecation.silence(): from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops.autoregressive import * from tensorflow.contrib.distributions.python.ops.batch_reshape import * from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.cauchy import * from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.conditional_distribution import * from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag
# See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Classes representing statistical distributions and ops for working with them. Use [tfp.distributions](/probability/api_docs/python/tfp/distributions) instead. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.util import deprecation # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member,g-import-not-at-top with deprecation.silence(): from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops.autoregressive import * from tensorflow.contrib.distributions.python.ops.batch_reshape import * from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.cauchy import * from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.conditional_distribution import * from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag
def kl_divergence(distribution_a, distribution_b, allow_nan_stats=True, name=None): """Get the KL-divergence KL(distribution_a || distribution_b). If there is no KL method registered specifically for `type(distribution_a)` and `type(distribution_b)`, then the class hierarchies of these types are searched. If one KL method is registered between any pairs of classes in these two parent hierarchies, it is used. If more than one such registered method exists, the method whose registered classes have the shortest sum MRO paths to the input types is used. If more than one such shortest path exists, the first method identified in the search is used (favoring a shorter MRO distance to `type(distribution_a)`). Args: distribution_a: The first distribution. distribution_b: The second distribution. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Returns: A Tensor with the batchwise KL-divergence between `distribution_a` and `distribution_b`. Raises: NotImplementedError: If no KL method is defined for distribution types of `distribution_a` and `distribution_b`. """ kl_fn = _registered_kl(type(distribution_a), type(distribution_b)) if kl_fn is None: # TODO(b/117098119): For backwards compatibility, we check TF's registry as # well. This typically happens when this function is called on a pair of # TF's distributions. with deprecation.silence(): return tf.compat.v1.distributions.kl_divergence(distribution_a, distribution_b) with tf.name_scope("KullbackLeibler"): kl_t = kl_fn(distribution_a, distribution_b, name=name) if allow_nan_stats: return kl_t # Check KL for NaNs kl_t = tf.identity(kl_t, name="kl") with tf.control_dependencies([ tf.Assert( tf.logical_not(tf.reduce_any(input_tensor=tf.math.is_nan(kl_t))), [ "KL calculation between %s and %s returned NaN values " "(and was called with allow_nan_stats=False). Values:" % (distribution_a.name, distribution_b.name), kl_t ]) ]): return tf.identity(kl_t, name="checked_kl")
else: # Install stacktrace handler try: from tensorflow.python.framework import test_util test_util.InstallStackTraceHandler() except Exception: pass # silence the massive deprecation warnings in TF 1.13+ if (int(_version[0]), int(_version[1])) >= (1, 13): try: from tensorflow.python.util.deprecation import silence except Exception: pass else: silence().__enter__() try: from tensorflow.python.util import deprecation_wrapper deprecation_wrapper._PER_MODULE_WARNING_LIMIT = 0 except Exception: pass # Monkey-patch tf.test.is_gpu_available to avoid side effects: # https://github.com/tensorflow/tensorflow/issues/26460 try: list_dev = tf.config.experimental.list_physical_devices except AttributeError: pass else: old_is_gpu_available = tf.test.is_gpu_available