Ejemplo n.º 1
0
 def __init__(self):
   super().__init__()
   for function in FLAGS.functions:
     for unit_test_spec in FUNCTIONS_TO_UNIT_TEST_SPECS[function]:
       if not FLAGS.test_complex and tf_utils.is_complex(
           unit_test_spec.input_signature):
         continue
       function_unit_test = create_function_unit_test(function, unit_test_spec)
       setattr(self, unit_test_spec.unit_test_name, function_unit_test)
Ejemplo n.º 2
0
def main(argv):
  del argv  # Unused.
  if hasattr(tf, "enable_v2_behavior"):
    tf.enable_v2_behavior()

  if FLAGS.list_functions_with_complex_tests:
    for function_name, unit_test_specs in FUNCTIONS_TO_UNIT_TEST_SPECS.items():
      for spec in unit_test_specs:
        if tf_utils.is_complex(spec.input_signature):
          print(f'    "{function_name}",')
    return

  if FLAGS.functions is None:
    raise flags.IllegalFlagValueError(
        "'--functions' must be specified if "
        "'--list_functions_with_complex_tests' isn't")

  TfMathTest.generate_unit_tests(TfMathModule)
  tf.test.main()
Ejemplo n.º 3
0
def create_function_unit_test(
    function_name: str,
    unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
  """Creates a tf_function_unit_test from the provided UnitTestSpec."""
  function = getattr(tf.math, function_name)
  signature = unit_test_spec.input_signature

  if tf_utils.is_complex(signature):
    function, signature = tf_utils.rewrite_complex_signature(
        function, signature)
  wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs)

  if FLAGS.dynamic_dims:
    signature = tf_utils.apply_function(signature, tf_utils.make_dims_dynamic)

  return tf_test_utils.tf_function_unit_test(
      input_signature=signature,
      input_generator=unit_test_spec.input_generator,
      input_args=unit_test_spec.input_args,
      name=unit_test_spec.unit_test_name,
      rtol=1e-5,
      atol=1e-5)(wrapped_function)