Example #1
0
    def testComposite(self):
        auto_normal = auto_composite_tensor.auto_composite_tensor(
            tfd.Normal, omit_kwargs=('name', ))

        def _loop_fn(state, element):
            return state + element

        def _trace_fn(state):
            return [state, 2 * state, auto_normal(state, 0.1)]

        final_state, trace = loop_util.trace_scan(loop_fn=_loop_fn,
                                                  initial_state=0.,
                                                  elems=[1., 2.],
                                                  trace_fn=_trace_fn)

        self.assertAllClose([], tensorshape_util.as_list(final_state.shape))
        self.assertAllClose([2], tensorshape_util.as_list(trace[0].shape))
        self.assertAllClose([2], tensorshape_util.as_list(trace[1].shape))

        self.assertAllClose(3, final_state)
        self.assertAllClose([1, 3], trace[0])
        self.assertAllClose([2, 6], trace[1])

        self.assertIsInstance(trace[2], tfd.Normal)
        self.assertAllClose([1., 3.], trace[2].loc)
        self.assertAllClose([0.1, 0.1], trace[2].scale)
Example #2
0
    def __new__(mcs, classname, baseclasses, attrs):  # pylint: disable=bad-mcs-classmethod-argument
        """Give subclasses their own type_spec, not an inherited one."""

        cls = super(_AutoCompositeTensorPsdKernelMeta, mcs).__new__(  # pylint: disable=too-many-function-args
            mcs, classname, baseclasses, attrs)
        return auto_composite_tensor.auto_composite_tensor(
            cls,
            omit_kwargs=('parameters', ),
            non_identifying_kwargs=('name', ),
            module_name='tfp.math._psdkernels')
Example #3
0
from tensorflow_probability.python.distributions import joint_distribution_sequential as jds
from tensorflow_probability.python.experimental.distributions import mvn_precision_factor_linop as mvn_pfl
from tensorflow_probability.python.experimental.stats import sample_stats
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import unnest
from tensorflow_probability.python.mcmc import kernel as kernel_base
from tensorflow_probability.python.mcmc.internal import util as mcmc_util

__all__ = [
    'DiagonalMassMatrixAdaptation',
]

# Add auto-composite tensors to the global namespace to avoid creating new
# classes inside functions.
_CompositeJointDistributionSequential = auto_composite_tensor.auto_composite_tensor(
    jds.JointDistributionSequential, omit_kwargs=('name', ))
_CompositeLinearOperatorDiag = auto_composite_tensor.auto_composite_tensor(
    tf.linalg.LinearOperatorDiag, omit_kwargs=('name', ))
_CompositeMultivariateNormalPrecisionFactorLinearOperator = auto_composite_tensor.auto_composite_tensor(
    mvn_pfl.MultivariateNormalPrecisionFactorLinearOperator,
    omit_kwargs=('name', ))
_CompositeIndependent = auto_composite_tensor.auto_composite_tensor(
    independent.Independent, omit_kwargs=('name', ))


