def test_child_name_doesnt_conflict(self): context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack, suggested_name='FOO') self.assertEqual(context.name, 'FOO') context2 = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack, suggested_name='FOO', parent=context) self.assertEqual(context2.name, 'FOO_1') context3 = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack, suggested_name='FOO', parent=context2) self.assertEqual(context3.name, 'FOO_2')
def zero_or_one_arg_fn_to_building_block( fn, parameter_name: Optional[str], parameter_type: Optional[Any], context_stack: context_stack_base.ContextStack, suggested_name: Optional[str] = None, ) -> building_blocks.ComputationBuildingBlock: """Converts a zero- or one-argument `fn` into a computation building block. Args: fn: A function with 0 or 1 arguments that contains orchestration logic, i.e., that expects zero or one `values_base.Value` and returns a result convertible to the same. parameter_name: The name of the parameter, or `None` if there is't any. parameter_type: The TFF type of the parameter, or `None` if there's none. context_stack: The context stack to use. suggested_name: The optional suggested name to use for the federated context that will be used to serialize this function's body (ideally the name of the underlying Python function). It might be modified to avoid conflicts. Returns: An instance of `building_blocks.ComputationBuildingBlock` that contains the logic from `fn`. Raises: ValueError: if `fn` is incompatible with `parameter_type`. """ py_typecheck.check_callable(fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if suggested_name is not None: py_typecheck.check_type(suggested_name, str) parameter_type = computation_types.to_type(parameter_type) if isinstance(context_stack.current, federated_computation_context.FederatedComputationContext): parent_context = context_stack.current else: parent_context = None context = federated_computation_context.FederatedComputationContext( context_stack, suggested_name=suggested_name, parent=parent_context) if parameter_name is not None: py_typecheck.check_type(parameter_name, str) parameter_name = '{}_{}'.format(context.name, str(parameter_name)) with context_stack.install(context): if parameter_type is not None: result = fn( value_impl.ValueImpl( building_blocks.Reference(parameter_name, parameter_type), context_stack)) else: result = fn() if result is None: raise ValueError( 'The function defined on line {} of file {} has returned a ' '`NoneType`, but all TFF functions must return some non-`None` ' 'value.'.format(fn.__code__.co_firstlineno, fn.__code__.co_filename)) result = value_impl.to_value(result, None, context_stack) result_comp = value_impl.ValueImpl.get_comp(result) return building_blocks.Lambda(parameter_name, parameter_type, result_comp)
def test_invoke_returns_value_with_correct_type(self): context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack) comp = computations.tf_computation(lambda: tf.constant(10)) result = context.invoke(comp, None) self.assertIsInstance(result, value_base.Value) self.assertEqual(str(result.type_signature), 'int32')
def zero_or_one_arg_func_to_building_block(func, parameter_name, parameter_type, context_stack, suggested_name=None): """Converts a zero- or one-argument `func` into a computation building block. Args: func: A function with 0 or 1 arguments that contains orchestration logic, i.e., that expects zero or one `values_base.Value` and returns a result convertible to the same. parameter_name: The name of the parameter, or `None` if there is't any. parameter_type: The TFF type of the parameter, or `None` if there's none. context_stack: The context stack to use. suggested_name: The optional suggested name to use for the federated context that will be used to serialize this function's body (ideally the name of the underlying Python function). It might be modified to avoid conflicts. If not `None`, it must be a string. Returns: An instance of `computation_building_blocks.ComputationBuildingBlock` that contains the logic from `func`. Raises: ValueError: if `func` is incompatible with `parameter_type`. """ py_typecheck.check_callable(func) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if suggested_name is not None: py_typecheck.check_type(suggested_name, six.string_types) parameter_type = computation_types.to_type(parameter_type) if isinstance(context_stack.current, federated_computation_context.FederatedComputationContext): parent_context = context_stack.current else: parent_context = None context = federated_computation_context.FederatedComputationContext( context_stack, suggested_name=suggested_name, parent=parent_context) if parameter_name is not None: py_typecheck.check_type(parameter_name, six.string_types) parameter_name = '{}_{}'.format(context.name, str(parameter_name)) with context_stack.install(context): if parameter_type is not None: result = func( value_impl.ValueImpl( computation_building_blocks.Reference( parameter_name, parameter_type), context_stack)) else: result = func() result = value_impl.to_value(result, None, context_stack) result_comp = value_impl.ValueImpl.get_comp(result) if parameter_type is None: return result_comp else: return computation_building_blocks.Lambda(parameter_name, parameter_type, result_comp)
def test_federated_value_raw_np_scalar(self): with context_stack_impl.context_stack.install( federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack)): floatv = np.float64(0) tff_float = intrinsics.federated_value(floatv, placements.SERVER) self.assertEqual(str(tff_float.type_signature), 'float64@SERVER') intv = np.int64(0) tff_int = intrinsics.federated_value(intv, placements.SERVER) self.assertEqual(str(tff_int.type_signature), 'int64@SERVER')
def test_bind_single_computation_to_reference(self): context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack) data = building_blocks.Data('x', tf.int32) ref = context.bind_computation_to_reference(data) symbol_bindings = context.symbol_bindings bound_symbol_name = symbol_bindings[0][0] self.assertIsInstance(ref, building_blocks.Reference) self.assertEqual(ref.type_signature, data.type_signature) self.assertLen(symbol_bindings, 1) self.assertEqual(bound_symbol_name, ref.name)
def test_something(self): context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack) comp = computations.tf_computation(lambda: tf.constant(10)) result = context.invoke(comp, None) self.assertIsInstance(result, value_base.Value) self.assertEqual(str(result.type_signature), 'int32') self.assertEqual(context.name, 'FEDERATED') context2 = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack, suggested_name='FOO', parent=context) self.assertEqual(context2.name, 'FOO') context3 = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack, suggested_name='FOO', parent=context2) self.assertEqual(context3.name, 'FOO_1') context4 = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack, suggested_name='FOO', parent=context3) self.assertEqual(context4.name, 'FOO_2')
def test_bind_two_computations_to_reference(self): context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack) data = building_blocks.Data('x', tf.int32) float_data = building_blocks.Data('x', tf.float32) ref1 = context.bind_computation_to_reference(data) ref2 = context.bind_computation_to_reference(float_data) symbol_bindings = context.symbol_bindings self.assertIsInstance(ref1, building_blocks.Reference) self.assertIsInstance(ref2, building_blocks.Reference) self.assertEqual(ref1.type_signature, data.type_signature) self.assertEqual(ref2.type_signature, float_data.type_signature) self.assertLen(symbol_bindings, 2) self.assertEqual(symbol_bindings[0][0], ref1.name) self.assertEqual(symbol_bindings[1][0], ref2.name)
computation_types.NamedTupleType( (computation_types.FederatedType(tf.int32, placement_literals.CLIENTS), computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)))) def _(x): x = value_impl.to_value(x, None, _context_stack) value_utils.ensure_federated_value(x) return x def test_ensure_federated_value_fails_on_unzippable(self): @computations.federated_computation( computation_types.NamedTupleType( (computation_types.FederatedType(tf.int32, placement_literals.CLIENTS), computation_types.FederatedType(tf.int32, placement_literals.SERVER)))) def _(x): x = value_impl.to_value(x, None, _context_stack) with self.assertRaises(TypeError): value_utils.ensure_federated_value(x) return x if __name__ == '__main__': with context_stack_impl.context_stack.install( federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack)): absltest.main()
def run(self, result=None): fc_context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack) with context_stack_impl.context_stack.install(fc_context): super(ValueImplTest, self).run(result)
def test_parent_populated_correctly(self): context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack) context2 = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack, parent=context) self.assertIs(context2.parent, context)
def test_suggested_name_populates_name_attribute(self): context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack, suggested_name='FOO') self.assertEqual(context.name, 'FOO')
def test_construction_populates_name(self): context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack) self.assertEqual(context.name, 'FEDERATED')