예제 #1
0
    def _fn(dtype,
            shape,
            name,
            trainable,
            add_variable_fn,
            initializer=tf.random_normal_initializer(stddev=0.1),
            regularizer=None,
            constraint=None,
            **kwargs):
        loc_scale_fn = tensor_loc_scale_fn(loc_initializer=initializer,
                                           loc_regularizer=regularizer,
                                           loc_constraint=constraint,
                                           **kwargs)

        loc, scale = loc_scale_fn(dtype, shape, name, trainable,
                                  add_variable_fn)
        prec = tfp.util.DeferredTensor(scale,
                                       precision_from_scale,
                                       name='precision')
        if scale is None:
            dist = tfd.Deterministic(loc=loc)
        else:
            loc_reparametrized, scale_reparametrized = \
                reparametrize_loc_scale(loc, prec, loc_ratio, prec_ratio)
            dist = tfd.Normal(loc=loc_reparametrized,
                              scale=scale_reparametrized)
        batch_ndims = tf.size(dist.batch_shape_tensor())
        return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
예제 #2
0
 def new(params,
         event_shape=(),
         log_prob=None,
         reinterpreted_batch_ndims=None,
         validate_args=False,
         name='DeterministicLayer'):
     """Create the distribution instance from a `params` vector."""
     params = tf.convert_to_tensor(value=params, name='params')
     event_shape = dist_util.expand_to_vector(
         tf.convert_to_tensor(value=event_shape,
                              name='event_shape',
                              dtype=tf.int32),
         tensor_name='event_shape',
     )
     output_shape = tf.concat(
         [tf.shape(input=params)[:-1], event_shape],
         axis=0,
     )
     dist = tfd.Deterministic(loc=tf.reshape(params, output_shape),
                              validate_args=validate_args,
                              name=name)
     # override the log-prob function
     if log_prob is not None and callable(log_prob):
         dist.log_prob = types.MethodType(log_prob, dist)
     # independent
     if reinterpreted_batch_ndims is not None and reinterpreted_batch_ndims > 0:
         dist = tfd.Independent(
             dist, reinterpreted_batch_ndims=int(reinterpreted_batch_ndims))
     return dist
예제 #3
0
        def model():
            i = yield Root(tfd.Categorical(probs=initial_prob, dtype=tf.int32))

            for t in range(n_steps - 1):
                i = yield tfd.Categorical(probs=tf.gather(
                    transition_matrix, i),
                                          dtype=tf.int32)

            yield tfd.Deterministic(i)
예제 #4
0
 def model():
     weights = yield from nest_util.map_structure_coroutine(
         _horseshoe,
         scale={
             'a': tf.ones([5]) * 100.,
             'b': tf.ones([2]) * 1e-2
         },
         _with_tuple_paths=True)
     yield tfd.Deterministic(
         tf.sqrt(tf.norm(weights['a'])**2 + tf.norm(weights['b'])**2),
         name='weights_norm')
예제 #5
0
 def transition_fn(_, previous_state):
   return tfd.JointDistributionNamed(
       {
           # The autoregressive coefficients and the `log_scale` each follow
           # an independent slow-moving random walk.
           'coefs': tfd.Independent(
               tfd.Normal(loc=previous_state['coefs'], scale=0.01),
               reinterpreted_batch_ndims=1),
           'log_scale': tfd.Normal(loc=previous_state['log_scale'],
                                   scale=0.01),
           # The level is a linear combination of the previous *two* levels,
           # with additional noise of scale `exp(log_scale)`.
           'level': lambda coefs, log_scale: tfd.Normal(  # pylint: disable=g-long-lambda
               loc=(coefs[..., 0] * previous_state['level'] +
                    coefs[..., 1] * previous_state['previous_level']),
               scale=tf.exp(log_scale)),
           # Store the previous level to access at the next step.
           'previous_level': tfd.Deterministic(previous_state['level'])})
예제 #6
0
 def transition_fn(_, previous_state):
   return tfd.JointDistributionNamedAutoBatched(
       # The previous state may include batch dimensions. Since the log scale
       # is a scalar quantity, its shape is the batch shape.
       batch_ndims=ps.rank(previous_state['log_scale']),
       model={
           # The autoregressive coefficients and the `log_scale` each follow
           # an independent slow-moving random walk.
           'coefs': tfd.Normal(loc=previous_state['coefs'], scale=0.01),
           'log_scale': tfd.Normal(loc=previous_state['log_scale'],
                                   scale=0.01),
           # The level is a linear combination of the previous *two* levels,
           # with additional noise of scale `exp(log_scale)`.
           'level': lambda coefs, log_scale: tfd.Normal(  # pylint: disable=g-long-lambda
               loc=(coefs[..., 0] * previous_state['level'] +
                    coefs[..., 1] * previous_state['previous_level']),
               scale=tf.exp(log_scale)),
           # Store the previous level to access at the next step.
           'previous_level': tfd.Deterministic(previous_state['level'])})
