Beispiel #1
0
 def test_single_param_multi_ellipsis(self):
     with self.assertRaisesRegexp(ValueError, 'Found multiple `...`'):
         slicing._slice_single_param(tf.zeros([7, 6, 5, 4, 3]),
                                     param_event_ndims=2,
                                     slices=make_slices[:, ..., 2, ...],
                                     dist_batch_shape=tf.constant([7, 6,
                                                                   5]))
Beispiel #2
0
 def test_single_param_too_many_slices(self):
   with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
     slicing._slice_single_param(
         tf.zeros([7, 6, 5, 4, 3]),
         param_event_ndims=2,
         slices=make_slices[:, :3, ..., -2:, :],
         dist_batch_shape=tf.constant([7, 6, 5]))
Beispiel #3
0
 def test_single_param_slice_start_broadcastdim(self):
     sliced = slicing._slice_single_param(tf.zeros([7, 1, 5, 4, 3]),
                                          param_event_ndims=2,
                                          slices=make_slices[:, 2:],
                                          dist_batch_shape=tf.constant(
                                              [7, 6, 5]))
     self.assertAllEqual((7, 1, 5, 4, 3), self.evaluate(sliced).shape)
Beispiel #4
0
 def test_single_param_slice_newaxis_trailing(self):
     sliced = slicing._slice_single_param(
         tf.zeros([7, 6, 5, 4, 3]),
         param_event_ndims=2,
         slices=make_slices[..., tf.newaxis, :],
         dist_batch_shape=tf.constant([7, 6, 5]))
     self.assertAllEqual((7, 6, 1, 5, 4, 3), self.evaluate(sliced).shape)
Beispiel #5
0
 def test_single_param_slice_stop_leadingdim(self):
     sliced = slicing._slice_single_param(tf.zeros([7, 6, 5, 4, 3]),
                                          param_event_ndims=2,
                                          slices=make_slices[:2],
                                          dist_batch_shape=tf.constant(
                                              [7, 6, 5], dtype=tf.int32))
     self.assertAllEqual((2, 6, 5, 4, 3), self.evaluate(sliced).shape)
Beispiel #6
0
 def test_single_param_slice_withstep_broadcastdim(self):
     event_dim = 3
     sliced = slicing._slice_single_param(
         tf.zeros([1, 1, event_dim]),
         param_event_ndims=1,
         slices=make_slices[44:-52:-3, -94::],
         dist_batch_shape=tf.constant([2, 7], dtype=tf.int32))
     self.assertAllEqual((1, 1, event_dim), self.evaluate(sliced).shape)
Beispiel #7
0
 def test_single_param_slice_tensor_broadcastdim(self):
     param = tf1.placeholder_with_default(tf.zeros([7, 1, 5, 4, 3]),
                                          shape=None)
     idx = tf1.placeholder_with_default(tf.constant(2, dtype=tf.int32),
                                        shape=[])
     sliced = slicing._slice_single_param(param,
                                          param_event_ndims=2,
                                          slices=make_slices[:, idx],
                                          dist_batch_shape=tf.constant(
                                              [7, 6, 5]))
     self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape)
Beispiel #8
0
 def test_single_param_slice_broadcast_batch(self):
   if not tf.executing_eagerly():
     return
   sliced = slicing._slice_single_param(
       tf.zeros([4, 3, 1]),  # batch = [4, 3], event = [1]
       param_event_ndims=1,
       slices=make_slices[..., tf.newaxis, 2:, tf.newaxis],
       dist_batch_shape=tf.constant([7, 4, 3]))
   self.assertAllEqual(
       list(tf.zeros([1, 4, 3])[..., tf.newaxis, 2:, tf.newaxis].shape) + [1],
       self.evaluate(sliced).shape)
Beispiel #9
0
 def test_single_param_slice_broadcast_batch_leading_newaxis(self):
     if not tf.executing_eagerly():
         return
     sliced = slicing._slice_single_param(
         tf.zeros([4, 3, 1]),  # batch = [4, 3], event = [1]
         param_event_ndims=1,
         slices=make_slices[tf.newaxis, ..., tf.newaxis, 2:, tf.newaxis],
         dist_batch_shape=tf.constant([7, 4, 3]))
     expected = tensorshape_util.as_list(
         (tf.zeros([1, 4, 3])[tf.newaxis, ..., tf.newaxis, 2:,
                              tf.newaxis]).shape) + [1]
     self.assertAllEqual(expected, self.evaluate(sliced).shape)