Exemplo n.º 1
0
def create_function_unit_test(
        function_name: str,
        unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
    """Creates a tf_function_unit_test from the provided UnitTestSpec."""
    function = getattr(tf.math, function_name)
    signature = unit_test_spec.input_signature

    if tf_utils.is_complex(signature):
        function, signature = tf_utils.rewrite_complex_signature(
            function, signature)
    if function_name == "top_k":
        function = _wrap_top_k(function)
    wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs)

    if FLAGS.dynamic_dims:
        signature = tf_utils.apply_function(signature,
                                            tf_utils.make_dims_dynamic)

    return tf_test_utils.tf_function_unit_test(
        input_signature=signature,
        input_generator=unit_test_spec.input_generator,
        input_args=unit_test_spec.input_args,
        name=unit_test_spec.unit_test_name,
        rtol=1e-5,
        atol=1e-5)(wrapped_function)
Exemplo n.º 2
0
 def __init__(self):
     super().__init__()
     for function in FLAGS.functions:
         for unit_test_spec in FUNCTIONS_TO_UNIT_TEST_SPECS[function]:
             if not FLAGS.test_complex and tf_utils.is_complex(
                     unit_test_spec.input_signature):
                 continue
             function_unit_test = create_function_unit_test(
                 function, unit_test_spec)
             setattr(self, unit_test_spec.unit_test_name,
                     function_unit_test)
Exemplo n.º 3
0
def main(argv):
    del argv  # Unused.
    if hasattr(tf, "enable_v2_behavior"):
        tf.enable_v2_behavior()

    if FLAGS.list_functions_with_complex_tests:
        for function_name, unit_test_specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items(
        ):
            for spec in unit_test_specs:
                if tf_utils.is_complex(spec.input_signature):
                    print(f'    "{function_name}",')
        return

    if FLAGS.functions is None:
        raise flags.IllegalFlagValueError(
            "'--functions' must be specified if "
            "'--list_functions_with_complex_tests' isn't")

    TfMathTest.generate_unit_tests(TfMathModule)
    tf.test.main()