Пример #1
0
 def new(params,
         event_shape='auto',
         pre_softplus=False,
         clip_for_stable=True,
         validate_args=False,
         name=None):
     """Create the distribution instance from a `params` vector."""
     event_shape = _preprocess_eventshape(params, event_shape)
     with tf.compat.v1.name_scope(name, 'Dirichlet', [params, event_shape]):
         params = tf.convert_to_tensor(value=params, name='params')
         event_shape = dist_util.expand_to_vector(tf.convert_to_tensor(
             value=event_shape, name='event_shape', dtype=tf.int32),
                                                  tensor_name='event_shape')
         output_shape = tf.concat([
             tf.shape(input=params)[:-1],
             event_shape,
         ],
                                  axis=0)
         # Clips the Dirichlet parameters to the numerically stable KL region
         if pre_softplus:
             params = tf.nn.softplus(params)
         if clip_for_stable:
             params = tf.clip_by_value(params, 1e-3, 1e3)
         return tfd.Independent(
             tfd.Dirichlet(concentration=tf.reshape(params, output_shape),
                           validate_args=validate_args),
             reinterpreted_batch_ndims=tf.size(input=event_shape),
             validate_args=validate_args)
Пример #2
0
    def testGradientsThroughParams(self):
        logits = tf.Variable(np.zeros((3, 5, 2)),
                             dtype=tf.float32,
                             shape=tf.TensorShape([None, None, 2]))
        concentration = tf.Variable(np.ones((3, 5, 4)),
                                    dtype=tf.float32,
                                    shape=tf.TensorShape(None))
        loc = tf.Variable(np.zeros((3, 5, 4)),
                          dtype=tf.float32,
                          shape=tf.TensorShape(None))
        scale = tf.Variable(1., dtype=tf.float32, shape=tf.TensorShape(None))

        dist = tfd.Mixture(tfd.Categorical(logits=logits),
                           components=[
                               tfd.Dirichlet(concentration),
                               tfd.MultivariateNormalDiag(
                                   loc=loc, scale_identity_multiplier=scale)
                           ],
                           use_static_graph=self.use_static_graph,
                           validate_args=True)

        with tf.GradientTape() as tape:
            loss = tf.reduce_sum(dist.log_prob(tf.ones((3, 5, 4)) / 4.))
        grad = tape.gradient(loss, dist.trainable_variables)
        self.assertLen(grad, 4)
        self.assertAllNotNone(grad)
  def test_bug170030378(self):
    n_item = 50
    n_rater = 7

    stream = test_util.test_seed_stream()
    weight = self.evaluate(
        tfd.Sample(tfd.Dirichlet([0.25, 0.25]), n_item).sample(seed=stream()))
    mixture_dist = tfd.Categorical(probs=weight)  # batch_shape=[50]

    rater_sensitivity = self.evaluate(
        tfd.Sample(tfd.Beta(5., 1.), n_rater).sample(seed=stream()))
    rater_specificity = self.evaluate(
        tfd.Sample(tfd.Beta(2., 5.), n_rater).sample(seed=stream()))

    probs = tf.stack([rater_sensitivity, rater_specificity])[None, ...]

    components_dist = tfd.BatchBroadcast(  # batch_shape=[50, 2]
        tfd.Independent(tfd.Bernoulli(probs=probs),
                        reinterpreted_batch_ndims=1),
        [50, 2])

    obs_dist = tfd.MixtureSameFamily(mixture_dist, components_dist)

    observed = self.evaluate(obs_dist.sample(seed=stream()))
    mixture_logp = obs_dist.log_prob(observed)

    expected_logp = tf.math.reduce_logsumexp(
        tf.math.log(weight) + components_dist.distribution.log_prob(
            observed[:, None, ...]),
        axis=-1)
    self.assertAllClose(expected_logp, mixture_logp)
