def test_additional_event_ndims(self):
     bij = tfb.Sigmoid(low=tf.zeros([2]), high=tf.ones([3, 2]))
     self.assertAllEqual(batch_shape_lib.inferred_batch_shape(bij), [3, 2])
     self.assertAllEqual(batch_shape_lib.inferred_batch_shape_tensor(bij),
                         [3, 2])
     self.assertAllEqual(
         batch_shape_lib.inferred_batch_shape(bij,
                                              additional_event_ndims=1),
         [3])
     self.assertAllEqual(
         batch_shape_lib.inferred_batch_shape_tensor(
             bij, additional_event_ndims=1), [3])
Example #2
0
    def test_bijector_event_ndims(self):
        bij = tfb.Sigmoid(low=tf.zeros([2]), high=tf.ones([3, 2]))
        self.assertAllEqual(batch_shape_lib.inferred_batch_shape(bij), [3, 2])
        self.assertAllEqual(batch_shape_lib.inferred_batch_shape_tensor(bij),
                            [3, 2])
        self.assertAllEqual(
            batch_shape_lib.inferred_batch_shape(bij,
                                                 bijector_x_event_ndims=1),
            [3])
        self.assertAllEqual(
            batch_shape_lib.inferred_batch_shape_tensor(
                bij, bijector_x_event_ndims=1), [3])

        # Verify that we don't pass Nones through to component
        # `experimental_batch_shape(x_event_ndims=None)` calls, where they'd be
        # incorrectly interpreted as `x_event_ndims=forward_min_event_ndims`.
        joint_bij = tfb.JointMap([bij, bij])
        self.assertAllEqual(
            batch_shape_lib.inferred_batch_shape(
                joint_bij, bijector_x_event_ndims=[None, None]),
            tf.TensorShape(None))
Example #3
0
    def test_batch_shape_inference_is_correct(self, value_fn,
                                              expected_batch_shape_parts,
                                              expected_batch_shape):
        value = value_fn(
        )  # Defer construction until we're in the right graph.

        parts = batch_shape_lib.batch_shape_parts(value)
        self.assertAllEqualNested(
            parts,
            nest.map_structure_up_to(parts, tf.TensorShape,
                                     expected_batch_shape_parts))

        self.assertAllEqual(expected_batch_shape,
                            batch_shape_lib.inferred_batch_shape_tensor(value))

        batch_shape = batch_shape_lib.inferred_batch_shape(value)
        self.assertIsInstance(batch_shape, tf.TensorShape)
        self.assertTrue(batch_shape.is_compatible_with(expected_batch_shape))
 def _batch_shape(self, index_points=None):
     kwargs = {}
     if index_points is not None:
         kwargs = {'index_points': index_points}
     return batch_shape_lib.inferred_batch_shape(self, **kwargs)