Пример #1
0
  def testMinEventNdimsWithPartiallyDependentJointMap(self):

    dependent = tfb.Chain([tfb.Split(2), tfb.Invert(tfb.Split(2))])
    wrap_in_list = tfb.Restructure(input_structure=[0, 1],
                                   output_structure=[[0, 1]])
    dependent_as_chain = tfb.Chain([
        tfb.Invert(wrap_in_list),
        tfb.JointMap([dependent]),
        wrap_in_list])
    self.assertAllEqualNested(dependent.forward_min_event_ndims,
                              dependent_as_chain.forward_min_event_ndims)
    self.assertAllEqualNested(dependent.inverse_min_event_ndims,
                              dependent_as_chain.inverse_min_event_ndims)
    self.assertAllEqualNested(dependent._parts_interact,
                              dependent_as_chain._parts_interact)
  def test_unknown_event_rank(self):
    if tf.executing_eagerly():
      self.skipTest('Eager execution.')
    unknown_rank_dist = tfd.Independent(
        tfd.Normal(loc=tf.ones([2, 1, 3]), scale=2.),
        reinterpreted_batch_ndims=tf1.placeholder_with_default(1, shape=[]))
    td = tfd.TransformedDistribution(
        distribution=unknown_rank_dist,
        bijector=tfb.Scale(1.),
        validate_args=True)
    self.assertEqual(td.batch_shape, tf.TensorShape(None))
    self.assertEqual(td.event_shape, tf.TensorShape(None))
    self.assertAllEqual(td.batch_shape_tensor(), [2, 1])
    self.assertAllEqual(td.event_shape_tensor(), [3])

    joint_td = tfd.TransformedDistribution(
        distribution=tfd.JointDistributionSequentialAutoBatched(
            [unknown_rank_dist, unknown_rank_dist]),
        bijector=tfb.Invert(tfb.Split(2)),
        validate_args=True)
    # Note that the current behavior is conservative; we could also correctly
    # return a batch shape of `[]` in this case.
    self.assertEqual(joint_td.batch_shape, tf.TensorShape(None))
    self.assertEqual(joint_td.event_shape, tf.TensorShape(None))
    self.assertAllEqual(joint_td.batch_shape_tensor(), [])
    self.assertAllEqual(joint_td.event_shape_tensor(), [2, 1, 6])
Пример #3
0
 def testAssertRaisesWrongNumSplits(self):
     num_splits = 4
     y = [np.random.rand(2, 3)] * 3
     bijector = tfb.Split(num_splits, validate_args=True)
     with self.assertRaisesRegexp(ValueError,
                                  "don't have the same sequence length"):
         self.evaluate(bijector.inverse(y))
Пример #4
0
    def testEventShape(self, num_or_size_splits, expected_split_sizes):
        split_sizes = self.build_input(num_or_size_splits)
        total_size = np.sum(expected_split_sizes)
        shape_in_static = tf.TensorShape([total_size, 2])
        shape_out_static = [
            tf.TensorShape([d, 2]) for d in expected_split_sizes
        ]
        bijector = tfb.Split(num_or_size_splits=split_sizes,
                             axis=-2,
                             validate_args=True)

        output_shape = [[None, 2]] * 3
        self.assertAllEqual([
            s.as_list() for s in bijector.forward_event_shape(shape_in_static)
        ], output_shape)
        self.assertEqual(
            bijector.inverse_event_shape(shape_out_static).as_list(),
            shape_in_static.as_list())

        self.assertAllEqual([
            s.as_list() for s in bijector.forward_event_shape(
                tf.TensorShape([total_size, None]))
        ], [[None, None]] * 3)
        self.assertAllEqual(
            bijector.inverse_event_shape([[None, 3], [3, 3],
                                          [2, 3]]).as_list(), [None, 3])
        self.assertAllEqual(
            bijector.inverse_event_shape([[2, 3], [None, 3],
                                          [2, 3]]).as_list(), [None, 3])
Пример #5
0
 def _testAssertRaisesMismatchedOutputShapes(self):
   split_sizes = self.build_input([5, -1, 3])
   y = [np.random.rand(3, 1, i) for i in [6, 2, 3]]
   bijector = tfb.Split(split_sizes, validate_args=True)
   if tf.get_static_value(split_sizes) is not None:
     with self.assertRaisesError('does not match expected `split_size`'):
       self.evaluate(bijector.inverse(y))
    def testStddev(self):
        base_stddev = 2.
        shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
        scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32)
        expected_stddev = tf.abs(base_stddev * scale)
        normal = self._cls()(
            distribution=tfd.Normal(loc=tf.zeros_like(shift),
                                    scale=base_stddev * tf.ones_like(scale),
                                    validate_args=True),
            bijector=tfb.Chain(
                [tfb.Shift(shift=shift),
                 tfb.Scale(scale=scale)],
                validate_args=True),
            validate_args=True)
        self.assertAllClose(expected_stddev, normal.stddev())
        self.assertAllClose(expected_stddev**2, normal.variance())

        split_normal = self._cls()(distribution=tfd.Independent(
            normal, reinterpreted_batch_ndims=1),
                                   bijector=tfb.Split(3),
                                   validate_args=True)
        self.assertAllCloseNested(
            tf.split(expected_stddev, num_or_size_splits=3, axis=-1),
            split_normal.stddev())

        scaled_normal = self._cls()(distribution=tfd.Independent(
            normal, reinterpreted_batch_ndims=1),
                                    bijector=tfb.ScaleMatvecTriL([[1., 0.],
                                                                  [-1., 2.]]),
                                    validate_args=True)
        with self.assertRaisesRegex(NotImplementedError,
                                    'is a multivariate transformation'):
            scaled_normal.stddev()
