def test_single_param_multi_ellipsis(self): with self.assertRaisesRegexp(ValueError, 'Found multiple `...`'): slicing._slice_single_param( tf.zeros([7, 6, 5, 4, 3]), param_event_ndims=2, slices=make_slices[:, ..., 2, ...], batch_shape=tf.constant([7, 6, 5]))
def test_single_param_too_many_slices(self): with self.assertRaises( (IndexError, ValueError, tf.errors.InvalidArgumentError)): slicing._slice_single_param(tf.zeros([7, 6, 5, 4, 3]), param_event_ndims=2, slices=make_slices[:, :3, ..., -2:, :], batch_shape=tf.constant([7, 6, 5]))
def test_single_param_slice_start_broadcastdim(self): sliced = slicing._slice_single_param( tf.zeros([7, 1, 5, 4, 3]), param_event_ndims=2, slices=make_slices[:, 2:], batch_shape=tf.constant([7, 6, 5])) self.assertAllEqual((7, 1, 5, 4, 3), self.evaluate(sliced).shape)
def test_single_param_slice_newaxis_trailing(self): sliced = slicing._slice_single_param( tf.zeros([7, 6, 5, 4, 3]), param_event_ndims=2, slices=make_slices[..., tf.newaxis, :], batch_shape=tf.constant([7, 6, 5])) self.assertAllEqual((7, 6, 1, 5, 4, 3), self.evaluate(sliced).shape)
def test_single_param_slice_stop_leadingdim(self): sliced = slicing._slice_single_param( tf.zeros([7, 6, 5, 4, 3]), param_event_ndims=2, slices=make_slices[:2], batch_shape=tf.constant([7, 6, 5], dtype=tf.int32)) self.assertAllEqual((2, 6, 5, 4, 3), self.evaluate(sliced).shape)
def test_slice_single_param_atomic(self): sliced = slicing._slice_single_param( tfb.Identity(), param_event_ndims=0, slices=make_slices[..., tf.newaxis, 2:, tf.newaxis], batch_shape=tf.constant([7, 4, 3])) self.assertAllEqual([], sliced.experimental_batch_shape_tensor())
def test_single_param_slice_withstep_broadcastdim(self): event_dim = 3 sliced = slicing._slice_single_param( tf.zeros([1, 1, event_dim]), param_event_ndims=1, slices=make_slices[44:-52:-3, -94::], batch_shape=tf.constant([2, 7], dtype=tf.int32)) self.assertAllEqual((1, 1, event_dim), self.evaluate(sliced).shape)
def test_slice_single_param_distribution(self): sliced = slicing._slice_single_param( tfd.Normal(loc=tf.zeros([4, 3, 1]), # batch = [4, 3], event = [2] scale=tf.ones([2])), param_event_ndims=1, slices=make_slices[..., tf.newaxis, 2:, tf.newaxis], batch_shape=tf.constant([7, 4, 3])) self.assertAllEqual( list(tf.zeros([1, 4, 3])[..., tf.newaxis, 2:, tf.newaxis].shape), sliced.batch_shape_tensor()[:-1])
def test_slice_single_param_bijector_composition(self): sliced = slicing._slice_single_param( tfb.JointMap({'a': tfb.Chain([ tfb.Invert(tfb.Scale(tf.ones([4, 3, 1]))) ])}), param_event_ndims={'a': 1}, slices=make_slices[..., tf.newaxis, 2:, tf.newaxis], batch_shape=tf.constant([7, 4, 3])) self.assertAllEqual( list(tf.zeros([1, 4, 3])[..., tf.newaxis, 2:, tf.newaxis].shape), sliced.experimental_batch_shape_tensor(x_event_ndims={'a': 1}))
def test_single_param_slice_broadcast_batch(self): if not tf.executing_eagerly(): return sliced = slicing._slice_single_param( tf.zeros([4, 3, 1]), # batch = [4, 3], event = [1] param_event_ndims=1, slices=make_slices[..., tf.newaxis, 2:, tf.newaxis], batch_shape=tf.constant([7, 4, 3])) self.assertAllEqual( list(tf.zeros([1, 4, 3])[..., tf.newaxis, 2:, tf.newaxis].shape) + [1], self.evaluate(sliced).shape)
def test_single_param_slice_tensor_broadcastdim(self): param = tf1.placeholder_with_default( tf.zeros([7, 1, 5, 4, 3]), shape=None) idx = tf1.placeholder_with_default( tf.constant(2, dtype=tf.int32), shape=[]) sliced = slicing._slice_single_param( param, param_event_ndims=2, slices=make_slices[:, idx], batch_shape=tf.constant([7, 6, 5])) self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape)
def test_single_param_slice_broadcast_batch_leading_newaxis(self): if not tf.executing_eagerly(): return sliced = slicing._slice_single_param( tf.zeros([4, 3, 1]), # batch = [4, 3], event = [1] param_event_ndims=1, slices=make_slices[tf.newaxis, ..., tf.newaxis, 2:, tf.newaxis], batch_shape=tf.constant([7, 4, 3])) expected = tensorshape_util.as_list(( tf.zeros([1, 4, 3])[tf.newaxis, ..., tf.newaxis, 2:, tf.newaxis] ).shape) + [1] self.assertAllEqual(expected, self.evaluate(sliced).shape)