예제 #7
0
파일: util.py 프로젝트: ywangV/probability
  def _fn(dtype, shape, name, trainable, add_variable_fn):
    """Creates multivariate `Deterministic` or `Normal` distribution.

    Args:
      dtype: Type of parameter's event.
      shape: Python `list`-like representing the parameter's event shape.
      name: Python `str` name prepended to any created (or existing)
        `tf.Variable`s.
      trainable: Python `bool` indicating all created `tf.Variable`s should be
        added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
      add_variable_fn: `tf.get_variable`-like `callable` used to create (or
        access existing) `tf.Variable`s.

    Returns:
      Multivariate `Deterministic` or `Normal` distribution.
    """
    loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn)
    if scale is None:
      dist = tfd.Deterministic(loc=loc)
    else:
      dist = tfd.Normal(loc=loc, scale=scale)
    batch_ndims = tf.size(input=dist.batch_shape_tensor())
    return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
예제 #8
0
class MarkovChainBijectorTest(test_util.TestCase):

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        dict(testcase_name='deterministic_prior',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)),
        dict(testcase_name='deterministic_transition',
             prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='fully_deterministic',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='mvn_diag',
             prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]],
                                                          scale_diag=[1.])),
             transition_fn=lambda _, x: tfd.VectorDeterministic(x)),
        dict(testcase_name='docstring_dirichlet',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 {'probs': tfd.Dirichlet([1., 1.])}),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 {
                     'probs':
                     tfd.MultivariateNormalDiag(loc=x['probs'],
                                                scale_diag=[0.1, 0.1])
                 },
                 batch_ndims=ps.rank(x['probs']))),
        dict(testcase_name='uniform_step',
             prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])),
             transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)),
        dict(testcase_name='joint_distribution',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=2,
                 model={
                     'a':
                     tfd.Gamma(tf.zeros([5]), 1.),
                     'b':
                     lambda a: (tfb.Reshape(event_shape_in=[4, 3],
                                            event_shape_out=[2, 3, 2])
                                (tfd.Independent(tfd.Normal(
                                    loc=tf.zeros([5, 4, 3]),
                                    scale=a[..., tf.newaxis, tf.newaxis]),
                                                 reinterpreted_batch_ndims=2)))
                 }),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=ps.rank_from_shape(x['a'].shape),
                 model={
                     'a':
                     tfd.Normal(loc=x['a'], scale=1.),
                     'b':
                     lambda a: tfd.Deterministic(x['b'] + a[
                         ..., tf.newaxis, tf.newaxis, tf.newaxis])
                 })),
        dict(testcase_name='nested_chain',
             prior_fn=lambda: tfd.
             MarkovChain(initial_state_prior=tfb.Split(2)
                         (tfd.MultivariateNormalDiag(0., [1., 2.])),
                         transition_fn=lambda _, x: tfb.Split(2)
                         (tfd.MultivariateNormalDiag(x[0], [1., 2.])),
                         num_steps=6),
             transition_fn=(
                 lambda _, x: tfd.JointDistributionSequentialAutoBatched(
                     [
                         tfd.MultivariateNormalDiag(x[0], [1.]),
                         tfd.MultivariateNormalDiag(x[1], [1.])
                     ],
                     batch_ndims=ps.rank(x[0])))))
    # pylint: enable=g-long-lambda
    def test_default_bijector(self, prior_fn, transition_fn):
        chain = tfd.MarkovChain(initial_state_prior=prior_fn(),
                                transition_fn=transition_fn,
                                num_steps=7)

        y = self.evaluate(chain.sample(seed=test_util.test_seed()))
        bijector = chain.experimental_default_event_space_bijector()

        self.assertAllEqual(chain.batch_shape_tensor(),
                            bijector.experimental_batch_shape_tensor())

        x = bijector.inverse(y)
        yy = bijector.forward(tf.nest.map_structure(
            tf.identity, x))  # Bypass bijector cache.
        self.assertAllCloseNested(y, yy)

        chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape,
                                                  chain.event_shape_tensor())
        self.assertAllEqualNested(bijector.inverse_min_event_ndims,
                                  chain_event_ndims)

        ildj = bijector.inverse_log_det_jacobian(
            tf.nest.map_structure(tf.identity, y),  # Bypass bijector cache.
            event_ndims=chain_event_ndims)
        if not bijector.is_constant_jacobian:
            self.assertAllEqual(ildj.shape, chain.batch_shape)
        fldj = bijector.forward_log_det_jacobian(
            tf.nest.map_structure(tf.identity, x),  # Bypass bijector cache.
            event_ndims=bijector.inverse_event_ndims(chain_event_ndims))
        self.assertAllClose(ildj, -fldj)

        # Verify that event shapes are passed through and flattened/unflattened
        # correctly.
        inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape)
        x_event_shapes = tf.nest.map_structure(
            lambda t, nd: t.shape[ps.rank(t) - nd:], x,
            bijector.forward_min_event_ndims)
        self.assertAllEqualNested(inverse_event_shapes, x_event_shapes)
        forward_event_shapes = bijector.forward_event_shape(
            inverse_event_shapes)
        self.assertAllEqualNested(forward_event_shapes, chain.event_shape)

        # Verify that the outputs of other methods have the correct structure.
        inverse_event_shape_tensors = bijector.inverse_event_shape_tensor(
            chain.event_shape_tensor())
        self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes)
        forward_event_shape_tensors = bijector.forward_event_shape_tensor(
            inverse_event_shape_tensors)
        self.assertAllEqualNested(forward_event_shape_tensors,
                                  chain.event_shape_tensor())