def fn(self):
        return self._fn


class ContextForTest(context_base.Context):
    def ingest(self, val, type_spec):
        return val

    def invoke(self, comp, arg):
        result = comp.fn(arg) if comp.type_signature.parameter else comp.fn()
        return '{} : {} -> {}'.format(str(arg),
                                      str(comp.type_signature.parameter),
                                      str(result))


test_wrap = computation_wrapper.ComputationWrapper(WrappedForTest)


class ComputationWrapperTest(test.TestCase):

    # Note: Many tests below silence certain linter warnings. These warnings are
    # not applicable, since it's the wrapper code, not not the dummy functions
    # that are being tested, so whether the specific function declarations used
    # here follow good practices is not really relevant. The purpose of the test
    # is to exercise various corner cases that the wrapper needs to be able to
    # correctly handle.

    def test_as_decorator_with_kwargs(self):
        with self.assertRaises(TypeError):

            @test_wrap(foo=1)
Esempio n. 2
0
    target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
        target_fn, parameter_type, unpack)
    if not type_analysis.is_tensorflow_compatible_type(parameter_type):
        raise TypeError(
            '`tf_computation`s can accept only parameter types with '
            'constituents `SequenceType`, `StructType` '
            'and `TensorType`; you have attempted to create one '
            'with the type {}.'.format(parameter_type))
    ctx_stack = context_stack_impl.context_stack
    comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        target_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack,
                                            extra_type_spec)


tensorflow_wrapper = computation_wrapper.ComputationWrapper(_tf_wrapper_fn)


def _tf2_wrapper_fn(target_fn, parameter_type, unpack, name=None):
    """Wrapper function to plug Tensorflow 2.0 logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
    del name  # Unused.
    comp_pb, extra_type_spec = (
        tensorflow_serialization.serialize_tf2_as_tf_computation(
            target_fn, parameter_type, unpack=unpack))
    return computation_impl.ComputationImpl(comp_pb,
                                            context_stack_impl.context_stack,
                                            extra_type_spec)
Esempio n. 3
0
        return val

    def invoke(self, zero_traced_fn, arg):
        return Result(arg=arg,
                      arg_type=zero_traced_fn.type_signature.parameter,
                      zero_result=zero_traced_fn.zero_result)


@attr.s
class Result:
    arg = attr.ib()
    arg_type = attr.ib()
    zero_result = attr.ib()


test_wrap = computation_wrapper.ComputationWrapper(_zero_tracer)


class ComputationWrapperTest(test_utils.TestCase):

    # Note: Many tests below silence certain linter warnings. These warnings are
    # not applicable, since it's the wrapper code, not not the dummy functions
    # that are being tested, so whether the specific function declarations used
    # here follow good practices is not really relevant. The purpose of the test
    # is to exercise various corner cases that the wrapper needs to be able to
    # correctly handle.

    def test_as_decorator_with_kwargs(self):
        with self.assertRaises(TypeError):

            @test_wrap(foo=1)
from tensorflow_federated.python.core.impl.context_stack import context_stack_impl
from tensorflow_federated.python.core.impl.utils import function_utils
from tensorflow_federated.python.core.impl.wrappers import computation_wrapper


def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    fn_to_wrap: The Python function containing JAX code to be serialized as a
      computation containing XLA.
    fn_name: The name for the constructed computation (currently ignored).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if there's none.
    unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`.

  Returns:
    An instance of `computation_impl.ComputationImpl` with the constructed
    computation.
  """
    del fn_name  # Unused.
    unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
        fn_to_wrap, parameter_type, unpack=unpack)
    ctx_stack = context_stack_impl.context_stack
    comp_pb = jax_serialization.serialize_jax_computation(
        fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack)


jax_wrapper = computation_wrapper.ComputationWrapper(_jax_strategy_fn)
                    'constituents `SequenceType`, `StructType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  tf_serializer = tensorflow_serialization.tf_computation_serializer(
      parameter_type, ctx_stack)
  arg = next(tf_serializer)
  try:
    result = yield arg
  except Exception as e:  # pylint: disable=broad-except
    tf_serializer.throw(e)
  comp_pb, extra_type_spec = tf_serializer.send(result)
  yield computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)


tensorflow_wrapper = computation_wrapper.ComputationWrapper(
    computation_wrapper.PythonTracingStrategy(_tf_wrapper_fn))


def _federated_computation_wrapper_fn(parameter_type, name):
  """Wrapper function to plug orchestration logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  ctx_stack = context_stack_impl.context_stack
  if parameter_type is None:
    parameter_name = None
  else:
    parameter_name = 'arg'
  fn_generator = federated_computation_utils.federated_computation_serializer(
      parameter_name=parameter_name,
        return val

    def invoke(self, zero_traced_fn, arg):
        return Result(arg=arg,
                      arg_type=zero_traced_fn.type_signature.parameter,
                      zero_result=zero_traced_fn.zero_result)


@attr.s
class Result:
    arg = attr.ib()
    arg_type = attr.ib()
    zero_result = attr.ib()


test_wrap = computation_wrapper.ComputationWrapper(
    computation_wrapper.PythonTracingStrategy(_zero_tracer))


class ComputationWrapperTest(test_case.TestCase):

    # Note: Many tests below silence certain linter warnings. These warnings are
    # not applicable, since it's the wrapper code, not not the whimsy functions
    # that are being tested, so whether the specific function declarations used
    # here follow good practices is not really relevant. The purpose of the test
    # is to exercise various corner cases that the wrapper needs to be able to
    # correctly handle.

    def test_as_decorator_with_kwargs(self):
        with self.assertRaises(TypeError):

            @test_wrap(foo=1)