def test_single_param_multi_ellipsis(self): with self.assertRaisesRegexp(ValueError, 'Found multiple `...`'): slicing._slice_single_param(array_ops.zeros([7, 6, 5, 4, 3]), param_ndims_to_matrix_ndims=2, slices=make_slices[:, ..., 2, ...], batch_shape=constant_op.constant( [7, 6, 5]))
def test_single_param_too_many_slices(self): with self.assertRaises( (IndexError, ValueError, errors.InvalidArgumentError)): slicing._slice_single_param(array_ops.zeros([7, 6, 5, 4, 3]), param_ndims_to_matrix_ndims=2, slices=make_slices[:, :3, ..., -2:, :], batch_shape=constant_op.constant( [7, 6, 5]))
def test_single_param_slice_start_broadcastdim(self): sliced = slicing._slice_single_param(array_ops.zeros([7, 1, 5, 4, 3]), param_ndims_to_matrix_ndims=2, slices=make_slices[:, 2:], batch_shape=constant_op.constant( [7, 6, 5])) self.assertAllEqual((7, 1, 5, 4, 3), self.evaluate(sliced).shape)
def test_single_param_slice_newaxis_trailing(self): sliced = slicing._slice_single_param( array_ops.zeros([7, 6, 5, 4, 3]), param_ndims_to_matrix_ndims=2, slices=make_slices[..., array_ops.newaxis, :], batch_shape=constant_op.constant([7, 6, 5])) self.assertAllEqual((7, 6, 1, 5, 4, 3), self.evaluate(sliced).shape)
def test_single_param_slice_stop_leadingdim(self): sliced = slicing._slice_single_param( array_ops.zeros([7, 6, 5, 4, 3]), param_ndims_to_matrix_ndims=2, slices=make_slices[:2], batch_shape=constant_op.constant([7, 6, 5], dtype=dtypes.int32)) self.assertAllEqual((2, 6, 5, 4, 3), self.evaluate(sliced).shape)
def test_single_param_slice_withstep_broadcastdim(self): event_dim = 3 sliced = slicing._slice_single_param( array_ops.zeros([1, 1, event_dim]), param_ndims_to_matrix_ndims=1, slices=make_slices[44:-52:-3, -94::], batch_shape=constant_op.constant([2, 7], dtype=dtypes.int32)) self.assertAllEqual((1, 1, event_dim), self.evaluate(sliced).shape)
def test_single_param_slice_broadcast_batch_leading_newaxis(self): sliced = slicing._slice_single_param( array_ops.zeros([4, 3, 1]), # batch = [4, 3], event = [1] param_ndims_to_matrix_ndims=1, slices=make_slices[array_ops.newaxis, ..., array_ops.newaxis, 2:, array_ops.newaxis], batch_shape=constant_op.constant([7, 4, 3])) expected = array_ops.zeros( [1, 4, 3])[array_ops.newaxis, ..., array_ops.newaxis, 2:, array_ops.newaxis].shape + [1] self.assertAllEqual(expected, self.evaluate(sliced).shape)
def test_single_param_slice_tensor_broadcastdim(self): param = array_ops.placeholder_with_default(array_ops.zeros( [7, 1, 5, 4, 3]), shape=None) idx = array_ops.placeholder_with_default(constant_op.constant( 2, dtype=dtypes.int32), shape=[]) sliced = slicing._slice_single_param(param, param_ndims_to_matrix_ndims=2, slices=make_slices[:, idx], batch_shape=constant_op.constant( [7, 6, 5])) self.assertAllEqual((7, 5, 4, 3), self.evaluate(sliced).shape)