Пример #7
0
  def testInverseWithEventDimsOmitted(self):
    bij = tfb.Split(2)

    self.assertAllEqual(
        0.0,
        self.evaluate(bij.inverse_log_det_jacobian(
            [tf.ones((3, 4, 5)), tf.ones((3, 4, 5))])))
Пример #8
0
  def testEventShape(self, num_or_size_splits, expected_split_sizes):
    num_or_size_splits = self.build_input(num_or_size_splits)
    total_size = np.sum(expected_split_sizes)
    shape_in_static = tf.TensorShape([total_size, 2])
    shape_out_static = [
        tf.TensorShape([d, 2]) for d in expected_split_sizes]
    bijector = tfb.Split(
        num_or_size_splits=num_or_size_splits, axis=-2, validate_args=True)

    # Test that forward_ and inverse_event_shape are correct when
    # event_shape_in/_out are statically known, even when the input shapes
    # are only partially specified.
    self.assertAllEqual(
        bijector.forward_event_shape(shape_in_static), shape_out_static)
    self.assertEqual(
        bijector.inverse_event_shape(shape_out_static), shape_in_static)

    # Shape is always known for splitting in eager mode, so we skip these tests.
    if tf.executing_eagerly():
      return
    self.assertAllEqual(
        [s.as_list() for s in bijector.forward_event_shape(
            tf.TensorShape([total_size, None]))],
        [[d, None] for d in expected_split_sizes])

    if bijector.split_sizes is None:
      static_split_sizes = tensorshape_util.constant_value_as_shape(
          expected_split_sizes).as_list()
    else:
      static_split_sizes = tensorshape_util.constant_value_as_shape(
          num_or_size_splits).as_list()

    static_total_size = None if None in static_split_sizes else total_size

    # Test correctness with an inverse input dimension of None that coincides
    # with the `-1` element in not-fully specified `split_sizes`
    shape_with_maybe_unknown_dim = (
        [[None, 3]] + [[d, 3] for d in expected_split_sizes[1:]])
    self.assertAllEqual(
        bijector.inverse_event_shape(shape_with_maybe_unknown_dim).as_list(),
        [static_total_size, 3])

    # Test correctness with an input dimension of None that does not coincide
    # with a `-1` split_size.
    shape_with_deducable_dim = [[d, 3] for d in expected_split_sizes]
    shape_with_deducable_dim[2] = [None, 3]
    self.assertAllEqual(
        bijector.inverse_event_shape(
            shape_with_deducable_dim).as_list(), [total_size, 3])

    # Test correctness for an input shape of known rank only.
    if bijector.split_sizes is not None:
      shape_with_unknown_total = (
          [[d, None] for d in static_split_sizes])
    else:
      shape_with_unknown_total = [[None, None]] * len(expected_split_sizes)
    self.assertAllEqual(
        [s.as_list() for s in bijector.forward_event_shape(
            tf.TensorShape([None, None]))],
        shape_with_unknown_total)
Пример #9
0
  def test_single_part_str_repr_match_expected(self):
    bij = tfb.Exp()
    self.assertContainsInOrder(
        ['tfp.bijectors.Exp("exp", batch_shape=[], min_event_ndims=0)'],
        str(bij))
    self.assertContainsInOrder(
        ["<tfp.bijectors.Exp 'exp' batch_shape=[] forward_min_event_ndims=0 "
         "inverse_min_event_ndims=0 dtype_x=? dtype_y=?>"],
        repr(bij))

    bij = tfb.Scale([1., 1.], name='myscale')
    self.assertContainsInOrder(
        ['tfp.bijectors.Scale("myscale", batch_shape=[2], min_event_ndims=0, '
         'dtype=float32)'],
        str(bij))
    self.assertContainsInOrder(
        ["<tfp.bijectors.Scale 'myscale' batch_shape=[2] "
         "forward_min_event_ndims=0 inverse_min_event_ndims=0 dtype_x=float32 "
         "dtype_y=float32>"],
        repr(bij))

    bij = tfb.Split([3, 4, 2], name='s_p_l_i_t')
    self.assertContainsInOrder(
        ['tfp.bijectors.Split("s_p_l_i_t", batch_shape=[], '
         'forward_min_event_ndims=1, inverse_min_event_ndims=[1, 1, 1])'],
        str(bij))
    self.assertContainsInOrder(
        ["<tfp.bijectors.Split 's_p_l_i_t' batch_shape=[] "
         "forward_min_event_ndims=1 inverse_min_event_ndims=[1, 1, 1] "
         "dtype_x=? dtype_y=[?, ?, ?]>"
         ], repr(bij))
Пример #10
0
 def _testAssertRaisesTooSmallInputShape(self):
     split_sizes = self.build_input([-1, 2, 3])
     x = tf.Variable(tf.zeros((2, 4)), shape=None)
     self.evaluate(x.initializer)
     bijector = tfb.Split(split_sizes, validate_args=True)
     with self.assertRaisesError('size of the input along `axis`'):
         self.evaluate(bijector.forward(x))
Пример #11
0
 def testAssertRaisesWrongNumberOfOutputs(self):
     split_sizes = self.build_input([5, 3, -1])
     y = [np.random.rand(2, i) for i in [5, 3, 1, 2]]
     bijector = tfb.Split(split_sizes, validate_args=True)
     with self.assertRaisesRegexp(ValueError,
                                  "don't have the same sequence length"):
         self.evaluate(bijector.inverse(y))
  def test_batch_broadcast_vector_to_parts(self):
    batch_shape = [4, 2]
    true_split_sizes = [1, 3, 2]

    base_event_size = sum(true_split_sizes)
    # Base dist with no batch shape (will require broadcasting).
    base_dist = tfd.MultivariateNormalDiag(
        loc=tf.random.normal([base_event_size], seed=test_util.test_seed()),
        scale_diag=tf.exp(tf.random.normal([base_event_size],
                                           seed=test_util.test_seed())))

    # Bijector with batch shape in one part.
    bijector = tfb.Chain([tfb.JointMap([tfb.Identity(),
                                        tfb.Identity(),
                                        tfb.Shift(
                                            tf.ones(batch_shape +
                                                    [true_split_sizes[-1]]))]),
                          tfb.Split(true_split_sizes, axis=-1)])
    split_dist = tfd.TransformedDistribution(base_dist, bijector)
    self.assertAllEqual(split_dist.batch_shape, batch_shape)

    # Because one branch of the split has batch shape, TD should feed batches
    # of base samples *into* the split, so the batch shape propagates to all
    # branches.
    xs = split_dist.sample(seed=test_util.test_seed())
    self.assertAllEqualNested(
        tf.nest.map_structure(lambda x: x.shape, xs),
        [batch_shape + [d] for d in true_split_sizes])
  def testCovarianceNotImplemented(self):
    mvn = tfd.MultivariateNormalDiag(loc=[0., 0.], scale_diag=[1., 2.])

    # Non-affine bijector.
    with self.assertRaisesRegex(
        NotImplementedError, '`covariance` is not implemented'):
      tfd.TransformedDistribution(
          distribution=mvn, bijector=tfb.Exp()).covariance()

    # Non-injective bijector.
    with self.assertRaisesRegex(
        NotImplementedError, '`covariance` is not implemented'):
      tfd.TransformedDistribution(
          distribution=mvn, bijector=tfb.AbsoluteValue()).covariance()

    # Non-vector event shape.
    with self.assertRaisesRegex(
        NotImplementedError, '`covariance` is only implemented'):
      tfd.TransformedDistribution(
          distribution=mvn,
          bijector=tfb.Reshape(event_shape_out=[2, 1],
                               event_shape_in=[2])).covariance()

    # Multipart bijector.
    with self.assertRaisesRegex(
        NotImplementedError, '`covariance` is only implemented'):
      tfd.TransformedDistribution(
          distribution=mvn, bijector=tfb.Split(2)).covariance()
Пример #14
0
 def testAssertRaisesWrongNumberOfOutputs(self):
     split_sizes = self.build_input([5, 3, -1])
     y = [np.random.rand(2, i) for i in [5, 3, 1, 2]]
     bijector = tfb.Split(split_sizes, validate_args=True)
     with self.assertRaisesRegexp(ValueError,
                                  'does not match the number of splits'):
         self.evaluate(bijector.inverse(y))
Пример #15
0
 def testAssertRaisesWrongNumSplits(self):
     num_splits = 4
     y = [np.random.rand(2, 3)] * 3
     bijector = tfb.Split(num_splits, validate_args=True)
     with self.assertRaisesRegexp(ValueError,
                                  'does not match the number of splits'):
         self.evaluate(bijector.inverse(y))
Пример #16
0
 def testCompositeTensor(self):
   split_sizes = self.build_input([1, 2, 2])
   bijector = tfb.Split(split_sizes, validate_args=True)
   x = tf.ones([3, 2, 5])
   flat = tf.nest.flatten(bijector, expand_composites=True)
   unflat = tf.nest.pack_sequence_as(bijector, flat, expand_composites=True)
   self.assertAllClose(
       bijector.forward(x),
       tf.function(lambda b_: b_.forward(x))(unflat))
