Exemplo n.º 1
0
def is_valid(matrix, atol=1e-3, name="rotation_matrix_2d_is_valid"):
    r"""Determines if a matrix is a valid rotation matrix.

  Determines if a matrix $$\mathbf{R}$$ is a valid rotation matrix by checking
  that $$\mathbf{R}^T\mathbf{R} = \mathbf{I}$$ and $$\det(\mathbf{R}) = 1$$.

  Note:
    In the following, A1 to An are optional batch dimensions.

  Args:
    matrix: A tensor of shape `[A1, ..., An, 2, 2]`, where the last two
      dimensions represent a 2d rotation matrix.
    atol: The absolute tolerance parameter.
    name: A name for this op that defaults to "rotation_matrix_2d_is_valid".

  Returns:
    A tensor of type `bool` and shape `[A1, ..., An, 1]` where False indicates
    that the input is not a valid rotation matrix.
  """
    with tf.name_scope(name):
        matrix = tf.convert_to_tensor(value=matrix)

        shape.check_static(tensor=matrix,
                           tensor_name="matrix",
                           has_rank_greater_than=1,
                           has_dim_equals=((-2, 2), (-1, 2)))

        return rotation_matrix_common.is_valid(matrix, atol)
Exemplo n.º 2
0
def is_valid(matrix, atol=1e-3, name=None):
    """Determines if a matrix is a valid rotation matrix.

  Note:
    In the following, A1 to An are optional batch dimensions.

  Args:
    matrix: A tensor of shape `[A1, ..., An, 3,3]`, where the last two
      dimensions represent a matrix.
    atol: Absolute tolerance parameter.
    name: A name for this op that defaults to "rotation_matrix_3d_is_valid".

  Returns:
    A tensor of type `bool` and shape `[A1, ..., An, 1]` where False indicates
    that the input is not a valid rotation matrix.
  """
    with tf.compat.v1.name_scope(name, "rotation_matrix_3d_is_valid",
                                 [matrix]):
        matrix = tf.convert_to_tensor(value=matrix)

        shape.check_static(tensor=matrix,
                           tensor_name="matrix",
                           has_rank_greater_than=1,
                           has_dim_equals=((-2, 3), (-1, 3)))

        return rotation_matrix_common.is_valid(matrix, atol)