def create_function_unit_test( function_name: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function: """Creates a tf_function_unit_test from the provided UnitTestSpec.""" function = getattr(tf.math, function_name) signature = unit_test_spec.input_signature if tf_utils.is_complex(signature): function, signature = tf_utils.rewrite_complex_signature( function, signature) if function_name == "top_k": function = _wrap_top_k(function) wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs) if FLAGS.dynamic_dims: signature = tf_utils.apply_function(signature, tf_utils.make_dims_dynamic) return tf_test_utils.tf_function_unit_test( input_signature=signature, input_generator=unit_test_spec.input_generator, input_args=unit_test_spec.input_args, name=unit_test_spec.unit_test_name, rtol=1e-5, atol=1e-5)(wrapped_function)
def __init__(self): super().__init__() for function in FLAGS.functions: for unit_test_spec in FUNCTIONS_TO_UNIT_TEST_SPECS[function]: if not FLAGS.test_complex and tf_utils.is_complex( unit_test_spec.input_signature): continue function_unit_test = create_function_unit_test( function, unit_test_spec) setattr(self, unit_test_spec.unit_test_name, function_unit_test)
def main(argv): del argv # Unused. if hasattr(tf, "enable_v2_behavior"): tf.enable_v2_behavior() if FLAGS.list_functions_with_complex_tests: for function_name, unit_test_specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items( ): for spec in unit_test_specs: if tf_utils.is_complex(spec.input_signature): print(f' "{function_name}",') return if FLAGS.functions is None: raise flags.IllegalFlagValueError( "'--functions' must be specified if " "'--list_functions_with_complex_tests' isn't") TfMathTest.generate_unit_tests(TfMathModule) tf.test.main()