Пример #17
0
  def test_transform_parts_to_vector(self, known_split_sizes):
    batch_shape = [4, 2]
    true_split_sizes = [1, 3, 2]

    # Create a joint distribution with parts of the specified sizes.
    seed = test_util.test_seed_stream()
    component_dists = tf.nest.map_structure(
        lambda size: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
            loc=tf.random.normal(batch_shape + [size], seed=seed()),
            scale_diag=tf.exp(
                tf.random.normal(batch_shape + [size], seed=seed()))),
        true_split_sizes)
    base_dist = tfd.JointDistributionSequential(component_dists)

    # Transform to a vector-valued distribution by concatenating the parts.
    bijector = tfb.Invert(tfb.Split(known_split_sizes, axis=-1))

    with self.assertRaisesRegexp(ValueError, 'Overriding the batch shape'):
      tfd.TransformedDistribution(base_dist, bijector, batch_shape=[3])

    with self.assertRaisesRegexp(ValueError, 'Overriding the event shape'):
      tfd.TransformedDistribution(base_dist, bijector, event_shape=[3])

    concat_dist = tfd.TransformedDistribution(base_dist, bijector)
    self.assertAllEqual(concat_dist.event_shape, [sum(true_split_sizes)])
    self.assertAllEqual(self.evaluate(concat_dist.event_shape_tensor()),
                        [sum(true_split_sizes)])
    self.assertAllEqual(concat_dist.batch_shape, batch_shape)
    self.assertAllEqual(self.evaluate(concat_dist.batch_shape_tensor()),
                        batch_shape)

    # Since the Split bijector has (constant) unit Jacobian, the transformed
    # entropy and mean/mode should match the base entropy and (split) base
    # mean/mode.
    self.assertAllEqual(*self.evaluate(
        (base_dist.entropy(), concat_dist.entropy())))

    self.assertAllEqual(*self.evaluate(
        (concat_dist.mean(), bijector.forward(base_dist.mean()))))
    self.assertAllEqual(*self.evaluate(
        (concat_dist.mode(), bijector.forward(base_dist.mode()))))

    # Since the Split bijector has zero Jacobian, the transformed `log_prob`
    # and `prob` should match the base distribution.
    sample_shape = [3]
    x = base_dist.sample(sample_shape, seed=seed())
    y = bijector.forward(x)
    for attr in ('log_prob', 'prob'):
      base_attr = getattr(base_dist, attr)(x)
      concat_attr = getattr(concat_dist, attr)(y)
      self.assertAllClose(*self.evaluate((base_attr, concat_attr)))

    # Test that `.sample()` works and returns a result of the expected structure
    # and shape.
    y_sampled = concat_dist.sample(sample_shape, seed=seed())
    self.assertAllEqual(y.shape, y_sampled.shape)
    def test_transform_vector_to_parts(self, known_split_sizes):
        batch_shape = [4, 2]
        true_split_sizes = [1, 3, 2]

        base_event_size = sum(true_split_sizes)
        base_dist = tfd.MultivariateNormalDiag(
            loc=tf.random.normal(batch_shape + [base_event_size],
                                 seed=test_util.test_seed()),
            scale_diag=tf.exp(
                tf.random.normal(batch_shape + [base_event_size],
                                 seed=test_util.test_seed())))

        bijector = tfb.Split(known_split_sizes, axis=-1)
        split_dist = tfd.TransformedDistribution(base_dist, bijector)

        self.assertRegex(
            str(split_dist),
            '{}.*batch_shape.*event_shape.*dtype'.format(split_dist.name))

        expected_event_shape = [np.array([s]) for s in true_split_sizes]
        output_event_shape = [np.array(s) for s in split_dist.event_shape]
        self.assertAllEqual(output_event_shape, expected_event_shape)
        self.assertAllEqual(self.evaluate(split_dist.event_shape_tensor()),
                            expected_event_shape)
        self.assertAllEqual(split_dist.batch_shape, batch_shape)
        self.assertAllEqual(self.evaluate(split_dist.batch_shape_tensor()),
                            batch_shape)

        # Since the Split bijector has (constant) unit Jacobian, the transformed
        # entropy and mean/mode should match the base entropy and (split) base
        # mean/mode.
        self.assertAllEqual(*self.evaluate((base_dist.entropy(),
                                            split_dist.entropy())))
        self.assertAllEqualNested(
            *self.evaluate((split_dist.mean(),
                            bijector.forward(base_dist.mean()))))
        self.assertAllEqualNested(
            *self.evaluate((split_dist.mode(),
                            bijector.forward(base_dist.mode()))))

        # Since the Split bijector has zero Jacobian, the transformed `log_prob`
        # and `prob` should match the base distribution.
        sample_shape = [3]
        x = base_dist.sample(sample_shape, seed=test_util.test_seed())
        y = bijector.forward(x)
        for attr in ('log_prob', 'prob'):
            split_attr = getattr(split_dist, attr)(y)
            base_attr = getattr(base_dist, attr)(x)
            self.assertAllClose(*self.evaluate((base_attr, split_attr)),
                                rtol=1e-5)

        # Test that `.sample()` works and returns a result of the expected structure
        # and shape.
        y_sampled = split_dist.sample(sample_shape, seed=test_util.test_seed())
        self.assertAllEqual([x.shape for x in y], [x.shape for x in y_sampled])
Пример #19
0
 def test_invert_str_and_repr_match_expected(self):
     bij = tfb.Invert(tfb.Split([3, 4, 2]))
     self.assertContainsInOrder([
         'tfp.bijectors.Invert("invert_split", batch_shape=[], '
         'forward_min_event_ndims=[1, 1, 1], inverse_min_event_ndims=1, '
         'bijector=Split)'
     ], str(bij))
     self.assertContainsInOrder([
         "<tfp.bijectors.Invert 'invert_split' batch_shape=[] "
         "forward_min_event_ndims=[1, 1, 1] inverse_min_event_ndims=1 "
         "dtype_x=[?, ?, ?] dtype_y=? "
         "bijector=<tfp.bijectors.Split 'split' batch_shape=[] "
         "forward_min_event_ndims=1 inverse_min_event_ndims=[1, 1, 1] "
         "dtype_x=? dtype_y=[?, ?, ?]>>"
     ], repr(bij))
Пример #20
0
    def test_composition_str_and_repr_match_expected_dynamic_shape(self):
        bij = tfb.Chain([
            tfb.Exp(),
            tfb.Shift(self._tensor([1., 2.])),
            tfb.SoftmaxCentered()
        ])
        self.assertContainsInOrder([
            'tfp.bijectors.Chain(',
            ('min_event_ndims=1, bijectors=[Exp, Shift, SoftmaxCentered])')
        ], str(bij))
        self.assertContainsInOrder([
            '<tfp.bijectors.Chain ',
            ('batch_shape=? forward_min_event_ndims=1 inverse_min_event_ndims=1 '
             'dtype_x=float32 dtype_y=float32 bijectors=[<tfp.bijectors.Exp'),
            '>, <tfp.bijectors.Shift', '>, <tfp.bijectors.SoftmaxCentered',
            '>]>'
        ], repr(bij))

        bij = tfb.Chain([
            tfb.JointMap({
                'a': tfb.Exp(),
                'b': tfb.ScaleMatvecDiag(self._tensor([2., 2.]))
            }),
            tfb.Restructure({
                'a': 0,
                'b': 1
            }, [0, 1]),
            tfb.Split(2),
            tfb.Invert(tfb.SoftmaxCentered()),
        ])
        self.assertContainsInOrder([
            'tfp.bijectors.Chain(',
            ('forward_min_event_ndims=1, '
             'inverse_min_event_ndims={a: 1, b: 1}, '
             'bijectors=[JointMap({a: Exp, b: ScaleMatvecDiag}), '
             'Restructure, Split, Invert(SoftmaxCentered)])')
        ], str(bij))
        self.assertContainsInOrder([
            '<tfp.bijectors.Chain ',
            ('batch_shape=? forward_min_event_ndims=1 '
             "inverse_min_event_ndims={'a': 1, 'b': 1} dtype_x=float32 "
             "dtype_y={'a': ?, 'b': float32} "
             "bijectors=[<tfp.bijectors.JointMap "),
            '>, <tfp.bijectors.Restructure', '>, <tfp.bijectors.Split',
            '>, <tfp.bijectors.Invert', '>]>'
        ], repr(bij))
