コード例 #1
0
 def test_dense_2d_3d_squeezable(self):
     x = constant_op.constant([[1, 2], [3, 4]])
     y = constant_op.constant([[[1], [2]], [[3], [4]]])
     _, y_p = losses_utils.remove_squeezable_dimensions(x, y)
     y_p.shape.assert_is_compatible_with(x.shape)
     self.assertEqual(y_p.shape.ndims, x.shape.ndims)
     x_p, _ = losses_utils.remove_squeezable_dimensions(y, x)
     x_p.shape.assert_is_compatible_with(x.shape)
コード例 #2
0
 def test_placeholder(self):
     """Test dynamic rank tensors."""
     with ops.Graph().as_default():
         x = array_ops.placeholder_with_default([1., 2., 3.], shape=None)
         y = array_ops.placeholder_with_default([[1.], [2.], [3.]],
                                                shape=None)
         _, y_p = losses_utils.remove_squeezable_dimensions(x, y)
         y_p.shape.assert_is_compatible_with(x.shape)
         self.assertAllEqual(array_ops.shape(x), array_ops.shape(y_p))
         x_p, _ = losses_utils.remove_squeezable_dimensions(y, x)
         x_p.shape.assert_is_compatible_with(x.shape)
コード例 #3
0
    def test_ragged_3d_4d_squeezable(self):
        """ shapes:

        x: (2, (sequence={1, 2}), 3)
        y: (2, (sequence={1, 2}), 3, 1)
    """
        x = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5, 6], [7, 8, 9]]])
        y = array_ops.expand_dims(x, axis=-1)
        self.assertEqual(x.shape.ndims, 3)
        self.assertEqual(y.shape.ndims, 4)
        _, y_p = losses_utils.remove_squeezable_dimensions(x, y)
        y_p.shape.assert_is_compatible_with(x.shape)
        self.assertEqual(y_p.shape.ndims, 3)

        x_p, _ = losses_utils.remove_squeezable_dimensions(y, x)
        x_p.shape.assert_is_compatible_with(x.shape)
        self.assertEqual(x_p.shape.ndims, 3)
コード例 #4
0
 def test_ragged_3d_same_shape(self):
     """ shape (2, (sequence={1, 2}), 3)"""
     x = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5, 6], [7, 8, 9]]])
     rank = x.shape.ndims
     x_p, _ = losses_utils.remove_squeezable_dimensions(x, x)
     self.assertEqual(x_p.shape.ndims, rank)