def test_additional_event_ndims(self): bij = tfb.Sigmoid(low=tf.zeros([2]), high=tf.ones([3, 2])) self.assertAllEqual(batch_shape_lib.inferred_batch_shape(bij), [3, 2]) self.assertAllEqual(batch_shape_lib.inferred_batch_shape_tensor(bij), [3, 2]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape(bij, additional_event_ndims=1), [3]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape_tensor( bij, additional_event_ndims=1), [3])
def test_bijector_event_ndims(self): bij = tfb.Sigmoid(low=tf.zeros([2]), high=tf.ones([3, 2])) self.assertAllEqual(batch_shape_lib.inferred_batch_shape(bij), [3, 2]) self.assertAllEqual(batch_shape_lib.inferred_batch_shape_tensor(bij), [3, 2]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape(bij, bijector_x_event_ndims=1), [3]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape_tensor( bij, bijector_x_event_ndims=1), [3]) # Verify that we don't pass Nones through to component # `experimental_batch_shape(x_event_ndims=None)` calls, where they'd be # incorrectly interpreted as `x_event_ndims=forward_min_event_ndims`. joint_bij = tfb.JointMap([bij, bij]) self.assertAllEqual( batch_shape_lib.inferred_batch_shape( joint_bij, bijector_x_event_ndims=[None, None]), tf.TensorShape(None))
def test_batch_shape_inference_is_correct(self, value_fn, expected_batch_shape_parts, expected_batch_shape): value = value_fn( ) # Defer construction until we're in the right graph. parts = batch_shape_lib.batch_shape_parts(value) self.assertAllEqualNested( parts, nest.map_structure_up_to(parts, tf.TensorShape, expected_batch_shape_parts)) self.assertAllEqual(expected_batch_shape, batch_shape_lib.inferred_batch_shape_tensor(value)) batch_shape = batch_shape_lib.inferred_batch_shape(value) self.assertIsInstance(batch_shape, tf.TensorShape) self.assertTrue(batch_shape.is_compatible_with(expected_batch_shape))
def _batch_shape(self, index_points=None): kwargs = {} if index_points is not None: kwargs = {'index_points': index_points} return batch_shape_lib.inferred_batch_shape(self, **kwargs)