Пример #21
0
    def _testBijector(  # pylint: disable=invalid-name
            self, num_or_size_splits, expected_split_sizes, shape_in, axis):
        """Do a basic sanity check of forward, inverse, jacobian."""
        num_or_size_splits = self.build_input(num_or_size_splits)
        bijector = tfb.Split(num_or_size_splits, axis=axis, validate_args=True)

        self.assertStartsWith(bijector.name, 'split')
        x = np.random.rand(*shape_in)
        y = tf.split(x, num_or_size_splits, axis=axis)

        self.assertAllClose(self.evaluate(y),
                            self.evaluate(bijector.forward(x)),
                            atol=0.,
                            rtol=1e-2)
        self.assertAllClose(x,
                            self.evaluate(bijector.inverse(y)),
                            atol=0.,
                            rtol=1e-4)

        shape_out = []
        for d in expected_split_sizes:
            s = shape_in[:]
            s[axis] = d
            shape_out.append(self.build_input(s))

        shape_in_ = self.evaluate(
            bijector.inverse_event_shape_tensor(shape_out))
        self.assertAllEqual(shape_in_, shape_in)

        shape_in = self.build_input(shape_in)
        shape_out_ = self.evaluate(
            bijector.forward_event_shape_tensor(shape_in))
        self.assertAllEqual(shape_out_, self.evaluate(shape_out))

        event_ndims = abs(axis)
        inverse_event_ndims = [event_ndims for _ in expected_split_sizes]
        self.assertEqual(
            0.,
            self.evaluate(
                bijector.inverse_log_det_jacobian(
                    y, event_ndims=inverse_event_ndims)))
        self.assertEqual(
            0.,
            self.evaluate(
                bijector.forward_log_det_jacobian(x, event_ndims=event_ndims)))
Пример #22
0
  def test_slice_transformed_distribution_with_chain(self):
    dist = tfd.TransformedDistribution(
        distribution=tfd.MultivariateNormalDiag(
            loc=tf.zeros([4]), scale_diag=tf.ones([1, 4])),
        bijector=tfb.Chain([tfb.JointMap([tfb.Identity(),
                                          tfb.Shift(tf.ones([4, 3, 2]))]),
                            tfb.Split(2),
                            tfb.ScaleMatvecDiag(tf.ones([5, 1, 3, 4])),
                            tfb.Exp()]))
    self.assertAllEqual(dist.batch_shape_tensor(), [5, 4, 3])
    self.assertAllEqualNested(
        tf.nest.map_structure(lambda x: x.shape,
                              dist.sample(seed=test_util.test_seed())),
        [[5, 4, 3, 2], [5, 4, 3, 2]])

    sliced = dist[tf.newaxis, ..., 0, :, :-1]
    self.assertAllEqual(sliced.batch_shape_tensor(), [1, 4, 2])
    self.assertAllEqualNested(
        tf.nest.map_structure(lambda x: x.shape,
                              sliced.sample(seed=test_util.test_seed())),
        [[1, 4, 2, 2], [1, 4, 2, 2]])
Пример #23
0
 def testAssertRaisesUnknownNumSplits(self):
     split_sizes = tf1.placeholder_with_default([-1, 2, 1], shape=[None])
     with self.assertRaisesRegexp(
             ValueError, 'must have a statically-known number of elements'):
         tfb.Split(num_or_size_splits=split_sizes, validate_args=True)