Пример #4
0
    def testExcessiveConcretizationOfParams(self):
        logits = tfp_hps.defer_and_count_usage(
            tf.Variable(np.zeros((3, 5, 2)),
                        dtype=tf.float32,
                        shape=tf.TensorShape([None, None, 2]),
                        name='logits'))
        concentration = tfp_hps.defer_and_count_usage(
            tf.Variable(np.ones((3, 5, 4)),
                        dtype=tf.float32,
                        shape=tf.TensorShape(None),
                        name='concentration'))
        loc = tfp_hps.defer_and_count_usage(
            tf.Variable(np.zeros((3, 5, 4)),
                        dtype=tf.float32,
                        shape=tf.TensorShape(None),
                        name='loc'))
        scale = tfp_hps.defer_and_count_usage(
            tf.Variable(1.,
                        dtype=tf.float32,
                        shape=tf.TensorShape(None),
                        name='scale'))

        dist = tfd.Mixture(tfd.Categorical(logits=logits),
                           components=[
                               tfd.Dirichlet(concentration),
                               tfd.Independent(tfd.Normal(loc=loc,
                                                          scale=scale),
                                               reinterpreted_batch_ndims=1)
                           ],
                           use_static_graph=self.use_static_graph,
                           validate_args=True)

        for method in ('batch_shape_tensor', 'event_shape_tensor',
                       'entropy_lower_bound'):
            with tfp_hps.assert_no_excessive_var_usage(method,
                                                       max_permissible=2):
                getattr(dist, method)()

        with tfp_hps.assert_no_excessive_var_usage('sample',
                                                   max_permissible=2):
            dist.sample(seed=test_util.test_seed())

        for method in ('prob', 'log_prob'):
            with tfp_hps.assert_no_excessive_var_usage('method',
                                                       max_permissible=2):
                getattr(dist, method)(tf.ones((3, 5, 4)) / 4.)

        # TODO(b/140579567): The `stddev()` and `variance()` methods require
        # calling both:
        #  - `self.components[i].mean()`
        #  - `self.components[i].stddev()`
        # Thus, these methods incur an additional concretization (or two if
        # `validate_args=True` for `self.components[i]`).

        for method in ('stddev', 'variance'):
            with tfp_hps.assert_no_excessive_var_usage(method,
                                                       max_permissible=3):
                getattr(dist, method)()
Пример #5
0
 def new(params,
         event_shape=(),
         concentration_activation=softplus1,
         concentration_clip=True,
         validate_args=False,
         name="DirichletLayer"):
   r"""Create the distribution instance from a `params` vector."""
   params = tf.convert_to_tensor(value=params, name='params')
   # Clips the Dirichlet parameters to the numerically stable KL region
   concentration_activation = parse_activation(concentration_activation, 'tf')
   params = concentration_activation(params)
   if concentration_clip:
     params = tf.clip_by_value(params, 1e-3, 1e3)
   return tfd.Dirichlet(concentration=params,
                        validate_args=validate_args,
                        name=name)
Пример #6
0
    def get_distributions(self, validate_args=False):
        self.dist1 = tfd.MultivariateNormalDiag(loc=self.maybe_static(
            tf.zeros(self.batch_dim_1 + self.event_dim_1, dtype=self.dtype),
            self.is_static),
                                                scale_diag=self.maybe_static(
                                                    tf.ones(self.batch_dim_1 +
                                                            self.event_dim_1,
                                                            dtype=self.dtype),
                                                    self.is_static))

        self.dist2 = tfd.OneHotCategorical(logits=self.maybe_static(
            tf.zeros(self.batch_dim_2 + self.event_dim_2), self.is_static),
                                           dtype=self.dtype)

        self.dist3 = tfd.Dirichlet(
            self.maybe_static(
                tf.zeros(self.batch_dim_3 + self.event_dim_3,
                         dtype=self.dtype), self.is_static))
        return batch_concat.BatchConcat(
            distributions=[self.dist1, self.dist2, self.dist3],
            axis=self.axis,
            validate_args=validate_args)
