def test_dense_2d_3d_squeezable(self): x = constant_op.constant([[1, 2], [3, 4]]) y = constant_op.constant([[[1], [2]], [[3], [4]]]) _, y_p = losses_utils.remove_squeezable_dimensions(x, y) y_p.shape.assert_is_compatible_with(x.shape) self.assertEqual(y_p.shape.ndims, x.shape.ndims) x_p, _ = losses_utils.remove_squeezable_dimensions(y, x) x_p.shape.assert_is_compatible_with(x.shape)
def test_placeholder(self): """Test dynamic rank tensors.""" with ops.Graph().as_default(): x = array_ops.placeholder_with_default([1., 2., 3.], shape=None) y = array_ops.placeholder_with_default([[1.], [2.], [3.]], shape=None) _, y_p = losses_utils.remove_squeezable_dimensions(x, y) y_p.shape.assert_is_compatible_with(x.shape) self.assertAllEqual(array_ops.shape(x), array_ops.shape(y_p)) x_p, _ = losses_utils.remove_squeezable_dimensions(y, x) x_p.shape.assert_is_compatible_with(x.shape)
def test_ragged_3d_4d_squeezable(self): """ shapes: x: (2, (sequence={1, 2}), 3) y: (2, (sequence={1, 2}), 3, 1) """ x = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5, 6], [7, 8, 9]]]) y = array_ops.expand_dims(x, axis=-1) self.assertEqual(x.shape.ndims, 3) self.assertEqual(y.shape.ndims, 4) _, y_p = losses_utils.remove_squeezable_dimensions(x, y) y_p.shape.assert_is_compatible_with(x.shape) self.assertEqual(y_p.shape.ndims, 3) x_p, _ = losses_utils.remove_squeezable_dimensions(y, x) x_p.shape.assert_is_compatible_with(x.shape) self.assertEqual(x_p.shape.ndims, 3)
def test_ragged_3d_same_shape(self): """ shape (2, (sequence={1, 2}), 3)""" x = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5, 6], [7, 8, 9]]]) rank = x.shape.ndims x_p, _ = losses_utils.remove_squeezable_dimensions(x, x) self.assertEqual(x_p.shape.ndims, rank)