def test_wrap_as_zero_or_one_arg_callable(self, unused_index, fn, parameter_type, unpack, arg, expected_result): wrapped_fn = function_utils.wrap_as_zero_or_one_arg_callable( fn, parameter_type, unpack) actual_result = wrapped_fn(arg) if parameter_type else wrapped_fn() self.assertEqual(actual_result, expected_result)
def __init__(self, fn, parameter_type, unpack, name=None): del name self._fn = function_utils.wrap_as_zero_or_one_arg_callable( fn, parameter_type, unpack) super().__init__( computation_types.FunctionType(parameter_type, tf.string), context_stack_impl.context_stack)
def zero_or_one_arg_fn_to_building_block(self, fn, parameter_type, fn_str): parameter_name = 'foo' parameter_type = computation_types.to_type(parameter_type) fn = function_utils.wrap_as_zero_or_one_arg_callable(fn, parameter_type) result = federated_computation_utils.zero_or_one_arg_fn_to_building_block( fn, parameter_name, parameter_type, context_stack_impl.context_stack) self.assertEqual(str(result), fn_str)
def test_zero_or_one_arg_fn_to_building_block(self, fn, parameter_type, fn_str): parameter_name = 'foo' if parameter_type is not None else None fn = function_utils.wrap_as_zero_or_one_arg_callable(fn, parameter_type) result, _ = federated_computation_utils.zero_or_one_arg_fn_to_building_block( fn, parameter_name, parameter_type, context_stack_impl.context_stack) self.assertStartsWith(str(result), fn_str)
def test_raises_value_error_with_none_result(self): fn = lambda: None parameter_type = None fn = function_utils.wrap_as_zero_or_one_arg_callable( fn, parameter_type) with self.assertRaisesRegex(ValueError, 'must return some non-`None`'): federated_computation_utils.zero_or_one_arg_fn_to_building_block( fn, None, parameter_type, context_stack_impl.context_stack)
def test_py_container_args(self, fn, parameter_type, exepcted_result_type): parameter_name = 'foo' fn = function_utils.wrap_as_zero_or_one_arg_callable(fn, parameter_type) _, type_signature = federated_computation_utils.zero_or_one_arg_fn_to_building_block( fn, parameter_name, parameter_type, context_stack_impl.context_stack) self.assertIs(type(type_signature.result), type(exepcted_result_type)) self.assertIs(type_signature.result.python_container, exepcted_result_type.python_container) self.assertEqual(type_signature.result, exepcted_result_type)
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug Tensorflow logic in to TFF framework.""" del name # Unused. target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) if not type_utils.is_tensorflow_compatible_type(parameter_type): raise TypeError('`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `NamedTupleType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( target_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def test_py_container_args(self, fn, parameter_type, result_type): parameter_name = 'foo' parameter_type = computation_types.to_type(parameter_type) fn = function_utils.wrap_as_zero_or_one_arg_callable( fn, parameter_type) _, annotated_type = federated_computation_utils.zero_or_one_arg_fn_to_building_block( fn, parameter_name, parameter_type, context_stack_impl.context_stack) self.assertIs(type(annotated_type.result), type(result_type)) self.assertIs( NamedTupleTypeWithPyContainerType.get_container_type( annotated_type.result), NamedTupleTypeWithPyContainerType.get_container_type(result_type)) self.assertEqual(annotated_type.result, result_type)
def _federated_computation_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug orchestration logic in to TFF framework.""" target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) ctx_stack = context_stack_impl.context_stack target_lambda = ( federated_computation_utils.zero_or_one_arg_fn_to_building_block( target_fn, 'arg' if parameter_type else None, parameter_type, ctx_stack, suggested_name=name)) return computation_impl.ComputationImpl(target_lambda.proto, ctx_stack)
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug Tensorflow logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ del name # Unused. target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError('`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( target_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def _federated_computation_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug orchestration logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) ctx_stack = context_stack_impl.context_stack target_lambda = ( federated_computation_utils.zero_or_one_arg_fn_to_building_block( target_fn, 'arg' if parameter_type else None, parameter_type, ctx_stack, suggested_name=name)) return computation_impl.ComputationImpl(target_lambda.proto, ctx_stack)