Пример #24
0
class BijectorBatchShapesTest(test_util.TestCase):

  @parameterized.named_parameters(
      ('exp', tfb.Exp, None),
      ('scale',
       lambda: tfb.Scale(tf.ones([4, 2])), None),
      ('sigmoid',
       lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])), None),
      ('scale_matvec',
       lambda: tfb.ScaleMatvecDiag([[0.], [3.]]), None),
      ('invert',
       lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))), None),
      ('reshape',
       lambda: tfb.Reshape([1], event_shape_in=[1, 1]), None),
      ('chain',
       lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]),  # pylint: disable=g-long-lambda
                          tfb.Invert(tfb.Split(2))]),
       None),
      ('jointmap_01',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]),
      ('jointmap_11',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]),
      ('jointmap_20',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]),
      ('jointmap_22',
       lambda: tfb.JointMap([tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]),
      ('restructure_with_ragged_event_ndims',
       lambda: tfb.Restructure(input_structure=[0, 1],  # pylint: disable=g-long-lambda
                               output_structure={'a': 0, 'b': 1}),
       [0, 1]))
  def test_batch_shape_matches_output_shapes(self,
                                             bijector_fn,
                                             override_x_event_ndims=None):
    bijector = bijector_fn()
    if override_x_event_ndims is None:
      x_event_ndims = bijector.forward_min_event_ndims
      y_event_ndims = bijector.inverse_min_event_ndims
    else:
      x_event_ndims = override_x_event_ndims
      y_event_ndims = bijector.forward_event_ndims(x_event_ndims)

    # All ways of calculating the batch shape should yield the same result.
    batch_shape_x = bijector.experimental_batch_shape(
        x_event_ndims=x_event_ndims)
    batch_shape_y = bijector.experimental_batch_shape(
        y_event_ndims=y_event_ndims)
    self.assertEqual(batch_shape_x, batch_shape_y)

    batch_shape_tensor_x = bijector.experimental_batch_shape_tensor(
        x_event_ndims=x_event_ndims)
    batch_shape_tensor_y = bijector.experimental_batch_shape_tensor(
        y_event_ndims=y_event_ndims)
    self.assertAllEqual(batch_shape_tensor_x, batch_shape_tensor_y)
    self.assertAllEqual(batch_shape_tensor_x, batch_shape_x)

    # Check that we're robust to integer type.
    batch_shape_tensor_x64 = bijector.experimental_batch_shape_tensor(
        x_event_ndims=tf.nest.map_structure(np.int64, x_event_ndims))
    batch_shape_tensor_y64 = bijector.experimental_batch_shape_tensor(
        y_event_ndims=tf.nest.map_structure(np.int64, y_event_ndims))
    self.assertAllEqual(batch_shape_tensor_x64, batch_shape_tensor_y64)
    self.assertAllEqual(batch_shape_tensor_x64, batch_shape_x)

    # Pushing a value through the bijector should return a Tensor(s) with
    # the expected batch shape...
    xs = tf.nest.map_structure(lambda nd: tf.ones([1] * nd), x_event_ndims)
    ys = bijector.forward(xs)
    for y_part, nd in zip(tf.nest.flatten(ys), tf.nest.flatten(y_event_ndims)):
      part_batch_shape = ps.shape(y_part)[:ps.rank(y_part) - nd]
      self.assertAllEqual(batch_shape_y,
                          ps.broadcast_shape(batch_shape_y, part_batch_shape))

    # ... which should also be the shape of the fldj.
    fldj = bijector.forward_log_det_jacobian(xs, event_ndims=x_event_ndims)
    self.assertAllEqual(batch_shape_y, ps.shape(fldj))

    # Also check the inverse case.
    xs = bijector.inverse(tf.nest.map_structure(tf.identity, ys))
    for x_part, nd in zip(tf.nest.flatten(xs), tf.nest.flatten(x_event_ndims)):
      part_batch_shape = ps.shape(x_part)[:ps.rank(x_part) - nd]
      self.assertAllEqual(batch_shape_x,
                          ps.broadcast_shape(batch_shape_x, part_batch_shape))
    ildj = bijector.inverse_log_det_jacobian(ys, event_ndims=y_event_ndims)
    self.assertAllEqual(batch_shape_x, ps.shape(ildj))

  @parameterized.named_parameters(
      ('scale', lambda: tfb.Scale([3.14159])),
      ('chain', lambda: tfb.Exp()(tfb.Scale([3.14159]))))
  def test_ndims_specification(self, bijector_fn):
    bijector = bijector_fn()

    # If no `event_ndims` is passed, should assume min_event_ndims.
    self.assertAllEqual(bijector.experimental_batch_shape(), [1])
    self.assertAllEqual(bijector.experimental_batch_shape_tensor(), [1])

    with self.assertRaisesRegex(
        ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'):
      bijector.experimental_batch_shape(x_event_ndims=0, y_event_ndims=0)

    with  self.assertRaisesRegex(
        ValueError, 'Only one of `x_event_ndims` and `y_event_ndims`'):
      bijector.experimental_batch_shape_tensor(x_event_ndims=0, y_event_ndims=0)

  @parameterized.named_parameters(
      ('scale', lambda: tfb.Scale(tf.ones([4, 2])), None),
      ('sigmoid', lambda: tfb.Sigmoid(low=tf.zeros([3]), high=tf.ones([4, 1])),
       None),
      ('invert', lambda: tfb.Invert(tfb.ScaleMatvecDiag(tf.ones([2, 1]))),
       None),
      ('chain',
       lambda: tfb.Chain([tfb.Power(power=[[2.], [3.]]),  # pylint: disable=g-long-lambda
                          tfb.Invert(tfb.Split(2))]),
       None),
      ('jointmap_01', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [0, 1]),
      ('jointmap_11', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [1, 1]),
      ('jointmap_20', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 0]),
      ('jointmap_22', lambda: tfb.JointMap(  # pylint: disable=g-long-lambda
          [tfb.Scale([5, 3]), tfb.Scale([1, 4])]), [2, 2]),
      ('nested_jointmap',
       lambda: tfb.JointMap([tfb.JointMap({'a': tfb.Scale([1.]),  # pylint: disable=g-long-lambda
                                           'b': tfb.Exp()}),
                             tfb.Scale([1, 4])(tfb.Invert(tfb.Split(2)))]),
       [{'a': 0, 'b': 0}, [2, 2]]))
  def test_with_broadcast_batch_shape(self, bijector_fn, x_event_ndims=None):
    bijector = bijector_fn()
    if x_event_ndims is None:
      x_event_ndims = bijector.forward_min_event_ndims
    batch_shape = bijector.experimental_batch_shape(x_event_ndims=x_event_ndims)
    param_batch_shapes = batch_shape_lib.batch_shape_parts(
        bijector, bijector_x_event_ndims=x_event_ndims)

    new_batch_shape = [4, 2, 1, 1, 1]
    broadcast_bijector = bijector._broadcast_parameters_with_batch_shape(
        new_batch_shape, x_event_ndims)
    broadcast_batch_shape = broadcast_bijector.experimental_batch_shape_tensor(
        x_event_ndims=x_event_ndims)
    self.assertAllEqual(broadcast_batch_shape,
                        ps.broadcast_shape(batch_shape, new_batch_shape))

    # Check that all params have the expected batch shape.
    broadcast_param_batch_shapes = batch_shape_lib.batch_shape_parts(
        broadcast_bijector, bijector_x_event_ndims=x_event_ndims)

    def _maybe_broadcast_param_batch_shape(p, s):
      if isinstance(p, tfb.Invert) and not p.bijector._params_event_ndims():
        return s  # Can't broadcast a bijector that doesn't itself have params.
      return ps.broadcast_shape(s, new_batch_shape)
    expected_broadcast_param_batch_shapes = tf.nest.map_structure(
        _maybe_broadcast_param_batch_shape,
        {param: getattr(bijector, param) for param in param_batch_shapes},
        param_batch_shapes)
    self.assertAllEqualNested(broadcast_param_batch_shapes,
                              expected_broadcast_param_batch_shapes)
