예제 #1
0
 def test_single_param_multi_ellipsis(self):
     with self.assertRaisesRegexp(ValueError, 'Found multiple `...`'):
         slicing._slice_single_param(array_ops.zeros([7, 6, 5, 4, 3]),
                                     param_ndims_to_matrix_ndims=2,
                                     slices=make_slices[:, ..., 2, ...],
                                     batch_shape=constant_op.constant(
                                         [7, 6, 5]))
예제 #2
0
 def test_single_param_too_many_slices(self):
     with self.assertRaises(
         (IndexError, ValueError, errors.InvalidArgumentError)):
         slicing._slice_single_param(array_ops.zeros([7, 6, 5, 4, 3]),
                                     param_ndims_to_matrix_ndims=2,
                                     slices=make_slices[:, :3, ..., -2:, :],
                                     batch_shape=constant_op.constant(
                                         [7, 6, 5]))
예제 #3
0
 def test_single_param_slice_start_broadcastdim(self):
     sliced = slicing._slice_single_param(array_ops.zeros([7, 1, 5, 4, 3]),
                                          param_ndims_to_matrix_ndims=2,
                                          slices=make_slices[:, 2:],
                                          batch_shape=constant_op.constant(
                                              [7, 6, 5]))
     self.assertAllEqual((7, 1, 5, 4, 3), self.evaluate(sliced).shape)
예제 #4
0
 def test_single_param_slice_newaxis_trailing(self):
     sliced = slicing._slice_single_param(
         array_ops.zeros([7, 6, 5, 4, 3]),
         param_ndims_to_matrix_ndims=2,
         slices=make_slices[..., array_ops.newaxis, :],
         batch_shape=constant_op.constant([7, 6, 5]))
     self.assertAllEqual((7, 6, 1, 5, 4, 3), self.evaluate(sliced).shape)
예제 #5
0
 def test_single_param_slice_stop_leadingdim(self):
     sliced = slicing._slice_single_param(
         array_ops.zeros([7, 6, 5, 4, 3]),
         param_ndims_to_matrix_ndims=2,
         slices=make_slices[:2],
         batch_shape=constant_op.constant([7, 6, 5], dtype=dtypes.int32))
     self.assertAllEqual((2, 6, 5, 4, 3), self.evaluate(sliced).shape)
예제 #6
0
 def test_single_param_slice_withstep_broadcastdim(self):
     event_dim = 3
     sliced = slicing._slice_single_param(
         array_ops.zeros([1, 1, event_dim]),
         param_ndims_to_matrix_ndims=1,
         slices=make_slices[44:-52:-3, -94::],
         batch_shape=constant_op.constant([2, 7], dtype=dtypes.int32))
     self.assertAllEqual((1, 1, event_dim), self.evaluate(sliced).shape)
예제 #7
0
 def test_single_param_slice_broadcast_batch_leading_newaxis(self):
     sliced = slicing._slice_single_param(
         array_ops.zeros([4, 3, 1]),  # batch = [4, 3], event = [1]
         param_ndims_to_matrix_ndims=1,
         slices=make_slices[array_ops.newaxis, ..., array_ops.newaxis, 2:,
                            array_ops.newaxis],
         batch_shape=constant_op.constant([7, 4, 3]))
     expected = array_ops.zeros(
         [1, 4, 3])[array_ops.newaxis, ..., array_ops.newaxis, 2:,
                    array_ops.newaxis].shape + [1]
     self.assertAllEqual(expected, self.evaluate(sliced).shape)
예제 #8
0
 def test_single_param_slice_tensor_broadcastdim(self):
     param = array_ops.placeholder_with_default(array_ops.zeros(
         [7, 1, 5, 4, 3]),
                                                shape=None)
     idx = array_ops.placeholder_with_default(constant_op.constant(
         2, dtype=dtypes.int32),
                                              shape=[])
     sliced = slicing._slice_single_param(param,
                                          param_ndims_to_matrix_ndims=2,
                                          slices=make_slices[:, idx],
                                          batch_shape=constant_op.constant(
                                              [7, 6, 5]))
     self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape)