Ejemplo n.º 1
0
def create_function_unit_test(
        function_name: str,
        unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
    """Creates a tf_function_unit_test from the provided UnitTestSpec."""
    function = getattr(tf.math, function_name)
    signature = unit_test_spec.input_signature

    if tf_utils.is_complex(signature):
        function, signature = tf_utils.rewrite_complex_signature(
            function, signature)
    if function_name == "top_k":
        function = _wrap_top_k(function)
    wrapped_function = lambda *args: function(*args, **unit_test_spec.kwargs)

    if FLAGS.dynamic_dims:
        signature = tf_utils.apply_function(signature,
                                            tf_utils.make_dims_dynamic)

    return tf_test_utils.tf_function_unit_test(
        input_signature=signature,
        input_generator=unit_test_spec.input_generator,
        input_args=unit_test_spec.input_args,
        name=unit_test_spec.unit_test_name,
        rtol=1e-5,
        atol=1e-5)(wrapped_function)
Ejemplo n.º 2
0
    def __init__(self):
        super().__init__()
        self.m = utils.get_model_with_default_params(
            FLAGS.model, MODE_ENUM_TO_MODE[FLAGS.mode])

        call = lambda *args: self.m(*args, training=False)
        input_signature = [
            tf.TensorSpec(tensor.shape) for tensor in self.m.inputs
        ]
        self.call = tf_test_utils.tf_function_unit_test(
            input_signature=input_signature, name="call", atol=1e-5)(call)
Ejemplo n.º 3
0
  def __init__(self):
    super().__init__()
    self.m = initialize_model()

    input_shape = list([BATCH_SIZE] + self.m.inputs[0].shape[1:])

    # Some models accept dynamic image dimensions by default, so we use
    # IMAGE_DIM as a stand-in.
    for i, dim in enumerate(input_shape):
      if dim is None:
        input_shape[i] = IMAGE_DIM

    # Specify input shape with a static batch size.
    # TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
    self.call = tf_test_utils.tf_function_unit_test(
        input_signature=[tf.TensorSpec(input_shape)],
        name="call",
        rtol=1e-5,
        atol=1e-5)(lambda x: self.m(x, training=False))
Ejemplo n.º 4
0
def create_layer_unit_test(
        model: tf.keras.Model,
        unit_test_spec: tf_test_utils.UnitTestSpec) -> tf.function:
    """Wrap the model's __call__ function in a tf.function for testing."""
    static_signature = unit_test_spec.input_signature

    dynamic_signature = static_signature
    if FLAGS.dynamic_dims:
        dynamic_signature = tf_utils.apply_function(dynamic_signature,
                                                    tf_utils.make_dims_dynamic)

    if len(static_signature) > 1:
        static_signature = [static_signature]
        dynamic_signature = [dynamic_signature]

    call = lambda *args: model(*args, training=FLAGS.training)
    return tf_test_utils.tf_function_unit_test(
        input_signature=dynamic_signature,
        static_signature=static_signature,
        input_generator=unit_test_spec.input_generator,
        input_args=unit_test_spec.input_args,
        name=unit_test_spec.unit_test_name)(call)