Пример #25
0
class BatchShapeInferenceTests(test_util.TestCase):

  @parameterized.named_parameters(
      {'testcase_name': '_trivial',
       'value_fn': lambda: tfd.Normal(loc=0., scale=1.),
       'expected_batch_shape': []},
      {'testcase_name': '_simple_tensor_broadcasting',
       'value_fn': lambda: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
           loc=[0., 0.], scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])),
       'expected_batch_shape': [2]},
      {'testcase_name': '_rank_deficient_tensor_broadcasting',
       'value_fn': lambda: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
           loc=0., scale_diag=tf.convert_to_tensor([[1., 1.], [1., 1.]])),
       'expected_batch_shape': [2]},
      {'testcase_name': '_mixture_same_family',
       'value_fn': lambda: tfd.MixtureSameFamily(  # pylint: disable=g-long-lambda
           mixture_distribution=tfd.Categorical(
               logits=[[[1., 2., 3.],
                        [4., 5., 6.]]]),
           components_distribution=tfd.Normal(loc=0.,
                                              scale=[[[1., 2., 3.],
                                                      [4., 5., 6.]]])),
       'expected_batch_shape': [1, 2]},
      {'testcase_name': '_deeply_nested',
       'value_fn': lambda: tfd.Independent(  # pylint: disable=g-long-lambda
           tfd.Independent(
               tfd.Independent(
                   tfd.Independent(
                       tfd.Normal(loc=0., scale=[[[[[[[[1.]]]]]]]]),
                       reinterpreted_batch_ndims=2),
                   reinterpreted_batch_ndims=0),
               reinterpreted_batch_ndims=1),
           reinterpreted_batch_ndims=1),
       'expected_batch_shape': [1, 1, 1, 1]})
  def test_batch_shape_inference_is_correct(
      self, value_fn, expected_batch_shape):
    value = value_fn()  # Defer construction until we're in the right graph.
    self.assertAllEqual(
        expected_batch_shape,
        value.batch_shape_tensor())

    batch_shape = value.batch_shape
    self.assertIsInstance(batch_shape, tf.TensorShape)
    self.assertTrue(
        batch_shape.is_compatible_with(expected_batch_shape))

  def assert_all_parameters_have_full_batch_shape(
      self, dist, expected_batch_shape):
    self.assertAllEqual(expected_batch_shape, dist.batch_shape_tensor())
    param_batch_shapes = batch_shape_lib.batch_shape_parts(dist)
    for param_batch_shape in param_batch_shapes.values():
      self.assertAllEqual(expected_batch_shape, param_batch_shape)

  @parameterized.named_parameters(
      {'testcase_name': '_trivial',
       'dist_fn': lambda: tfd.Normal(loc=0., scale=1.)},
      {'testcase_name': '_simple_tensor_broadcasting',
       'dist_fn': lambda: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
           loc=[0., 0.],
           scale_diag=[[1., 1.], [1., 1.]])},
      {'testcase_name': '_rank_deficient_tensor_broadcasting',
       'dist_fn': lambda: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
           loc=0.,
           scale_diag=[[1., 1.], [1., 1.]])},
      {'testcase_name': '_deeply_nested',
       'dist_fn': lambda: tfd.Independent(  # pylint: disable=g-long-lambda
           tfd.Independent(
               tfd.Independent(
                   tfd.Independent(
                       tfd.Normal(loc=0.,
                                  scale=[[[[[[[[1.]]]]]]]]),
                       reinterpreted_batch_ndims=2),
                   reinterpreted_batch_ndims=0),
               reinterpreted_batch_ndims=1),
           reinterpreted_batch_ndims=1)},
      {'testcase_name': '_transformed_dist_simple',
       'dist_fn': lambda: tfd.TransformedDistribution(  # pylint: disable=g-long-lambda
           tfd.Normal(loc=[[1., 2., 3.], [3., 4., 5.]], scale=[1.]),
           tfb.Scale(scale=[2., 3., 4.]))},
      {'testcase_name': '_transformed_dist_with_chain',
       'dist_fn': lambda: tfd.TransformedDistribution(  # pylint: disable=g-long-lambda
           tfd.Normal(loc=[[1., 2., 3.], [3., 4., 5.]], scale=[1.]),
           tfb.Shift(-4.)(tfb.Scale(scale=[2., 3., 4.])))},
      {'testcase_name': '_transformed_dist_multipart_nested',
       'dist_fn': lambda: tfd.TransformedDistribution(  # pylint: disable=g-long-lambda
           tfd.TransformedDistribution(
               tfd.TransformedDistribution(
                   tfd.MultivariateNormalDiag(tf.zeros([4, 6]), tf.ones([6])),
                   tfb.Split([3, 3])),
               tfb.JointMap([tfb.Identity(), tfb.Reshape([3, 1])])),
           tfb.JointMap([tfb.Scale(scale=[2., 3., 4.]), tfb.Shift(1.)]))}
      )
  def test_batch_broadcasting(self, dist_fn):
    dist = dist_fn()
    broadcast_dist = dist._broadcast_parameters_with_batch_shape(
        dist.batch_shape)
    self.assert_all_parameters_have_full_batch_shape(
        broadcast_dist,
        expected_batch_shape=broadcast_dist.batch_shape_tensor())

    expanded_batch_shape = ps.concat([[7, 4], dist.batch_shape], axis=0)
    broadcast_params = batch_shape_lib.broadcast_parameters_with_batch_shape(
        dist, expanded_batch_shape)
    broadcast_dist = dist.copy(**broadcast_params)
    self.assert_all_parameters_have_full_batch_shape(
        broadcast_dist,
        expected_batch_shape=expanded_batch_shape)
Пример #26
0
 def _testAssertRaisesNegativeSplitSizes(self):
     split_sizes = self.build_input([-2, 3, 5])
     with self.assertRaisesError('must be either non-negative integers or'):
         bijector = tfb.Split(split_sizes, validate_args=True)
         self.evaluate(bijector.forward(tf.zeros((4, 10))))
