def test_find_operation_with_name_raises_if_not_found(self):
        with self.assertRaises(ValueError):
            all_operations.find_operation_with_name('bad name')

        with self.assertRaises(ValueError):
            all_operations.find_operation_with_name(
                'tf.add(x, y)',
                operation_list=all_operations.get_python_operations())
 def test_get_python_operations(self):
     operations = all_operations.get_python_operations()
     self.assertTrue(
         all(
             isinstance(operation, operation_base.Operation)
             for operation in operations))
     self.assertTrue(
         any(operation.name == 'IndexingOperation'
             for operation in operations))
Exemplo n.º 3
0
def print_supported_operations():
    """Prints all of the supported operations."""
    print('TensorFlow functions:\n' '---------------------')
    for operation in all_operations.get_tf_operations():
        print(operation.name)
    print()
    print('SparseTensor functions:\n' '-----------------------')
    for operation in all_operations.get_sparse_operations():
        print(operation.name)
    print()
    print('Python-syntax operations:\n' '-------------------------')
    for operation in all_operations.get_python_operations():
        syntax_form = operation.reconstruct_expression_from_strings(
            ['arg{}'.format(i + 1) for i in range(operation.num_args)])
        print('{:35} {}'.format(operation.name + ':', syntax_form))
class AllOperationsTest(parameterized.TestCase):
    def test_get_python_operations(self):
        operations = all_operations.get_python_operations()
        self.assertTrue(
            all(
                isinstance(operation, operation_base.Operation)
                for operation in operations))
        self.assertTrue(
            any(operation.name == 'IndexingOperation'
                for operation in operations))

    def test_get_tf_operations(self):
        operations = all_operations.get_tf_operations()
        self.assertLen(operations, len(tf_functions.TF_FUNCTIONS))
        self.assertTrue(
            any(operation.name == 'tf.add(x, y)' for operation in operations))
        self.assertFalse(
            any(operation.name == 'tf.sparse.add(a, b)'
                for operation in operations))

    def test_get_sparse_operations(self):
        operations = all_operations.get_sparse_operations()
        self.assertLen(operations, len(tf_functions.SPARSE_FUNCTIONS))
        self.assertFalse(
            any(operation.name == 'tf.add(x, y)' for operation in operations))
        self.assertTrue(
            any(operation.name == 'tf.sparse.add(a, b)'
                for operation in operations))

    def test_get_operations_correct_type(self):
        operations = all_operations.get_operations(
            include_sparse_operations=True)
        self.assertTrue(
            all(
                isinstance(element, operation_base.Operation)
                for element in operations))

    @parameterized.named_parameters(
        ('with_sparse', True,
         len(tf_functions.TF_FUNCTIONS) + len(tf_functions.SPARSE_FUNCTIONS) +
         len(all_operations.get_python_operations())),
        ('without_sparse', False, len(tf_functions.TF_FUNCTIONS) +
         len(all_operations.get_python_operations())))
    def test_get_operations_correct_cardinality(self,
                                                include_sparse_operations,
                                                expected_cardinality):
        operations = all_operations.get_operations(
            include_sparse_operations=include_sparse_operations)
        self.assertLen(operations, expected_cardinality)

    @parameterized.named_parameters(
        ('indexing', 'IndexingOperation', False),
        ('slicing_axis_0_both', 'SlicingAxis0BothOperation', False),
        ('tf_add', 'tf.add(x, y)', False),
        ('tf_cast', 'tf.cast(x, dtype)', False),
        ('tf_sparse_expand_dims', 'tf.sparse.expand_dims(sp_input, axis)',
         True))
    def test_get_operations_includes_expected(self, name, is_sparse):
        for include_sparse_operations in [True, False]:
            operations = all_operations.get_operations(
                include_sparse_operations=include_sparse_operations)
            should_be_included = include_sparse_operations or not is_sparse
            self.assertEqual(
                any(operation.name == name for operation in operations),
                should_be_included)

    def test_get_operations_unique_names(self):
        operations = all_operations.get_operations(
            include_sparse_operations=True)
        names_set = {operation.name for operation in operations}
        self.assertLen(names_set, len(operations))

    def test_get_operations_all_have_docstrings(self):
        operations = all_operations.get_operations(
            include_sparse_operations=True)
        self.assertTrue(
            all(operation.metadata.docstring for operation in operations))

    def test_get_operations_consistent_order(self):
        operations_1 = all_operations.get_operations(
            include_sparse_operations=True)
        operations_2 = all_operations.get_operations(
            include_sparse_operations=True)
        self.assertEqual([op.name for op in operations_1],
                         [op.name for op in operations_2])

    def test_find_operation_with_name(self):
        operation = all_operations.find_operation_with_name('tf.add(x, y)')
        self.assertEqual(operation.name, 'tf.add(x, y)')

        operation = all_operations.find_operation_with_name(
            'tf.add(x, y)', operation_list=all_operations.get_tf_operations())
        self.assertEqual(operation.name, 'tf.add(x, y)')

    def test_find_operation_with_name_raises_if_not_found(self):
        with self.assertRaises(ValueError):
            all_operations.find_operation_with_name('bad name')

        with self.assertRaises(ValueError):
            all_operations.find_operation_with_name(
                'tf.add(x, y)',
                operation_list=all_operations.get_python_operations())