Esempio n. 1
0
  def test_is_static(self, tensor_shape, is_static):
    """Tests that is_static correctly checks if shape is static."""
    if tf.executing_eagerly():
      return
    tensor = tf.compat.v1.placeholder(shape=tensor_shape, dtype=tf.float32)

    with self.subTest(name="tensor_shape_is_list"):
      self.assertEqual(shape.is_static(tensor_shape), is_static)

    with self.subTest(name="tensor_shape"):
      self.assertEqual(shape.is_static(tensor.shape), is_static)
Esempio n. 2
0
 def test_compare_batch_dimensions_raises_exceptions(
         self, error_msg, tensor_shapes, last_axes, broadcast_compatible):
     """Tests that compare_batch_dimensions raises expected exceptions."""
     if not tensor_shapes:
         tensors = 0
     else:
         if all(
                 shape.is_static(tensor_shape)
                 for tensor_shape in tensor_shapes):
             tensors = [
                 tf.ones(tensor_shape) for tensor_shape in tensor_shapes
             ]
         else:
             # Dynamic shapes are not supported in eager mode.
             if tf.executing_eagerly():
                 return
             tensors = [
                 tf.compat.v1.placeholder(shape=tensor_shape,
                                          dtype=tf.float32)
                 for tensor_shape in tensor_shapes
             ]
     self.assert_exception_is_raised(
         shape.compare_batch_dimensions,
         error_msg,
         shapes=[],
         tensors=tensors,
         last_axes=last_axes,
         broadcast_compatible=broadcast_compatible)
Esempio n. 3
0
def _is_dynamic_shape(tensors):
  """Helper function to test if any tensor in a list has a dynamic shape.

  Args:
    tensors: A list or tuple of tensors with shapes to test.

  Returns:
    True if any tensor in the list has a dynamic shape, False otherwise.
  """
  if not isinstance(tensors, (list, tuple)):
    raise ValueError("'tensors' must be list of tuple.")
  return not all([shape.is_static(tensor.shape) for tensor in tensors])
Esempio n. 4
0
 def test_compare_batch_dimensions_raises_no_exceptions(
     self, tensor_shapes, last_axes, broadcast_compatible, initial_axes):
   """Tests that compare_batch_dimensions works for various inputs."""
   if all(shape.is_static(tensor_shape) for tensor_shape in tensor_shapes):
     tensors = [tf.ones(tensor_shape) for tensor_shape in tensor_shapes]
   else:
     # Dynamic shapes are not supported in eager mode.
     if tf.executing_eagerly():
       return
     tensors = [
         tf.compat.v1.placeholder(shape=tensor_shape, dtype=tf.float32)
         for tensor_shape in tensor_shapes
     ]
   self.assert_exception_is_not_raised(
       shape.compare_batch_dimensions,
       shapes=[],
       tensors=tensors,
       last_axes=last_axes,
       broadcast_compatible=broadcast_compatible,
       initial_axes=initial_axes)