Пример #27
0
 def _testAssertRaisesMultipleUnknownSplitSizes(self):
     split_sizes = self.build_input([-1, 4, -1, 8])
     with self.assertRaisesError('must have at most one'):
         bijector = tfb.Split(split_sizes, validate_args=True)
         self.evaluate(bijector.forward(tf.zeros((3, 14))))
Пример #28
0
class MarkovChainBijectorTest(test_util.TestCase):

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        dict(testcase_name='deterministic_prior',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)),
        dict(testcase_name='deterministic_transition',
             prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='fully_deterministic',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='mvn_diag',
             prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]],
                                                          scale_diag=[1.])),
             transition_fn=lambda _, x: tfd.VectorDeterministic(x)),
        dict(testcase_name='docstring_dirichlet',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 {'probs': tfd.Dirichlet([1., 1.])}),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 {
                     'probs':
                     tfd.MultivariateNormalDiag(loc=x['probs'],
                                                scale_diag=[0.1, 0.1])
                 },
                 batch_ndims=ps.rank(x['probs']))),
        dict(testcase_name='uniform_step',
             prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])),
             transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)),
        dict(testcase_name='joint_distribution',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=2,
                 model={
                     'a':
                     tfd.Gamma(tf.zeros([5]), 1.),
                     'b':
                     lambda a: (tfb.Reshape(event_shape_in=[4, 3],
                                            event_shape_out=[2, 3, 2])
                                (tfd.Independent(tfd.Normal(
                                    loc=tf.zeros([5, 4, 3]),
                                    scale=a[..., tf.newaxis, tf.newaxis]),
                                                 reinterpreted_batch_ndims=2)))
                 }),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=ps.rank_from_shape(x['a'].shape),
                 model={
                     'a':
                     tfd.Normal(loc=x['a'], scale=1.),
                     'b':
                     lambda a: tfd.Deterministic(x['b'] + a[
                         ..., tf.newaxis, tf.newaxis, tf.newaxis])
                 })),
        dict(testcase_name='nested_chain',
             prior_fn=lambda: tfd.
             MarkovChain(initial_state_prior=tfb.Split(2)
                         (tfd.MultivariateNormalDiag(0., [1., 2.])),
                         transition_fn=lambda _, x: tfb.Split(2)
                         (tfd.MultivariateNormalDiag(x[0], [1., 2.])),
                         num_steps=6),
             transition_fn=(
                 lambda _, x: tfd.JointDistributionSequentialAutoBatched(
                     [
                         tfd.MultivariateNormalDiag(x[0], [1.]),
                         tfd.MultivariateNormalDiag(x[1], [1.])
                     ],
                     batch_ndims=ps.rank(x[0])))))
    # pylint: enable=g-long-lambda
    def test_default_bijector(self, prior_fn, transition_fn):
        chain = tfd.MarkovChain(initial_state_prior=prior_fn(),
                                transition_fn=transition_fn,
                                num_steps=7)

        y = self.evaluate(chain.sample(seed=test_util.test_seed()))
        bijector = chain.experimental_default_event_space_bijector()

        self.assertAllEqual(chain.batch_shape_tensor(),
                            bijector.experimental_batch_shape_tensor())

        x = bijector.inverse(y)
        yy = bijector.forward(tf.nest.map_structure(
            tf.identity, x))  # Bypass bijector cache.
        self.assertAllCloseNested(y, yy)

        chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape,
                                                  chain.event_shape_tensor())
        self.assertAllEqualNested(bijector.inverse_min_event_ndims,
                                  chain_event_ndims)

        ildj = bijector.inverse_log_det_jacobian(
            tf.nest.map_structure(tf.identity, y),  # Bypass bijector cache.
            event_ndims=chain_event_ndims)
        if not bijector.is_constant_jacobian:
            self.assertAllEqual(ildj.shape, chain.batch_shape)
        fldj = bijector.forward_log_det_jacobian(
            tf.nest.map_structure(tf.identity, x),  # Bypass bijector cache.
            event_ndims=bijector.inverse_event_ndims(chain_event_ndims))
        self.assertAllClose(ildj, -fldj)

        # Verify that event shapes are passed through and flattened/unflattened
        # correctly.
        inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape)
        x_event_shapes = tf.nest.map_structure(
            lambda t, nd: t.shape[ps.rank(t) - nd:], x,
            bijector.forward_min_event_ndims)
        self.assertAllEqualNested(inverse_event_shapes, x_event_shapes)
        forward_event_shapes = bijector.forward_event_shape(
            inverse_event_shapes)
        self.assertAllEqualNested(forward_event_shapes, chain.event_shape)

        # Verify that the outputs of other methods have the correct structure.
        inverse_event_shape_tensors = bijector.inverse_event_shape_tensor(
            chain.event_shape_tensor())
        self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes)
        forward_event_shape_tensors = bijector.forward_event_shape_tensor(
            inverse_event_shape_tensors)
        self.assertAllEqualNested(forward_event_shape_tensors,
                                  chain.event_shape_tensor())
Пример #29
0
 def testAssertRaisesNumSplitsNonDivisible(self):
     num_splits = 3
     x = np.random.rand(4, 5, 6)
     bijector = tfb.Split(num_splits, axis=-2, validate_args=True)
     with self.assertRaisesRegexp(ValueError, 'number of splits'):
         self.evaluate(bijector.forward(x))
Пример #30
0
 def testAssertRaisesNonVectorSplitSizes(self):
     split_sizes = self.build_input([[1, 2, 2]])
     with self.assertRaisesRegexp(ValueError, 'must be an integer or 1-D'):
         tfb.Split(split_sizes, validate_args=True)