def create_function_unit_test( function_name: str, unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function: """Creates a tf_function_unit_test from the provided UnitTestSpec.""" function = getattr(tf.math, function_name) signature = unit_test_spec.input_signature if tf_utils.is_complex(signature): function, signature = tf_utils.rewrite_complex_signature( function, signature) if function_name == "top_k": function = _wrap_top_k(function) wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs) if FLAGS.dynamic_dims: signature = tf_utils.apply_function(signature, tf_utils.make_dims_dynamic) return tf_test_utils.tf_function_unit_test( input_signature=signature, input_generator=unit_test_spec.input_generator, input_args=unit_test_spec.input_args, name=unit_test_spec.unit_test_name, rtol=1e-5, atol=1e-5)(wrapped_function)
def create_layer_unit_test( model: tf.keras.Model, unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function: """Wrap the model's __call__ function in a tf.function for testing.""" static_signature = unit_test_spec.input_signature dynamic_signature = static_signature if FLAGS.dynamic_dims: dynamic_signature = tf_utils.apply_function(dynamic_signature, tf_utils.make_dims_dynamic) if len(static_signature) > 1: static_signature = [static_signature] dynamic_signature = [dynamic_signature] call = lambda *args: model(*args, training=FLAGS.training) return tf_test_utils.tf_function_unit_test( input_signature=dynamic_signature, static_signature=static_signature, input_generator=unit_test_spec.input_generator, input_args=unit_test_spec.input_args, name=unit_test_spec.unit_test_name)(call)
def unit_test_specs_from_args( names_to_input_args: Dict[str, Sequence[Any]], kwargs_to_values: Dict[str, Sequence[Any]] = None) -> List[UnitTestSpec]: """Generates a Cartesian product of UnitTestSpecs from the given arguments. Args: signature_shapes: A dict mapping names for input arguments to the arguments themselves. kwargs_to_values: A dict mapping kwarg names to sequences of values that they can take. Returns: A list of 'UnitTestSpec's generated from the provided arguments. """ # Validate and parse 'kwargs_to_values' names_to_kwargs = _named_kwargs_product(kwargs_to_values) # Create a Cartesian product through all specifications and their names. specs = [names_to_input_args, names_to_kwargs] key_product = itertools.product(*[list(spec.keys()) for spec in specs]) value_product = itertools.product(*[list(spec.values()) for spec in specs]) # Generate a UnitTestSpec for each element in the above product. unit_tests = [] for keys, (input_args, kwargs) in zip(key_product, value_product): unit_test_name = "__".join(key for key in keys if key) input_signature = tf_utils.apply_function( input_args, lambda x: tf.TensorSpec.from_tensor(tf.convert_to_tensor(x))) unit_tests.append( UnitTestSpec( unit_test_name=unit_test_name, input_signature=input_signature, input_generator=None, input_args=input_args, kwargs=kwargs, )) return unit_tests
def test_apply_function(self): inputs = [1, [2, 3], (4, 5), {'6': 6, '78': [7, 8]}] expected = [0, [1, 2], (3, 4), {'6': 5, '78': [6, 7]}] result = tf_utils.apply_function(inputs, lambda x: x - 1) self.assertEqual(result, expected) self.assertNotEqual(inputs, expected)