def hmc_like_momentum_distribution_setter_fn(kernel_results, new_distribution):
    """Setter for `momentum_distribution` so it can be adapted."""
    # Note that unnest.replace_innermost has a special path for going into
    # `accepted_results` preferentially, so this will set
    # `accepted_results.momentum_distribution`.
    return unnest.replace_innermost(kernel_results,
Example #4
0
NUMPY_MODE = False

# TODO(b/182603117): Remove this block once distributions are auto-composite.
if JAX_MODE or NUMPY_MODE:
    _CompositeJointDistributionSequential = jds.JointDistributionSequential
    _CompositeShardedJointDistributionSequential = sharded_jd_lib.JointDistributionSequential
    _CompositeMultivariateNormalPrecisionFactorLinearOperator = mvn_pfl.MultivariateNormalPrecisionFactorLinearOperator
    _CompositeIndependent = independent.Independent
    _CompositeTransformedDistribution = transformed_distribution.TransformedDistribution

else:
    from tensorflow_probability.python.internal import auto_composite_tensor  # pylint: disable=g-import-not-at-top

    # Add auto-composite tensors to the global namespace to avoid creating new
    # classes inside functions.
    _CompositeJointDistributionSequential = auto_composite_tensor.auto_composite_tensor(
        jds.JointDistributionSequential, omit_kwargs=('name', ))
    _CompositeShardedJointDistributionSequential = auto_composite_tensor.auto_composite_tensor(
        sharded_jd_lib.JointDistributionSequential, omit_kwargs=('name', ))
    _CompositeMultivariateNormalPrecisionFactorLinearOperator = auto_composite_tensor.auto_composite_tensor(
        mvn_pfl.MultivariateNormalPrecisionFactorLinearOperator,
        omit_kwargs=('name', ))
    _CompositeIndependent = auto_composite_tensor.auto_composite_tensor(
        independent.Independent, omit_kwargs=('name', ))
    _CompositeTransformedDistribution = auto_composite_tensor.auto_composite_tensor(
        transformed_distribution.TransformedDistribution,
        omit_kwargs=('name', 'kwargs_split_fn', 'parameters'))


def make_momentum_distribution(state_parts,
                               batch_shape,
                               running_variance_parts=None,
Example #5
0
 def _decorator(cls):
     return auto_composite_tensor.auto_composite_tensor(**kwargs)(cls)
Example #6
0
if JAX_MODE or NUMPY_MODE:
    _CompositeJointDistributionSequential = jds.JointDistributionSequential
    _CompositeShardedJointDistributionSequential = sharded_jd_lib.JointDistributionSequential
    _CompositeLinearOperatorDiag = tf.linalg.LinearOperatorDiag
    _CompositeMultivariateNormalPrecisionFactorLinearOperator = mvn_pfl.MultivariateNormalPrecisionFactorLinearOperator
    _CompositeIndependent = independent.Independent
    _CompositeReshape = reshape.Reshape
    _CompositeTransformedDistribution = transformed_distribution.TransformedDistribution
    _CompositeBatchBroadcast = batch_broadcast.BatchBroadcast

else:
    from tensorflow_probability.python.internal import auto_composite_tensor  # pylint: disable=g-import-not-at-top

    # Add auto-composite tensors to the global namespace to avoid creating new
    # classes inside functions.
    _CompositeJointDistributionSequential = auto_composite_tensor.auto_composite_tensor(
        jds.JointDistributionSequential, omit_kwargs=('name', ))
    _CompositeShardedJointDistributionSequential = auto_composite_tensor.auto_composite_tensor(
        sharded_jd_lib.JointDistributionSequential, omit_kwargs=('name', ))
    _CompositeLinearOperatorDiag = auto_composite_tensor.auto_composite_tensor(
        tf.linalg.LinearOperatorDiag, omit_kwargs=('name', ))
    _CompositeMultivariateNormalPrecisionFactorLinearOperator = auto_composite_tensor.auto_composite_tensor(
        mvn_pfl.MultivariateNormalPrecisionFactorLinearOperator,
        omit_kwargs=('name', ))
    _CompositeIndependent = auto_composite_tensor.auto_composite_tensor(
        independent.Independent, omit_kwargs=('name', ))
    _CompositeReshape = auto_composite_tensor.auto_composite_tensor(
        reshape.Reshape, omit_kwargs=('name', ))
    _CompositeTransformedDistribution = auto_composite_tensor.auto_composite_tensor(
        transformed_distribution.TransformedDistribution,
        omit_kwargs=('name', 'kwargs_split_fn', 'parameters'))
    _CompositeBatchBroadcast = auto_composite_tensor.auto_composite_tensor(
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for auto_composite_tensor."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf

from tensorflow_probability.python.internal import auto_composite_tensor as auto_ct
from tensorflow_probability.python.internal import test_util


AutoIdentity = auto_ct.auto_composite_tensor(tf.linalg.LinearOperatorIdentity)
AutoDiag = auto_ct.auto_composite_tensor(tf.linalg.LinearOperatorDiag)
AutoBlockDiag = auto_ct.auto_composite_tensor(tf.linalg.LinearOperatorBlockDiag)


class AutoCompositeTensorTest(test_util.TestCase):

  def test_example(self):
    @auto_ct.auto_composite_tensor
    class Adder(object):

      def __init__(self, x, y):
        self._x = tf.convert_to_tensor(x)
        self._y = tf.convert_to_tensor(y)

      def xpy(self):