예제 #1
0
def create_function_unit_test(
        function_name: str,
        unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
    """Creates a tf_function_unit_test from the provided UnitTestSpec."""
    function = getattr(tf.math, function_name)
    signature = unit_test_spec.input_signature

    if tf_utils.is_complex(signature):
        function, signature = tf_utils.rewrite_complex_signature(
            function, signature)
    if function_name == "top_k":
        function = _wrap_top_k(function)
    wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs)

    if FLAGS.dynamic_dims:
        signature = tf_utils.apply_function(signature,
                                            tf_utils.make_dims_dynamic)

    return tf_test_utils.tf_function_unit_test(
        input_signature=signature,
        input_generator=unit_test_spec.input_generator,
        input_args=unit_test_spec.input_args,
        name=unit_test_spec.unit_test_name,
        rtol=1e-5,
        atol=1e-5)(wrapped_function)
예제 #2
0
def create_layer_unit_test(
        model: tf.keras.Model,
        unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
    """Wrap the model's __call__ function in a tf.function for testing."""
    static_signature = unit_test_spec.input_signature

    dynamic_signature = static_signature
    if FLAGS.dynamic_dims:
        dynamic_signature = tf_utils.apply_function(dynamic_signature,
                                                    tf_utils.make_dims_dynamic)

    if len(static_signature) > 1:
        static_signature = [static_signature]
        dynamic_signature = [dynamic_signature]

    call = lambda *args: model(*args, training=FLAGS.training)
    return tf_test_utils.tf_function_unit_test(
        input_signature=dynamic_signature,
        static_signature=static_signature,
        input_generator=unit_test_spec.input_generator,
        input_args=unit_test_spec.input_args,
        name=unit_test_spec.unit_test_name)(call)
예제 #3
0
def unit_test_specs_from_args(
        names_to_input_args: Dict[str, Sequence[Any]],
        kwargs_to_values: Dict[str,
                               Sequence[Any]] = None) -> List[UnitTestSpec]:
    """Generates a Cartesian product of UnitTestSpecs from the given arguments.

  Args:
    signature_shapes:
      A dict mapping names for input arguments to the arguments themselves.
    kwargs_to_values:
      A dict mapping kwarg names to sequences of values that they can take.

  Returns:
    A list of 'UnitTestSpec's generated from the provided arguments.
  """
    # Validate and parse 'kwargs_to_values'
    names_to_kwargs = _named_kwargs_product(kwargs_to_values)

    # Create a Cartesian product through all specifications and their names.
    specs = [names_to_input_args, names_to_kwargs]
    key_product = itertools.product(*[list(spec.keys()) for spec in specs])
    value_product = itertools.product(*[list(spec.values()) for spec in specs])

    # Generate a UnitTestSpec for each element in the above product.
    unit_tests = []
    for keys, (input_args, kwargs) in zip(key_product, value_product):
        unit_test_name = "__".join(key for key in keys if key)
        input_signature = tf_utils.apply_function(
            input_args,
            lambda x: tf.TensorSpec.from_tensor(tf.convert_to_tensor(x)))
        unit_tests.append(
            UnitTestSpec(
                unit_test_name=unit_test_name,
                input_signature=input_signature,
                input_generator=None,
                input_args=input_args,
                kwargs=kwargs,
            ))
    return unit_tests
예제 #4
0
 def test_apply_function(self):
   inputs = [1, [2, 3], (4, 5), {'6': 6, '78': [7, 8]}]
   expected = [0, [1, 2], (3, 4), {'6': 5, '78': [6, 7]}]
   result = tf_utils.apply_function(inputs, lambda x: x - 1)
   self.assertEqual(result, expected)
   self.assertNotEqual(inputs, expected)