Exemple #1
0
 def test_single_param_multi_ellipsis(self):
   with self.assertRaisesRegexp(ValueError, 'Found multiple `...`'):
     slicing._slice_single_param(
         tf.zeros([7, 6, 5, 4, 3]),
         param_event_ndims=2,
         slices=make_slices[:, ..., 2, ...],
         batch_shape=tf.constant([7, 6, 5]))
Exemple #2
0
 def test_single_param_too_many_slices(self):
     with self.assertRaises(
         (IndexError, ValueError, tf.errors.InvalidArgumentError)):
         slicing._slice_single_param(tf.zeros([7, 6, 5, 4, 3]),
                                     param_event_ndims=2,
                                     slices=make_slices[:, :3, ..., -2:, :],
                                     batch_shape=tf.constant([7, 6, 5]))
Exemple #3
0
 def test_single_param_slice_start_broadcastdim(self):
   sliced = slicing._slice_single_param(
       tf.zeros([7, 1, 5, 4, 3]),
       param_event_ndims=2,
       slices=make_slices[:, 2:],
       batch_shape=tf.constant([7, 6, 5]))
   self.assertAllEqual((7, 1, 5, 4, 3), self.evaluate(sliced).shape)
Exemple #4
0
 def test_single_param_slice_newaxis_trailing(self):
   sliced = slicing._slice_single_param(
       tf.zeros([7, 6, 5, 4, 3]),
       param_event_ndims=2,
       slices=make_slices[..., tf.newaxis, :],
       batch_shape=tf.constant([7, 6, 5]))
   self.assertAllEqual((7, 6, 1, 5, 4, 3), self.evaluate(sliced).shape)
Exemple #5
0
 def test_single_param_slice_stop_leadingdim(self):
   sliced = slicing._slice_single_param(
       tf.zeros([7, 6, 5, 4, 3]),
       param_event_ndims=2,
       slices=make_slices[:2],
       batch_shape=tf.constant([7, 6, 5], dtype=tf.int32))
   self.assertAllEqual((2, 6, 5, 4, 3), self.evaluate(sliced).shape)
Exemple #6
0
 def test_slice_single_param_atomic(self):
   sliced = slicing._slice_single_param(
       tfb.Identity(),
       param_event_ndims=0,
       slices=make_slices[..., tf.newaxis, 2:, tf.newaxis],
       batch_shape=tf.constant([7, 4, 3]))
   self.assertAllEqual([], sliced.experimental_batch_shape_tensor())
Exemple #7
0
 def test_single_param_slice_withstep_broadcastdim(self):
   event_dim = 3
   sliced = slicing._slice_single_param(
       tf.zeros([1, 1, event_dim]),
       param_event_ndims=1,
       slices=make_slices[44:-52:-3, -94::],
       batch_shape=tf.constant([2, 7], dtype=tf.int32))
   self.assertAllEqual((1, 1, event_dim), self.evaluate(sliced).shape)
Exemple #8
0
 def test_slice_single_param_distribution(self):
   sliced = slicing._slice_single_param(
       tfd.Normal(loc=tf.zeros([4, 3, 1]),  # batch = [4, 3], event = [2]
                  scale=tf.ones([2])),
       param_event_ndims=1,
       slices=make_slices[..., tf.newaxis, 2:, tf.newaxis],
       batch_shape=tf.constant([7, 4, 3]))
   self.assertAllEqual(
       list(tf.zeros([1, 4, 3])[..., tf.newaxis, 2:, tf.newaxis].shape),
       sliced.batch_shape_tensor()[:-1])
Exemple #9
0
 def test_slice_single_param_bijector_composition(self):
   sliced = slicing._slice_single_param(
       tfb.JointMap({'a': tfb.Chain([
           tfb.Invert(tfb.Scale(tf.ones([4, 3, 1])))
       ])}),
       param_event_ndims={'a': 1},
       slices=make_slices[..., tf.newaxis, 2:, tf.newaxis],
       batch_shape=tf.constant([7, 4, 3]))
   self.assertAllEqual(
       list(tf.zeros([1, 4, 3])[..., tf.newaxis, 2:, tf.newaxis].shape),
       sliced.experimental_batch_shape_tensor(x_event_ndims={'a': 1}))
Exemple #10
0
 def test_single_param_slice_broadcast_batch(self):
   if not tf.executing_eagerly():
     return
   sliced = slicing._slice_single_param(
       tf.zeros([4, 3, 1]),  # batch = [4, 3], event = [1]
       param_event_ndims=1,
       slices=make_slices[..., tf.newaxis, 2:, tf.newaxis],
       batch_shape=tf.constant([7, 4, 3]))
   self.assertAllEqual(
       list(tf.zeros([1, 4, 3])[..., tf.newaxis, 2:, tf.newaxis].shape) + [1],
       self.evaluate(sliced).shape)
Exemple #11
0
 def test_single_param_slice_tensor_broadcastdim(self):
   param = tf1.placeholder_with_default(
       tf.zeros([7, 1, 5, 4, 3]), shape=None)
   idx = tf1.placeholder_with_default(
       tf.constant(2, dtype=tf.int32), shape=[])
   sliced = slicing._slice_single_param(
       param,
       param_event_ndims=2,
       slices=make_slices[:, idx],
       batch_shape=tf.constant([7, 6, 5]))
   self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape)
Exemple #12
0
 def test_single_param_slice_broadcast_batch_leading_newaxis(self):
   if not tf.executing_eagerly():
     return
   sliced = slicing._slice_single_param(
       tf.zeros([4, 3, 1]),  # batch = [4, 3], event = [1]
       param_event_ndims=1,
       slices=make_slices[tf.newaxis, ..., tf.newaxis, 2:, tf.newaxis],
       batch_shape=tf.constant([7, 4, 3]))
   expected = tensorshape_util.as_list((
       tf.zeros([1, 4, 3])[tf.newaxis, ..., tf.newaxis, 2:, tf.newaxis]
       ).shape) + [1]
   self.assertAllEqual(expected, self.evaluate(sliced).shape)