return self._func


class ContextForTest(context_base.Context):
    def ingest(self, val, type_spec):
        return val

    def invoke(self, comp, arg):
        result = comp.func(
            arg) if comp.type_signature.parameter else comp.func()
        return '{} : {} -> {}'.format(str(arg),
                                      str(comp.type_signature.parameter),
                                      str(result))


test_wrap = computation_wrapper.ComputationWrapper(WrappedForTest)


class ComputationWrapperTest(test.TestCase):

    # NOTE: Many tests below silence certain linter warnings. These warnings are
    # not applicable, since it's the wrapper code, not not the dummy functions
    # that are being tested, so whether the specific function declarations used
    # here follow good practices is not really relevant. The purpose of the test
    # is to exercise various corner cases that the wrapper needs to be able to
    # correctly handle.

    def test_as_decorator_with_kwargs(self):
        with self.assertRaises(TypeError):

            @test_wrap(foo=1)
    target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
        target_fn, parameter_type, unpack)
    if not type_utils.is_tensorflow_compatible_type(parameter_type):
        raise TypeError(
            '`tf_computation`s can accept only parameter types with '
            'constituents `SequenceType`, `NamedTupleType` '
            'and `TensorType`; you have attempted to create one '
            'with the type {}.'.format(parameter_type))
    ctx_stack = context_stack_impl.context_stack
    comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        target_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack,
                                            extra_type_spec)


tensorflow_wrapper = computation_wrapper.ComputationWrapper(_tf_wrapper_fn)


def _tf2_wrapper_fn(target_fn, parameter_type, unpack, name=None):
    del name  # Unused.
    comp_pb, extra_type_spec = (
        tensorflow_serialization.serialize_tf2_as_tf_computation(
            target_fn, parameter_type, unpack=unpack))
    return computation_impl.ComputationImpl(comp_pb,
                                            context_stack_impl.context_stack,
                                            extra_type_spec)


tf2_wrapper = computation_wrapper.ComputationWrapper(_tf2_wrapper_fn)