def test_wrap_as_zero_or_one_arg_callable(self, unused_index, fn,
                                           parameter_type, unpack, arg,
                                           expected_result):
     wrapped_fn = function_utils.wrap_as_zero_or_one_arg_callable(
         fn, parameter_type, unpack)
     actual_result = wrapped_fn(arg) if parameter_type else wrapped_fn()
     self.assertEqual(actual_result, expected_result)
 def __init__(self, fn, parameter_type, unpack, name=None):
     del name
     self._fn = function_utils.wrap_as_zero_or_one_arg_callable(
         fn, parameter_type, unpack)
     super().__init__(
         computation_types.FunctionType(parameter_type, tf.string),
         context_stack_impl.context_stack)
 def zero_or_one_arg_fn_to_building_block(self, fn, parameter_type, fn_str):
   parameter_name = 'foo'
   parameter_type = computation_types.to_type(parameter_type)
   fn = function_utils.wrap_as_zero_or_one_arg_callable(fn, parameter_type)
   result = federated_computation_utils.zero_or_one_arg_fn_to_building_block(
       fn, parameter_name, parameter_type, context_stack_impl.context_stack)
   self.assertEqual(str(result), fn_str)
 def test_zero_or_one_arg_fn_to_building_block(self, fn, parameter_type,
                                               fn_str):
   parameter_name = 'foo' if parameter_type is not None else None
   fn = function_utils.wrap_as_zero_or_one_arg_callable(fn, parameter_type)
   result, _ = federated_computation_utils.zero_or_one_arg_fn_to_building_block(
       fn, parameter_name, parameter_type, context_stack_impl.context_stack)
   self.assertStartsWith(str(result), fn_str)
示例#5
0
    def test_raises_value_error_with_none_result(self):
        fn = lambda: None
        parameter_type = None
        fn = function_utils.wrap_as_zero_or_one_arg_callable(
            fn, parameter_type)

        with self.assertRaisesRegex(ValueError, 'must return some non-`None`'):
            federated_computation_utils.zero_or_one_arg_fn_to_building_block(
                fn, None, parameter_type, context_stack_impl.context_stack)
示例#6
0
 def test_py_container_args(self, fn, parameter_type, exepcted_result_type):
   parameter_name = 'foo'
   fn = function_utils.wrap_as_zero_or_one_arg_callable(fn, parameter_type)
   _, type_signature = federated_computation_utils.zero_or_one_arg_fn_to_building_block(
       fn, parameter_name, parameter_type, context_stack_impl.context_stack)
   self.assertIs(type(type_signature.result), type(exepcted_result_type))
   self.assertIs(type_signature.result.python_container,
                 exepcted_result_type.python_container)
   self.assertEqual(type_signature.result, exepcted_result_type)
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None):
  """Wrapper function to plug Tensorflow logic in to TFF framework."""
  del name  # Unused.
  target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
      target_fn, parameter_type, unpack)
  if not type_utils.is_tensorflow_compatible_type(parameter_type):
    raise TypeError('`tf_computation`s can accept only parameter types with '
                    'constituents `SequenceType`, `NamedTupleType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
      target_fn, parameter_type, ctx_stack)
  return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
 def test_py_container_args(self, fn, parameter_type, result_type):
     parameter_name = 'foo'
     parameter_type = computation_types.to_type(parameter_type)
     fn = function_utils.wrap_as_zero_or_one_arg_callable(
         fn, parameter_type)
     _, annotated_type = federated_computation_utils.zero_or_one_arg_fn_to_building_block(
         fn, parameter_name, parameter_type,
         context_stack_impl.context_stack)
     self.assertIs(type(annotated_type.result), type(result_type))
     self.assertIs(
         NamedTupleTypeWithPyContainerType.get_container_type(
             annotated_type.result),
         NamedTupleTypeWithPyContainerType.get_container_type(result_type))
     self.assertEqual(annotated_type.result, result_type)
示例#9
0
def _federated_computation_wrapper_fn(target_fn,
                                      parameter_type,
                                      unpack,
                                      name=None):
    """Wrapper function to plug orchestration logic in to TFF framework."""
    target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
        target_fn, parameter_type, unpack)
    ctx_stack = context_stack_impl.context_stack
    target_lambda = (
        federated_computation_utils.zero_or_one_arg_fn_to_building_block(
            target_fn,
            'arg' if parameter_type else None,
            parameter_type,
            ctx_stack,
            suggested_name=name))
    return computation_impl.ComputationImpl(target_lambda.proto, ctx_stack)
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None):
  """Wrapper function to plug Tensorflow logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
  del name  # Unused.
  target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
      target_fn, parameter_type, unpack)
  if not type_analysis.is_tensorflow_compatible_type(parameter_type):
    raise TypeError('`tf_computation`s can accept only parameter types with '
                    'constituents `SequenceType`, `StructType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
      target_fn, parameter_type, ctx_stack)
  return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def _federated_computation_wrapper_fn(target_fn,
                                      parameter_type,
                                      unpack,
                                      name=None):
    """Wrapper function to plug orchestration logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
    target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
        target_fn, parameter_type, unpack)
    ctx_stack = context_stack_impl.context_stack
    target_lambda = (
        federated_computation_utils.zero_or_one_arg_fn_to_building_block(
            target_fn,
            'arg' if parameter_type else None,
            parameter_type,
            ctx_stack,
            suggested_name=name))
    return computation_impl.ComputationImpl(target_lambda.proto, ctx_stack)