Пример #7
0
class MarkovChainBijectorTest(test_util.TestCase):

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        dict(testcase_name='deterministic_prior',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)),
        dict(testcase_name='deterministic_transition',
             prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='fully_deterministic',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='mvn_diag',
             prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]],
                                                          scale_diag=[1.])),
             transition_fn=lambda _, x: tfd.VectorDeterministic(x)),
        dict(testcase_name='docstring_dirichlet',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 {'probs': tfd.Dirichlet([1., 1.])}),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 {
                     'probs':
                     tfd.MultivariateNormalDiag(loc=x['probs'],
                                                scale_diag=[0.1, 0.1])
                 },
                 batch_ndims=ps.rank(x['probs']))),
        dict(testcase_name='uniform_step',
             prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])),
             transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)),
        dict(testcase_name='joint_distribution',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=2,
                 model={
                     'a':
                     tfd.Gamma(tf.zeros([5]), 1.),
                     'b':
                     lambda a: (tfb.Reshape(event_shape_in=[4, 3],
                                            event_shape_out=[2, 3, 2])
                                (tfd.Independent(tfd.Normal(
                                    loc=tf.zeros([5, 4, 3]),
                                    scale=a[..., tf.newaxis, tf.newaxis]),
                                                 reinterpreted_batch_ndims=2)))
                 }),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=ps.rank_from_shape(x['a'].shape),
                 model={
                     'a':
                     tfd.Normal(loc=x['a'], scale=1.),
                     'b':
                     lambda a: tfd.Deterministic(x['b'] + a[
                         ..., tf.newaxis, tf.newaxis, tf.newaxis])
                 })),
        dict(testcase_name='nested_chain',
             prior_fn=lambda: tfd.
             MarkovChain(initial_state_prior=tfb.Split(2)
                         (tfd.MultivariateNormalDiag(0., [1., 2.])),
                         transition_fn=lambda _, x: tfb.Split(2)
                         (tfd.MultivariateNormalDiag(x[0], [1., 2.])),
                         num_steps=6),
             transition_fn=(
                 lambda _, x: tfd.JointDistributionSequentialAutoBatched(
                     [
                         tfd.MultivariateNormalDiag(x[0], [1.]),
                         tfd.MultivariateNormalDiag(x[1], [1.])
                     ],
                     batch_ndims=ps.rank(x[0])))))
    # pylint: enable=g-long-lambda
    def test_default_bijector(self, prior_fn, transition_fn):
        chain = tfd.MarkovChain(initial_state_prior=prior_fn(),
                                transition_fn=transition_fn,
                                num_steps=7)

        y = self.evaluate(chain.sample(seed=test_util.test_seed()))
        bijector = chain.experimental_default_event_space_bijector()

        self.assertAllEqual(chain.batch_shape_tensor(),
                            bijector.experimental_batch_shape_tensor())

        x = bijector.inverse(y)
        yy = bijector.forward(tf.nest.map_structure(
            tf.identity, x))  # Bypass bijector cache.
        self.assertAllCloseNested(y, yy)

        chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape,
                                                  chain.event_shape_tensor())
        self.assertAllEqualNested(bijector.inverse_min_event_ndims,
                                  chain_event_ndims)

        ildj = bijector.inverse_log_det_jacobian(
            tf.nest.map_structure(tf.identity, y),  # Bypass bijector cache.
            event_ndims=chain_event_ndims)
        if not bijector.is_constant_jacobian:
            self.assertAllEqual(ildj.shape, chain.batch_shape)
        fldj = bijector.forward_log_det_jacobian(
            tf.nest.map_structure(tf.identity, x),  # Bypass bijector cache.
            event_ndims=bijector.inverse_event_ndims(chain_event_ndims))
        self.assertAllClose(ildj, -fldj)

        # Verify that event shapes are passed through and flattened/unflattened
        # correctly.
        inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape)
        x_event_shapes = tf.nest.map_structure(
            lambda t, nd: t.shape[ps.rank(t) - nd:], x,
            bijector.forward_min_event_ndims)
        self.assertAllEqualNested(inverse_event_shapes, x_event_shapes)
        forward_event_shapes = bijector.forward_event_shape(
            inverse_event_shapes)
        self.assertAllEqualNested(forward_event_shapes, chain.event_shape)

        # Verify that the outputs of other methods have the correct structure.
        inverse_event_shape_tensors = bijector.inverse_event_shape_tensor(
            chain.event_shape_tensor())
        self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes)
        forward_event_shape_tensors = bijector.forward_event_shape_tensor(
            inverse_event_shape_tensors)
        self.assertAllEqualNested(forward_event_shape_tensors,
                                  chain.event_shape_tensor())