예제 #1
0
def create_function_unit_test(
    function_name: str,
    unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
  """Creates a tf_function_unit_test from the provided UnitTestSpec."""
  function = getattr(tf.math, function_name)
  signature = unit_test_spec.input_signature

  if tf_utils.is_complex(signature):
    function, signature = tf_utils.rewrite_complex_signature(
        function, signature)
  wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs)

  if FLAGS.dynamic_dims:
    signature = tf_utils.apply_function(signature, tf_utils.make_dims_dynamic)

  return tf_test_utils.tf_function_unit_test(
      input_signature=signature,
      input_generator=unit_test_spec.input_generator,
      input_args=unit_test_spec.input_args,
      name=unit_test_spec.unit_test_name,
      rtol=1e-5,
      atol=1e-5)(wrapped_function)
예제 #2
0
def create_layer_unit_test(
        model: tf.keras.Model,
        unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
    """Wrap the model's __call__ function in a tf.function for testing."""
    static_signature = unit_test_spec.input_signature

    dynamic_signature = static_signature
    if FLAGS.dynamic_dims:
        dynamic_signature = tf_utils.apply_function(dynamic_signature,
                                                    tf_utils.make_dims_dynamic)

    if len(static_signature) > 1:
        static_signature = [static_signature]
        dynamic_signature = [dynamic_signature]

    call = lambda *args: model(*args, training=FLAGS.training)
    return tf_test_utils.tf_function_unit_test(
        input_signature=dynamic_signature,
        static_signature=static_signature,
        input_generator=unit_test_spec.input_generator,
        input_args=unit_test_spec.input_args,
        name=unit_test_spec.unit_test_name)(call)
예제 #3
0
def unit_test_specs_from_args(
    names_to_input_args: Dict[str, Sequence[Any]],
    kwargs_to_values: Dict[str, Sequence[Any]] = None) -> List[UnitTestSpec]:
  """Generates a Cartesian product of UnitTestSpecs from the given arguments.

  Args:
    signature_shapes:
      A dict mapping names for input arguments to the arguments themselves.
    kwargs_to_values:
      A dict mapping kwarg names to sequences of values that they can take.

  Returns:
    A list of 'UnitTestSpec's generated from the provided arguments.
  """
  # Validate and parse 'kwargs_to_values'
  names_to_kwargs = _named_kwargs_product(kwargs_to_values)

  # Create a Cartesian product through all specifications and their names.
  specs = [names_to_input_args, names_to_kwargs]
  key_product = itertools.product(*[list(spec.keys()) for spec in specs])
  value_product = itertools.product(*[list(spec.values()) for spec in specs])

  # Generate a UnitTestSpec for each element in the above product.
  unit_tests = []
  for keys, (input_args, kwargs) in zip(key_product, value_product):
    unit_test_name = "__".join(key for key in keys if key)
    input_signature = tf_utils.apply_function(
        input_args,
        lambda x: tf.TensorSpec.from_tensor(tf.convert_to_tensor(x)))
    unit_tests.append(
        UnitTestSpec(
            unit_test_name=unit_test_name,
            input_signature=input_signature,
            input_generator=None,
            input_args=input_args,
            kwargs=kwargs,
        ))
  return unit_tests
예제 #4
0
 def test_apply_function(self):
     inputs = [1, [2, 3], (4, 5), {'6': 6, '78': [7, 8]}]
     expected = [0, [1, 2], (3, 4), {'6': 5, '78': [6, 7]}]
     result = tf_utils.apply_function(inputs, lambda x: x - 1)
     self.assertEqual(result, expected)
     self.assertNotEqual(inputs, expected)