def test_sequence_sum(self): @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_sum(x) self.assertEqual(str(foo1.type_signature), '(int32* -> int32)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_sum(x) self.assertEqual(str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_sum(x) self.assertEqual(str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)')
def test_run_encoded_sum(self): value = np.array([0.0, 1.0, 2.0, -1.0]) value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype)) value_type = tff.to_type(value_spec) encoder = te.encoders.as_gather_encoder(te.encoders.identity(), value_spec) gather_fn = encoding_utils.build_encoded_sum(value, encoder) initial_state = gather_fn.initialize() @tff.federated_computation( tff.FederatedType(gather_fn._initialize_fn.type_signature.result, tff.SERVER), tff.FederatedType(value_type, tff.CLIENTS)) def call_gather(state, value): return gather_fn(state, value) _, value_sum = call_gather(initial_state, [value, value]) self.assertAllClose(2 * value, value_sum) _, value_sum = call_gather(initial_state, [value, -value]) self.assertAllClose(0 * value, value_sum) _, value_sum = call_gather(initial_state, [value, 2 * value]) self.assertAllClose(3 * value, value_sum)
def test_call_returned_directly_creates_canonical_form(self): @tff.federated_computation def init_fn(): return tff.federated_value(42, tff.SERVER) @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType( tff.SequenceType(tf.float32), tff.CLIENTS)) def next_fn(server_state, client_data): broadcast_state = tff.federated_broadcast(server_state) @tff.tf_computation(tf.int32, tff.SequenceType(tf.float32)) @tf.function def some_transform(x, y): del y # Unused return x + 1 client_update = tff.federated_map(some_transform, (broadcast_state, client_data)) aggregate_update = tff.federated_sum(client_update) server_output = tff.federated_value(1234, tff.SERVER) return aggregate_update, server_output @tff.federated_computation( tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType(computation_types.SequenceType(tf.float32), tff.CLIENTS)) def nested_next_fn(server_state, client_data): return next_fn(server_state, client_data) iterative_process = computation_utils.IterativeProcess( init_fn, nested_next_fn) cf = canonical_form_utils.get_canonical_form_for_iterative_process( iterative_process) self.assertIsInstance(cf, canonical_form.CanonicalForm)
def test_federated_max_on_nested_scalars(self): tuple_type = tff.NamedTupleType([ ('a', tf.int32), ('b', tf.int32), ]) @tff.federated_computation(tff.FederatedType(tuple_type, tff.CLIENTS)) def call_federated_max(value): return federated_aggregations.federated_max(value) test_type = collections.namedtuple('NestedScalars', ['a', 'b']) value = call_federated_max( [test_type(1, 5), test_type(2, 3), test_type(1, 8)]) self.assertDictEqual(value._asdict(), {'a': 2, 'b': 8})
def test_federated_zip_with_twenty_elements_local_executor(self): n = 20 n_clients = 2 @tff.federated_computation([tff.FederatedType(tf.int32, tff.CLIENTS)] * n) def foo(x): return tff.federated_zip(x) data = [list(range(n_clients)) for _ in range(n)] # This would not have ever returned when local executor was scaling # factorially with number of elements zipped foo(data)
def test_federated_apply_raises_warning(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def foo(x): return tff.federated_apply( tff.tf_computation(lambda x: x * x, tf.int32), x) self.assertLen(w, 1) self.assertIsInstance(w[0].category(), DeprecationWarning) self.assertIn('tff.federated_apply() is deprecated', str(w[0].message)) self.assertEqual(str(foo.type_signature), '(int32@SERVER -> int32@SERVER)')
def test_sequence_reduce(self): add_numbers = tff.tf_computation(tf.add, [tf.int32, tf.int32]) @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual(str(foo1.type_signature), '(int32* -> int32)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual( str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual( str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)')
def test_federated_min_on_nested_scalars(self): tuple_type = tff.NamedTupleType([ ('x', tf.float32), ('y', tf.float32), ]) @tff.federated_computation(tff.FederatedType(tuple_type, tff.CLIENTS)) def call_federated_min(value): return federated_aggregations.federated_min(value) test_type = collections.namedtuple('NestedScalars', ['x', 'y']) value = call_federated_min( [test_type(0.0, 1.0), test_type(-1.0, 5.0), test_type(2.0, -10.0)]) self.assertDictEqual(value._asdict(), {'x': -1.0, 'y': -10.0})
def test_federated_max_nested_tensor_value(self): tuple_type = tff.NamedTupleType([ ('a', (tf.int32, [2])), ('b', (tf.int32, [3])), ]) @tff.federated_computation(tff.FederatedType(tuple_type, tff.CLIENTS)) def call_federated_max(value): return federated_aggregations.federated_max(value) test_type = collections.namedtuple('NestedScalars', ['a', 'b']) client1 = test_type( np.array([4, 5], dtype=np.int32), np.array([1, -2, 3], dtype=np.int32)) client2 = test_type( np.array([9, 0], dtype=np.int32), np.array([5, 1, -2], dtype=np.int32)) value = call_federated_max([client1, client2]) self.assertCountEqual(value[0], [9, 5]) self.assertCountEqual(value[1], [5, 1, 3])
def get_iterative_process_for_canonical_form(cf): """Creates `tff.utils.IterativeProcess` from a canonical form. Args: cf: An instance of `tff.backends.mapreduce.CanonicalForm`. Returns: An instance of `tff.utils.IterativeProcess` that corresponds to `cf`. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(cf, canonical_form.CanonicalForm) @tff.federated_computation def init_computation(): return tff.federated_value(cf.initialize(), tff.SERVER) @tff.federated_computation(init_computation.type_signature.result, tff.FederatedType( cf.work.type_signature.parameter[0], tff.CLIENTS)) def next_computation(arg): """The logic of a single MapReduce sprocessing round.""" s1 = arg[0] c1 = arg[1] s2 = tff.federated_apply(cf.prepare, s1) c2 = tff.federated_broadcast(s2) c3 = tff.federated_zip([c1, c2]) c4 = tff.federated_map(cf.work, c3) c5 = c4[0] c6 = c4[1] s3 = tff.federated_aggregate(c5, cf.zero(), cf.accumulate, cf.merge, cf.report) s4 = tff.federated_zip([s1, s3]) s5 = tff.federated_apply(cf.update, s4) s6 = s5[0] s7 = s5[1] return s6, s7, c6 return computation_utils.IterativeProcess(init_computation, next_computation)
def test_federated_broadcast_with_client_int(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS, True)) def _(x): return tff.federated_broadcast(x)
def check_and_pack_before_aggregate_type_signature(type_spec, previously_packed_types): """Checks types inferred from `before_aggregate` and packs in `previously_packed_types`. After splitting the `after_broadcast` portion of a `tff.utils.IterativeProcess` into `before_aggregate` and `after_aggregate`, `before_aggregate` should have type signature `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`. This function validates `c1`, `s1` and `c2` against the existing entries in `previously_packed_types`, then packs `s5`, `zero`, `accumulate`, `merge` and `report`. Args: type_spec: The `type_signature` attribute of the `before_aggregate` portion of the `tff.utils.IterativeProcess` from which we are looking to extract an instance of `canonical_form.CanonicalForm`. previously_packed_types: Dict containing the information from `next` and `before_broadcast` in the iterative process we are parsing. Returns: A `dict` packing the types which can be inferred from `type_spec`. Raises: TypeError: If `type_signature` is incompatible with `previously_packed_types`. """ should_raise = False if not (isinstance(type_spec, tff.FunctionType) and isinstance(type_spec.parameter, tff.NamedTupleType)): should_raise = True if not (isinstance(type_spec.parameter[0], tff.NamedTupleType) and len(type_spec.parameter[0]) == 2 and type_spec.parameter[0][0] == previously_packed_types['s1_type'] and type_spec.parameter[0][1] == previously_packed_types['c1_type']): should_raise = True if not (isinstance(type_spec.parameter[1], tff.FederatedType) and type_spec.parameter[1].placement == tff.CLIENTS and type_spec.parameter[1].member == previously_packed_types['s2_type'].member): should_raise = True if not (isinstance(type_spec.result, tff.NamedTupleType) and len(type_spec.result) == 5 and isinstance(type_spec.result[0], tff.FederatedType) and type_spec.result[0].placement == tff.CLIENTS and tff_framework.is_tensorflow_compatible_type(type_spec.result[1]) and type_spec.result[2] == tff.FunctionType( [type_spec.result[1], type_spec.result[0].member], type_spec.result[1]) and type_spec.result[3] == tff.FunctionType([type_spec.result[1], type_spec.result[1]], type_spec.result[1]) and type_spec.result[4].parameter == type_spec.result[1] and tff_framework.is_tensorflow_compatible_type( type_spec.result[4].result)): should_raise = True if should_raise: # TODO(b/121290421): These error messages, and indeed the 'track boolean and # raise once' logic of these methods as well, is intended to be provisional # and revisited when we've seen the compilation pipeline fail more clearly, # or maybe preferably iteratively improved as new failure modes are # encountered. raise TypeError( 'Encountered a type error while checking ' '`before_aggregate`. Expected a type signature of the ' 'form `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`, ' 'where `s1` matches {}, `c1` matches {}, and `c2` matches ' 'the result of broadcasting {}, as defined in ' '`canonical_form.CanonicalForm`. Found type signature {}.'.format( previously_packed_types['s1_type'], previously_packed_types['c1_type'], previously_packed_types['s2_type'], type_spec)) newly_determined_types = {} c2_type = type_spec.parameter[1] newly_determined_types['c2_type'] = c2_type c3_type = tff.FederatedType( [previously_packed_types['c1_type'].member, c2_type.member], tff.CLIENTS) newly_determined_types['c3_type'] = c3_type c5_type = type_spec.result[0] zero_type = tff.FunctionType(None, type_spec.result[1]) accumulate_type = type_spec.result[2] merge_type = type_spec.result[3] report_type = type_spec.result[4] newly_determined_types['c5_type'] = c5_type newly_determined_types['zero_type'] = zero_type newly_determined_types['accumulate_type'] = accumulate_type newly_determined_types['merge_type'] = merge_type newly_determined_types['report_type'] = report_type newly_determined_types['s3_type'] = tff.FederatedType( report_type.result, tff.SERVER) c4_type = tff.FederatedType([ newly_determined_types['c5_type'].member, previously_packed_types['c6_type'].member ], tff.CLIENTS) newly_determined_types['c4_type'] = c4_type newly_determined_types['work_type'] = tff.FunctionType( c3_type.member, c4_type.member) return dict( itertools.chain(six.iteritems(previously_packed_types), six.iteritems(newly_determined_types)))
def check_and_pack_after_aggregate_type_signature(type_spec, previously_packed_types): """Checks types inferred from `after_aggregate` and packs in `previously_packed_types`. After splitting the `next` portion of a `tff.utils.IterativeProcess` all the way down, `after_aggregate` should have type signature `<<<s1,c1>,c2>,s3> -> <s6,s7,c6>`. This function validates every element of the above, extracting and packing in addition types of `s3` and `s4`. Args: type_spec: The `type_signature` attribute of the `after_aggregate` portion of the `tff.utils.IterativeProcess` from which we are looking to extract an instance of `canonical_form.CanonicalForm`. previously_packed_types: Dict containing the information from `next`, `before_broadcast` and `before_aggregate` in the iterative process we are parsing. Returns: A `dict` packing the types which can be inferred from `type_spec`. Raises: TypeError: If `type_signature` is incompatible with `previously_packed_types`. """ should_raise = False if not (type_spec.parameter[0][0][0] == previously_packed_types['s1_type'] and type_spec.parameter[0][0][1] == previously_packed_types['c1_type'] and type_spec.parameter[0][1] == previously_packed_types['c2_type'] and type_spec.parameter[1] == previously_packed_types['s3_type']): should_raise = True if not (type_spec.result[0] == previously_packed_types['s6_type'] and type_spec.result[1] == previously_packed_types['s7_type']): should_raise = True if len( type_spec.result ) == 3 and type_spec.result[2] != previously_packed_types['c6_type']: should_raise = True if should_raise: # TODO(b/121290421): These error messages, and indeed the 'track boolean and # raise once' logic of these methods as well, is intended to be provisional # and revisited when we've seen the compilation pipeline fail more clearly, # or maybe preferably iteratively improved as new failure modes are # encountered. raise TypeError( 'Encountered a type error while checking `after_aggregate`; ' 'expected a type signature of the form ' '`<<<s1,c1>,c2>,s3> -> <s6,s7,c6>`, where s1 matches {}, ' 'c1 matches {}, c2 matches {}, s3 matches {}, s6 matches ' '{}, s7 matches {}, c6 matches {}, as defined in ' '`canonical_form.CanonicalForm`. Encountered a type signature ' '{}.'.format(previously_packed_types['s1_type'], previously_packed_types['c1_type'], previously_packed_types['c2_type'], previously_packed_types['s3_type'], previously_packed_types['s6_type'], previously_packed_types['s7_type'], previously_packed_types['c6_type'], type_spec)) s4_type = tff.FederatedType([ previously_packed_types['s1_type'].member, previously_packed_types['s3_type'].member ], tff.SERVER) s5_type = tff.FederatedType([ previously_packed_types['s6_type'].member, previously_packed_types['s7_type'].member ], tff.SERVER) newly_determined_types = {} newly_determined_types['s4_type'] = s4_type newly_determined_types['s5_type'] = s5_type newly_determined_types['update_type'] = tff.FunctionType( s4_type.member, s5_type.member) c3_type = tff.FederatedType([ previously_packed_types['c1_type'].member, previously_packed_types['c2_type'].member ], tff.CLIENTS) newly_determined_types['c3_type'] = c3_type return dict( itertools.chain(six.iteritems(previously_packed_types), six.iteritems(newly_determined_types)))
def zip_selection_as_argument_to_lower_level_lambda(comp, selected_index_lists): r"""Binds selections from the param of `comp` as params to lower-level lambda. Notice that `comp` must be a `tff_framework.Lambda`. The returned pattern is quite important here; given an input lambda `Comp`, we will return an equivalent structure of the form: Lambda(x) | Call / \ Lambda <Selections from x> Where <Selections from x> represents a tuple of selections from the parameter `x`, as specified by `selected_index_lists`. This transform is necessary in order to isolate spurious dependence on arguments that are not in fact used, for example after we have separated processing on the server from that which happens on the clients, but the server-processing still declares some parameters placed at the clients. `selected_index_lists` must be a list of lists. Each list represents a sequence of selections to the parameter of `comp`. For example, if `var` is the parameter of `comp`, the list `[0, 1, 0]` would represent the selection `x[0][1][0]`. The elements of these inner lists must be integers; that is, the selections must be positional. Notice we do not allow for tuples due to automatic unwrapping. Args: comp: Instance of `tff_framework.Lambda`, whose parameters we wish to rebind to a different lambda. selected_index_lists: 2-d list of `int`s, specifying the parameters of `comp` which we wish to rebind as the parameter to a lower-level lambda. Returns: An instance of `tff_framework.Lambda`, equivalent to `comp`, satisfying the pattern above. """ py_typecheck.check_type(comp, tff_framework.Lambda) py_typecheck.check_type(selected_index_lists, list) for selection_list in selected_index_lists: py_typecheck.check_type(selection_list, list) for selected_element in selection_list: py_typecheck.check_type(selected_element, int) original_comp = comp comp = _prepare_for_rebinding(comp) top_level_parameter_type = comp.type_signature.parameter name_generator = tff_framework.unique_name_generator(comp) top_level_parameter_name = comp.parameter_name top_level_parameter_reference = tff_framework.Reference( top_level_parameter_name, comp.parameter_type) type_list = [] for selection_list in selected_index_lists: try: selected_type = top_level_parameter_type for selection in selection_list: selected_type = selected_type[selection] type_list.append(selected_type) except TypeError: six.reraise( TypeError, TypeError( 'You have tried to bind a variable to a nonexistent index in your ' 'lambda parameter type; the selection defined by {} is ' 'inadmissible for the lambda parameter type {}, in the comp {}.' .format(selection_list, top_level_parameter_type, original_comp)), sys.exc_info()[2]) if not all(isinstance(x, tff.FederatedType) for x in type_list): raise TypeError( 'All selected arguments should be of federated type; your selections ' 'have resulted in the list of types {}'.format(type_list)) placement = type_list[0].placement if not all(x.placement is placement for x in type_list): raise ValueError( 'In order to zip the argument to the lower-level lambda together, all ' 'selected arguments should be at the same placement. Your selections ' 'have resulted in the list of types {}'.format(type_list)) arg_to_lower_level_lambda_list = [] for selection_tuple in selected_index_lists: selected_comp = top_level_parameter_reference for selection in selection_tuple: selected_comp = tff_framework.Selection(selected_comp, index=selection) arg_to_lower_level_lambda_list.append(selected_comp) zip_arg = tff_framework.create_federated_zip( tff_framework.Tuple(arg_to_lower_level_lambda_list)) zip_type = tff.FederatedType([x.member for x in type_list], placement=placement) ref_to_zip = tff_framework.Reference(six.next(name_generator), zip_type) selections_from_zip = [ _construct_selection_from_federated_tuple(ref_to_zip, x, name_generator) for x in range(len(selected_index_lists)) ] def _replace_selections_with_new_bindings(inner_comp): """Identifies selection pattern and replaces with new binding. Detecting this pattern is the most brittle part of this rebinding function. It relies on pattern-matching, and right now we cannot guarantee that this pattern is present in every situation we wish to replace with a new binding. Args: inner_comp: Instance of `tff_framework.ComputationBuildingBlock` in which we wish to replace the selections specified by `selected_index_lists` with the parallel new bindings from `selections_from_zip`. Returns: A possibly transformed version of `inner_comp` with nodes matching the selection patterns replaced by their new bindings. """ # TODO(b/135541729): Either come up with a preprocessing way to enforce # this is sufficient, or rework the should_transform predicate. for idx, tup in enumerate(selected_index_lists): selection = inner_comp # Empty selection tuple_pattern_matched = True for selected_index in tup[::-1]: if isinstance( selection, tff_framework.Selection) and selection.index == selected_index: selection = selection.source else: tuple_pattern_matched = False break if tuple_pattern_matched: if isinstance(selection, tff_framework.Reference ) and selection.name == top_level_parameter_name: return selections_from_zip[idx], True return inner_comp, False variables_rebound_in_result, _ = tff_framework.transform_postorder( comp.result, _replace_selections_with_new_bindings) lambda_with_zipped_param = tff_framework.Lambda(ref_to_zip.name, ref_to_zip.type_signature, variables_rebound_in_result) _check_for_missed_binding(comp, lambda_with_zipped_param) zipped_lambda_called = tff_framework.Call(lambda_with_zipped_param, zip_arg) constructed_lambda = tff_framework.Lambda(comp.parameter_name, comp.parameter_type, zipped_lambda_called) names_uniquified, _ = tff_framework.uniquify_reference_names( constructed_lambda) return names_uniquified
class IntrinsicsTest(parameterized.TestCase): def test_federated_broadcast_with_server_all_equal_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def foo(x): return tff.federated_broadcast(x) self.assertEqual(str(foo.type_signature), '(int32@SERVER -> int32@CLIENTS)') def test_federated_broadcast_with_server_non_all_equal_int(self): with self.assertRaises(TypeError): @tff.federated_computation( tff.FederatedType(tf.int32, tff.SERVER, all_equal=False)) def _(x): return tff.federated_broadcast(x) def test_federated_broadcast_with_client_int(self): with self.assertRaises(TypeError): @tff.federated_computation( tff.FederatedType(tf.int32, tff.CLIENTS, True)) def _(x): return tff.federated_broadcast(x) def test_federated_broadcast_with_non_federated_val(self): with self.assertRaises(TypeError): @tff.federated_computation(tf.int32) def _(x): return tff.federated_broadcast(x) def test_federated_map_with_client_all_equal_int(self): @tff.federated_computation( tff.FederatedType(tf.int32, tff.CLIENTS, True)) def foo(x): return tff.federated_map( tff.tf_computation(lambda x: x > 10, tf.int32), x) self.assertEqual(str(foo.type_signature), '(int32@CLIENTS -> {bool}@CLIENTS)') def test_federated_map_with_client_non_all_equal_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): return tff.federated_map( tff.tf_computation(lambda x: x > 10, tf.int32), x) self.assertEqual(str(foo.type_signature), '({int32}@CLIENTS -> {bool}@CLIENTS)') def test_federated_map_with_server_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def foo(x): return tff.federated_map( tff.tf_computation(lambda x: x > 10, tf.int32), x) self.assertEqual(str(foo.type_signature), '(int32@SERVER -> bool@SERVER)') def test_federated_map_injected_zip_with_server_int(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType(tf.int32, tff.SERVER) ]) def foo(x, y): return tff.federated_map( tff.tf_computation(lambda x, y: x > 10, [tf.int32, tf.int32]), [x, y]) self.assertEqual(str(foo.type_signature), '(<int32@SERVER,int32@SERVER> -> bool@SERVER)') def test_federated_map_injected_zip_fails_different_placements(self): def foo(x, y): return tff.federated_map( tff.tf_computation(lambda x, y: x > 10, [tf.int32, tf.int32]), [x, y]) with self.assertRaisesRegex( TypeError, 'You cannot apply federated_map on nested values with mixed ' 'placements.'): tff.federated_computation(foo, [ tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType(tf.int32, tff.CLIENTS) ]) def test_federated_map_with_non_federated_val(self): with self.assertRaises(TypeError): @tff.federated_computation(tf.int32) def _(x): return tff.federated_map( tff.tf_computation(lambda x: x > 10, tf.int32), x) def test_federated_sum_with_client_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): return tff.federated_sum(x) self.assertEqual(str(foo.type_signature), '({int32}@CLIENTS -> int32@SERVER)') def test_federated_sum_with_client_string(self): with self.assertRaises(TypeError): @tff.federated_computation( tff.FederatedType(tf.string, tff.CLIENTS)) def _(x): return tff.federated_sum(x) def test_federated_sum_with_server_int(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def _(x): return tff.federated_sum(x) def test_federated_zip_with_client_non_all_equal_int_and_bool(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.CLIENTS), tff.FederatedType(tf.bool, tff.CLIENTS, True) ]) def foo(x, y): return tff.federated_zip([x, y]) self.assertEqual( str(foo.type_signature), '(<{int32}@CLIENTS,bool@CLIENTS> -> {<int32,bool>}@CLIENTS)') def test_federated_zip_with_single_unnamed_int_client(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.CLIENTS), ]) def foo(x): return tff.federated_zip(x) self.assertEqual(str(foo.type_signature), '(<{int32}@CLIENTS> -> {<int32>}@CLIENTS)') def test_federated_zip_with_single_unnamed_int_server(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.SERVER), ]) def foo(x): return tff.federated_zip(x) self.assertEqual(str(foo.type_signature), '(<int32@SERVER> -> <int32>@SERVER)') def test_federated_zip_with_single_named_bool_clients(self): @tff.federated_computation([ ('a', tff.FederatedType(tf.bool, tff.CLIENTS)), ]) def foo(x): return tff.federated_zip(x) self.assertEqual(str(foo.type_signature), '(<a={bool}@CLIENTS> -> {<a=bool>}@CLIENTS)') def test_federated_zip_with_single_named_bool_server(self): @tff.federated_computation([ ('a', tff.FederatedType(tf.bool, tff.SERVER)), ]) def foo(x): return tff.federated_zip(x) self.assertEqual(str(foo.type_signature), '(<a=bool@SERVER> -> <a=bool>@SERVER)') def test_federated_zip_with_names_client_non_all_equal_int_and_bool(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.CLIENTS), tff.FederatedType(tf.bool, tff.CLIENTS, True) ]) def foo(x, y): a = {'x': x, 'y': y} return tff.federated_zip(a) self.assertEqual( str(foo.type_signature), '(<{int32}@CLIENTS,bool@CLIENTS> -> {<x=int32,y=bool>}@CLIENTS)') def test_federated_zip_with_client_all_equal_int_and_bool(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.CLIENTS, True), tff.FederatedType(tf.bool, tff.CLIENTS, True) ]) def foo(x, y): return tff.federated_zip([x, y]) self.assertEqual( str(foo.type_signature), '(<int32@CLIENTS,bool@CLIENTS> -> {<int32,bool>}@CLIENTS)') def test_federated_zip_with_names_client_all_equal_int_and_bool(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.CLIENTS, True), tff.FederatedType(tf.bool, tff.CLIENTS, True) ]) def foo(arg): a = {'x': arg[0], 'y': arg[1]} return tff.federated_zip(a) self.assertEqual( str(foo.type_signature), '(<int32@CLIENTS,bool@CLIENTS> -> {<x=int32,y=bool>}@CLIENTS)') def test_federated_zip_with_server_int_and_bool(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType(tf.bool, tff.SERVER) ]) def foo(x, y): return tff.federated_zip([x, y]) self.assertEqual( str(foo.type_signature), '(<int32@SERVER,bool@SERVER> -> <int32,bool>@SERVER)') def test_federated_zip_with_names_server_int_and_bool(self): @tff.federated_computation([ ('a', tff.FederatedType(tf.int32, tff.SERVER)), ('b', tff.FederatedType(tf.bool, tff.SERVER)), ]) def foo(arg): return tff.federated_zip(arg) self.assertEqual( str(foo.type_signature), '(<a=int32@SERVER,b=bool@SERVER> -> <a=int32,b=bool>@SERVER)') def test_federated_zip_error_different_placements(self): with self.assertRaisesRegex( TypeError, r'The elements .* must be placed at SERVER. ' r'Element placements: \(SERVER,CLIENTS\)'): @tff.federated_computation([ ('a', tff.FederatedType(tf.int32, tff.SERVER)), ('b', tff.FederatedType(tf.bool, tff.CLIENTS)), ]) def _(arg): return tff.federated_zip(arg) def test_federated_collect_with_client_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): return tff.federated_collect(x) self.assertEqual(str(foo.type_signature), '({int32}@CLIENTS -> int32*@SERVER)') def test_federated_collect_with_server_int_fails(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def _(x): return tff.federated_collect(x) def test_federated_mean_with_client_float32_without_weight(self): @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def foo(x): return tff.federated_mean(x) self.assertEqual(str(foo.type_signature), '({float32}@CLIENTS -> float32@SERVER)') def test_federated_mean_with_all_equal_client_float32_without_weight(self): federated_all_equal_float = tff.FederatedType(tf.float32, tff.CLIENTS, all_equal=True) @tff.federated_computation(federated_all_equal_float) def foo(x): return tff.federated_mean(x) self.assertEqual(str(foo.type_signature), '(float32@CLIENTS -> float32@SERVER)') def test_federated_mean_with_all_equal_client_float32_with_weight(self): federated_all_equal_float = tff.FederatedType(tf.float32, tff.CLIENTS, all_equal=True) @tff.federated_computation(federated_all_equal_float) def foo(x): return tff.federated_mean(x, x) self.assertEqual(str(foo.type_signature), '(float32@CLIENTS -> float32@SERVER)') def test_federated_mean_with_client_tuple_with_int32_weight(self): @tff.federated_computation([ tff.FederatedType([('x', tf.float64), ('y', tf.float64)], tff.CLIENTS), tff.FederatedType(tf.int32, tff.CLIENTS) ]) def foo(x, y): return tff.federated_mean(x, y) self.assertEqual( str(foo.type_signature), '(<{<x=float64,y=float64>}@CLIENTS,{int32}@CLIENTS> ' '-> <x=float64,y=float64>@SERVER)') def test_federated_mean_with_client_int32_fails(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType( tf.int32, tff.CLIENTS)) def _(x): return tff.federated_mean(x) def test_federated_mean_with_string_weight_fails(self): with self.assertRaises(TypeError): @tff.federated_computation([ tff.FederatedType(tf.float32, tff.CLIENTS), tff.FederatedType(tf.string, tff.CLIENTS) ]) def _(x, y): return tff.federated_mean(x, y) def test_federated_aggregate_with_client_int(self): # The representation used during the aggregation process will be a named # tuple with 2 elements - the integer 'total' that represents the sum of # elements encountered, and the integer element 'count'. # pylint: disable=invalid-name Accumulator = collections.namedtuple('Accumulator', 'total count') # pylint: enable=invalid-name accumulator_type = tff.NamedTupleType(Accumulator(tf.int32, tf.int32)) # The operator to use during the first stage simply adds an element to the # total and updates the count. @tff.tf_computation([accumulator_type, tf.int32]) def accumulate(accu, elem): return Accumulator(accu.total + elem, accu.count + 1) # The operator to use during the second stage simply adds total and count. @tff.tf_computation([accumulator_type, accumulator_type]) def merge(x, y): return Accumulator(x.total + y.total, x.count + y.count) # The operator to use during the final stage simply computes the ratio. @tff.tf_computation(accumulator_type) def report(accu): return tf.cast(accu.total, tf.float32) / tf.cast( accu.count, tf.float32) @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): return tff.federated_aggregate(x, Accumulator(0, 0), accumulate, merge, report) self.assertEqual(str(foo.type_signature), '({int32}@CLIENTS -> float32@SERVER)') def test_federated_aggregate_with_federated_zero_fails(self): @tff.federated_computation() def build_federated_zero(): return tff.federated_value(0, tff.SERVER) @tff.tf_computation([tf.int32, tf.int32]) def accumulate(accu, elem): return accu + elem # The operator to use during the second stage simply adds total and count. @tff.tf_computation([tf.int32, tf.int32]) def merge(x, y): return x + y # The operator to use during the final stage simply computes the ratio. @tff.tf_computation(tf.int32) def report(accu): return accu def foo(x): return tff.federated_aggregate(x, build_federated_zero(), accumulate, merge, report) with self.assertRaisesRegex( TypeError, 'Expected `zero` to be assignable to type int32, ' 'but was of incompatible type int32@SERVER'): tff.federated_computation(foo, tff.FederatedType(tf.int32, tff.CLIENTS)) def test_federated_aggregate_with_unknown_dimension(self): Accumulator = collections.namedtuple('Accumulator', ['samples']) # pylint: disable=invalid-name accumulator_type = tff.NamedTupleType( Accumulator(samples=tff.TensorType(dtype=tf.int32, shape=[None]))) @tff.tf_computation() def build_empty_accumulator(): return Accumulator(samples=tf.zeros(shape=[0], dtype=tf.int32)) # The operator to use during the first stage simply adds an element to the # tensor, increasing its size. @tff.tf_computation([accumulator_type, tf.int32]) def accumulate(accu, elem): return Accumulator(samples=tf.concat( [accu.samples, tf.expand_dims(elem, axis=0)], axis=0)) # The operator to use during the second stage simply adds total and count. @tff.tf_computation([accumulator_type, accumulator_type]) def merge(x, y): return Accumulator( samples=tf.concat([x.samples, y.samples], axis=0)) # The operator to use during the final stage simply computes the ratio. @tff.tf_computation(accumulator_type) def report(accu): return accu @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): return tff.federated_aggregate(x, build_empty_accumulator(), accumulate, merge, report) self.assertEqual(str(foo.type_signature), '({int32}@CLIENTS -> <samples=int32[?]>@SERVER)') def test_federated_reduce_with_tf_add_raw_constant(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): plus = tff.tf_computation(tf.add, [tf.int32, tf.int32]) return tff.federated_reduce(x, 0, plus) self.assertEqual(str(foo.type_signature), '({int32}@CLIENTS -> int32@SERVER)') def test_num_over_temperature_threshold_example(self): @tff.federated_computation([ tff.FederatedType(tf.float32, tff.CLIENTS), tff.FederatedType(tf.float32, tff.SERVER) ]) def foo(temperatures, threshold): return tff.federated_sum( tff.federated_map( tff.tf_computation( lambda x, y: tf.cast(tf.greater(x, y), tf.int32), [tf.float32, tf.float32]), [temperatures, tff.federated_broadcast(threshold)])) self.assertEqual( str(foo.type_signature), '(<{float32}@CLIENTS,float32@SERVER> -> int32@SERVER)') @parameterized.named_parameters(('test_n_2', 2), ('test_n_3', 3), ('test_n_5', 5)) def test_n_tuple_federated_zip_tensor_args(self, n): fed_type = tff.FederatedType(tf.int32, tff.CLIENTS) initial_tuple_type = tff.NamedTupleType([fed_type] * n) final_fed_type = tff.FederatedType([tf.int32] * n, tff.CLIENTS) function_type = tff.FunctionType(initial_tuple_type, final_fed_type) type_string = str(function_type) @tff.federated_computation([tff.FederatedType(tf.int32, tff.CLIENTS)] * n) def foo(x): return tff.federated_zip(x) self.assertEqual(str(foo.type_signature), type_string) @parameterized.named_parameters( ('test_n_2_int', 2, tff.FederatedType(tf.int32, tff.CLIENTS)), ('test_n_3_int', 3, tff.FederatedType(tf.int32, tff.CLIENTS)), ('test_n_5_int', 5, tff.FederatedType(tf.int32, tff.CLIENTS)), ('test_n_2_tuple', 2, tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS)), ('test_n_3_tuple', 3, tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS)), ('test_n_5_tuple', 5, tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS))) def test_named_n_tuple_federated_zip(self, n, fed_type): initial_tuple_type = tff.NamedTupleType([fed_type] * n) named_fed_type = tff.FederatedType([(str(k), fed_type.member) for k in range(n)], tff.CLIENTS) mixed_fed_type = tff.FederatedType( [(str(k), fed_type.member) if k % 2 == 0 else fed_type.member for k in range(n)], tff.CLIENTS) named_function_type = tff.FunctionType(initial_tuple_type, named_fed_type) mixed_function_type = tff.FunctionType(initial_tuple_type, mixed_fed_type) named_type_string = str(named_function_type) mixed_type_string = str(mixed_function_type) @tff.federated_computation([fed_type] * n) def foo(x): arg = {str(k): x[k] for k in range(n)} return tff.federated_zip(arg) self.assertEqual(str(foo.type_signature), named_type_string) @tff.federated_computation([fed_type] * n) def bar(x): arg = anonymous_tuple.AnonymousTuple([ (str(k), x[k]) if k % 2 == 0 else (None, x[k]) for k in range(n) ]) return tff.federated_zip(arg) self.assertEqual(str(bar.type_signature), mixed_type_string) @parameterized.named_parameters([ ('test_n_' + str(n) + '_m_' + str(m), n, m) for n, m in itertools.product([1, 2, 3], [1, 2, 3]) ]) def test_n_tuple_federated_zip_mixed_args(self, n, m): tuple_fed_type = tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS) single_fed_type = tff.FederatedType(tf.int32, tff.CLIENTS) initial_tuple_type = tff.NamedTupleType([tuple_fed_type] * n + [single_fed_type] * m) final_fed_type = tff.FederatedType([[tf.int32, tf.int32]] * n + [tf.int32] * m, tff.CLIENTS) function_type = tff.FunctionType(initial_tuple_type, final_fed_type) type_string = str(function_type) @tff.federated_computation([ tff.FederatedType(tff.NamedTupleType([tf.int32, tf.int32]), tff.CLIENTS) ] * n + [tff.FederatedType(tf.int32, tff.CLIENTS)] * m) def baz(x): return tff.federated_zip(x) self.assertEqual(str(baz.type_signature), type_string) def test_federated_apply_raises_warning(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def foo(x): return tff.federated_apply( tff.tf_computation(lambda x: x * x, tf.int32), x) self.assertLen(w, 1) self.assertIsInstance(w[0].category(), DeprecationWarning) self.assertIn('tff.federated_apply() is deprecated', str(w[0].message)) self.assertEqual(str(foo.type_signature), '(int32@SERVER -> int32@SERVER)') def test_federated_value_with_bool_on_clients(self): @tff.federated_computation(tf.bool) def foo(x): return tff.federated_value(x, tff.CLIENTS) self.assertEqual(str(foo.type_signature), '(bool -> bool@CLIENTS)') def test_federated_value_raw_np_scalar(self): @tff.federated_computation def test_np_values(): floatv = np.float64(0) tff_float = tff.federated_value(floatv, tff.SERVER) self.assertEqual(str(tff_float.type_signature), 'float64@SERVER') intv = np.int64(0) tff_int = tff.federated_value(intv, tff.SERVER) self.assertEqual(str(tff_int.type_signature), 'int64@SERVER') return (tff_float, tff_int) floatv, intv = test_np_values() self.assertEqual(floatv, 0.0) self.assertEqual(intv, 0) def test_federated_value_raw_tf_scalar_variable(self): v = tf.Variable(initial_value=0., name='test_var') with self.assertRaisesRegex( TypeError, 'TensorFlow construct (.*) has been ' 'encountered in a federated context.'): _ = tff.federated_value(v, tff.SERVER) def test_federated_value_with_bool_on_server(self): @tff.federated_computation(tf.bool) def foo(x): return tff.federated_value(x, tff.SERVER) self.assertEqual(str(foo.type_signature), '(bool -> bool@SERVER)') def test_sequence_sum(self): @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_sum(x) self.assertEqual(str(foo1.type_signature), '(int32* -> int32)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_sum(x) self.assertEqual(str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_sum(x) self.assertEqual(str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)') def test_sequence_map(self): @tff.tf_computation(tf.int32) def over_threshold(x): return x > 10 @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_map(over_threshold, x) self.assertEqual(str(foo1.type_signature), '(int32* -> bool*)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_map(over_threshold, x) self.assertEqual(str(foo2.type_signature), '(int32*@SERVER -> bool*@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_map(over_threshold, x) self.assertEqual(str(foo3.type_signature), '({int32*}@CLIENTS -> {bool*}@CLIENTS)') def test_sequence_reduce(self): add_numbers = tff.tf_computation(tf.add, [tf.int32, tf.int32]) @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual(str(foo1.type_signature), '(int32* -> int32)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual(str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual(str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)') @core_test.executors( ('local', executor_stacks.create_local_executor()), ) def test_federated_zip_with_twenty_elements_local_executor(self): n = 20 n_clients = 2 @tff.federated_computation([tff.FederatedType(tf.int32, tff.CLIENTS)] * n) def foo(x): return tff.federated_zip(x) data = [list(range(n_clients)) for _ in range(n)] # This would not have ever returned when local executor was scaling # factorially with number of elements zipped foo(data)
def test_federated_sum_with_client_string(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.string, tff.CLIENTS)) def _(x): return tff.federated_sum(x)
def test_federated_collect_with_server_int_fails(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def _(x): return tff.federated_collect(x)
def test_federated_mean_with_client_int32_fails(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def _(x): return tff.federated_mean(x)
class IntrinsicsTest(parameterized.TestCase): def test_federated_broadcast_with_server_all_equal_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def foo(x): return tff.federated_broadcast(x) self.assertEqual(str(foo.type_signature), '(int32@SERVER -> int32@CLIENTS)') def test_federated_broadcast_with_server_non_all_equal_int(self): with self.assertRaises(TypeError): @tff.federated_computation( tff.FederatedType(tf.int32, tff.SERVER, all_equal=False)) def _(x): return tff.federated_broadcast(x) def test_federated_broadcast_with_client_int(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS, True)) def _(x): return tff.federated_broadcast(x) def test_federated_broadcast_with_non_federated_val(self): with self.assertRaises(TypeError): @tff.federated_computation(tf.int32) def _(x): return tff.federated_broadcast(x) def test_federated_map_with_client_all_equal_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS, True)) def foo(x): return tff.federated_map( tff.tf_computation(lambda x: x > 10, tf.int32), x) self.assertEqual( str(foo.type_signature), '(int32@CLIENTS -> {bool}@CLIENTS)') def test_federated_map_with_client_non_all_equal_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): return tff.federated_map( tff.tf_computation(lambda x: x > 10, tf.int32), x) self.assertEqual( str(foo.type_signature), '({int32}@CLIENTS -> {bool}@CLIENTS)') def test_federated_map_with_non_federated_val(self): with self.assertRaises(TypeError): @tff.federated_computation(tf.int32) def _(x): return tff.federated_map( tff.tf_computation(lambda x: x > 10, tf.int32), x) def test_federated_sum_with_client_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): return tff.federated_sum(x) self.assertEqual( str(foo.type_signature), '({int32}@CLIENTS -> int32@SERVER)') def test_federated_sum_with_client_string(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.string, tff.CLIENTS)) def _(x): return tff.federated_sum(x) def test_federated_sum_with_server_int(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def _(x): return tff.federated_sum(x) def test_federated_zip_with_client_non_all_equal_int_and_bool(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.CLIENTS), tff.FederatedType(tf.bool, tff.CLIENTS, True) ]) def foo(x, y): return tff.federated_zip([x, y]) self.assertEqual( str(foo.type_signature), '(<{int32}@CLIENTS,bool@CLIENTS> -> {<int32,bool>}@CLIENTS)') def test_federated_zip_with_single_unnamed_int_client(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.CLIENTS), ]) def foo(x): return tff.federated_zip(x) self.assertEqual( str(foo.type_signature), '(<{int32}@CLIENTS> -> {<int32>}@CLIENTS)') def test_federated_zip_with_single_unnamed_int_server(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.SERVER), ]) def foo(x): return tff.federated_zip(x) self.assertEqual( str(foo.type_signature), '(<int32@SERVER> -> <int32>@SERVER)') def test_federated_zip_with_single_named_bool_clients(self): @tff.federated_computation([ ('a', tff.FederatedType(tf.bool, tff.CLIENTS)), ]) def foo(x): return tff.federated_zip(x) self.assertEqual( str(foo.type_signature), '(<a={bool}@CLIENTS> -> {<a=bool>}@CLIENTS)') def test_federated_zip_with_single_named_bool_server(self): @tff.federated_computation([ ('a', tff.FederatedType(tf.bool, tff.SERVER)), ]) def foo(x): return tff.federated_zip(x) self.assertEqual( str(foo.type_signature), '(<a=bool@SERVER> -> <a=bool>@SERVER)') def test_federated_zip_with_names_client_non_all_equal_int_and_bool(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.CLIENTS), tff.FederatedType(tf.bool, tff.CLIENTS, True) ]) def foo(x, y): a = {'x': x, 'y': y} return tff.federated_zip(a) self.assertEqual( str(foo.type_signature), '(<{int32}@CLIENTS,bool@CLIENTS> -> {<x=int32,y=bool>}@CLIENTS)') def test_federated_zip_with_client_all_equal_int_and_bool(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.CLIENTS, True), tff.FederatedType(tf.bool, tff.CLIENTS, True) ]) def foo(x, y): return tff.federated_zip([x, y]) self.assertEqual( str(foo.type_signature), '(<int32@CLIENTS,bool@CLIENTS> -> {<int32,bool>}@CLIENTS)') def test_federated_zip_with_names_client_all_equal_int_and_bool(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.CLIENTS, True), tff.FederatedType(tf.bool, tff.CLIENTS, True) ]) def foo(arg): a = {'x': arg[0], 'y': arg[1]} return tff.federated_zip(a) self.assertEqual( str(foo.type_signature), '(<int32@CLIENTS,bool@CLIENTS> -> {<x=int32,y=bool>}@CLIENTS)') def test_federated_zip_with_server_int_and_bool(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType(tf.bool, tff.SERVER) ]) def foo(x, y): return tff.federated_zip([x, y]) self.assertEqual( str(foo.type_signature), '(<int32@SERVER,bool@SERVER> -> <int32,bool>@SERVER)') def test_federated_zip_with_names_server_int_and_bool(self): @tff.federated_computation([ ('a', tff.FederatedType(tf.int32, tff.SERVER)), ('b', tff.FederatedType(tf.bool, tff.SERVER)), ]) def foo(arg): return tff.federated_zip(arg) self.assertEqual( str(foo.type_signature), '(<a=int32@SERVER,b=bool@SERVER> -> <a=int32,b=bool>@SERVER)') def test_federated_collect_with_client_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): return tff.federated_collect(x) self.assertEqual( str(foo.type_signature), '({int32}@CLIENTS -> int32*@SERVER)') def test_federated_collect_with_server_int_fails(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def _(x): return tff.federated_collect(x) def test_federated_mean_with_client_float32_without_weight(self): @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS)) def foo(x): return tff.federated_mean(x) self.assertEqual( str(foo.type_signature), '({float32}@CLIENTS -> float32@SERVER)') def test_federated_mean_with_client_tuple_with_int32_weight(self): @tff.federated_computation([ tff.FederatedType([('x', tf.float64), ('y', tf.float64)], tff.CLIENTS), tff.FederatedType(tf.int32, tff.CLIENTS) ]) def foo(x, y): return tff.federated_mean(x, y) self.assertEqual( str(foo.type_signature), '(<{<x=float64,y=float64>}@CLIENTS,{int32}@CLIENTS> ' '-> <x=float64,y=float64>@SERVER)') def test_federated_mean_with_client_int32_fails(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def _(x): return tff.federated_mean(x) def test_federated_mean_with_string_weight_fails(self): with self.assertRaises(TypeError): @tff.federated_computation([ tff.FederatedType(tf.float32, tff.CLIENTS), tff.FederatedType(tf.string, tff.CLIENTS) ]) def _(x, y): return tff.federated_mean(x, y) def test_federated_aggregate_with_client_int(self): # The representation used during the aggregation process will be a named # tuple with 2 elements - the integer 'total' that represents the sum of # elements encountered, and the integer element 'count'. # pylint: disable=invalid-name Accumulator = collections.namedtuple('Accumulator', 'total count') # pylint: enable=invalid-name accumulator_type = tff.NamedTupleType(Accumulator(tf.int32, tf.int32)) # The operator to use during the first stage simply adds an element to the # total and updates the count. @tff.tf_computation([accumulator_type, tf.int32]) def accumulate(accu, elem): return Accumulator(accu.total + elem, accu.count + 1) # The operator to use during the second stage simply adds total and count. @tff.tf_computation([accumulator_type, accumulator_type]) def merge(x, y): return Accumulator(x.total + y.total, x.count + y.count) # The operator to use during the final stage simply computes the ratio. @tff.tf_computation(accumulator_type) def report(accu): return tf.to_float(accu.total) / tf.to_float(accu.count) @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): return tff.federated_aggregate(x, Accumulator(0, 0), accumulate, merge, report) self.assertEqual( str(foo.type_signature), '({int32}@CLIENTS -> float32@SERVER)') def test_federated_reduce_with_tf_add_raw_constant(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) def foo(x): plus = tff.tf_computation(tf.add, [tf.int32, tf.int32]) return tff.federated_reduce(x, 0, plus) self.assertEqual( str(foo.type_signature), '({int32}@CLIENTS -> int32@SERVER)') def test_num_over_temperature_threshold_example(self): @tff.federated_computation([ tff.FederatedType(tf.float32, tff.CLIENTS), tff.FederatedType(tf.float32, tff.SERVER) ]) def foo(temperatures, threshold): return tff.federated_sum( tff.federated_map( tff.tf_computation(lambda x, y: tf.to_int32(tf.greater(x, y)), [tf.float32, tf.float32]), [temperatures, tff.federated_broadcast(threshold)])) self.assertEqual( str(foo.type_signature), '(<{float32}@CLIENTS,float32@SERVER> -> int32@SERVER)') @parameterized.named_parameters(('test_n_2', 2), ('test_n_3', 3), ('test_n_5', 5)) def test_n_tuple_federated_zip_tensor_args(self, n): fed_type = tff.FederatedType(tf.int32, tff.CLIENTS) initial_tuple_type = tff.NamedTupleType([fed_type] * n) final_fed_type = tff.FederatedType([tf.int32] * n, tff.CLIENTS) function_type = tff.FunctionType(initial_tuple_type, final_fed_type) type_string = str(function_type) @tff.federated_computation([tff.FederatedType(tf.int32, tff.CLIENTS)] * n) def foo(x): return tff.federated_zip(x) self.assertEqual(str(foo.type_signature), type_string) @parameterized.named_parameters( ('test_n_2_int', 2, tff.FederatedType(tf.int32, tff.CLIENTS)), ('test_n_3_int', 3, tff.FederatedType(tf.int32, tff.CLIENTS)), ('test_n_5_int', 5, tff.FederatedType(tf.int32, tff.CLIENTS)), ('test_n_2_tuple', 2, tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS)), ('test_n_3_tuple', 3, tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS)), ('test_n_5_tuple', 5, tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS))) def test_named_n_tuple_federated_zip(self, n, fed_type): initial_tuple_type = tff.NamedTupleType([fed_type] * n) named_fed_type = tff.FederatedType( [(str(k), fed_type.member) for k in range(n)], tff.CLIENTS) mixed_fed_type = tff.FederatedType( [(str(k), fed_type.member) if k % 2 == 0 else fed_type.member for k in range(n)], tff.CLIENTS) named_function_type = tff.FunctionType(initial_tuple_type, named_fed_type) mixed_function_type = tff.FunctionType(initial_tuple_type, mixed_fed_type) named_type_string = str(named_function_type) mixed_type_string = str(mixed_function_type) @tff.federated_computation([fed_type] * n) def foo(x): arg = {str(k): x[k] for k in range(n)} return tff.federated_zip(arg) self.assertEqual(str(foo.type_signature), named_type_string) @tff.federated_computation([fed_type] * n) def bar(x): arg = anonymous_tuple.AnonymousTuple([ (str(k), x[k]) if k % 2 == 0 else (None, x[k]) for k in range(n) ]) return tff.federated_zip(arg) self.assertEqual(str(bar.type_signature), mixed_type_string) @parameterized.named_parameters([ ('test_n_' + str(n) + '_m_' + str(m), n, m) for n, m in itertools.product([1, 2, 3], [1, 2, 3]) ]) def test_n_tuple_federated_zip_mixed_args(self, n, m): tuple_fed_type = tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS) single_fed_type = tff.FederatedType(tf.int32, tff.CLIENTS) initial_tuple_type = tff.NamedTupleType([tuple_fed_type] * n + [single_fed_type] * m) final_fed_type = tff.FederatedType([[tf.int32, tf.int32]] * n + [tf.int32] * m, tff.CLIENTS) function_type = tff.FunctionType(initial_tuple_type, final_fed_type) type_string = str(function_type) @tff.federated_computation([ tff.FederatedType( tff.NamedTupleType([tf.int32, tf.int32]), tff.CLIENTS) ] * n + [tff.FederatedType(tf.int32, tff.CLIENTS)] * m) def baz(x): return tff.federated_zip(x) self.assertEqual(str(baz.type_signature), type_string) def test_federated_apply_with_int(self): @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def foo(x): return tff.federated_apply( tff.tf_computation(lambda x: x > 10, tf.int32), x) self.assertEqual(str(foo.type_signature), '(int32@SERVER -> bool@SERVER)') def test_federated_apply_injected_zip_int(self): @tff.federated_computation([ tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType(tf.int32, tff.SERVER) ]) def foo(x, y): return tff.federated_apply( tff.tf_computation(lambda x, y: x > 10, [tf.int32, tf.int32]), [x, y]) self.assertEqual( str(foo.type_signature), '(<int32@SERVER,int32@SERVER> -> bool@SERVER)') def test_federated_value_with_bool_on_clients(self): @tff.federated_computation(tf.bool) def foo(x): return tff.federated_value(x, tff.CLIENTS) self.assertEqual(str(foo.type_signature), '(bool -> bool@CLIENTS)') def test_federated_value_raw_np_scalar(self): @tff.federated_computation def test_np_values(): floatv = np.float64(0) tff_float = tff.federated_value(floatv, tff.SERVER) self.assertEqual(str(tff_float.type_signature), 'float64@SERVER') intv = np.int64(0) tff_int = tff.federated_value(intv, tff.SERVER) self.assertEqual(str(tff_int.type_signature), 'int64@SERVER') return (tff_float, tff_int) floatv, intv = test_np_values() self.assertEqual(floatv, 0.0) self.assertEqual(intv, 0) def test_federated_value_raw_tf_scalar_variable(self): v = tf.Variable(initial_value=0., name='test_var') with self.assertRaisesRegex( TypeError, 'TensorFlow construct (.*) has been ' 'encountered in a federated context.'): _ = tff.federated_value(v, tff.SERVER) def test_federated_value_with_bool_on_server(self): @tff.federated_computation(tf.bool) def foo(x): return tff.federated_value(x, tff.SERVER) self.assertEqual(str(foo.type_signature), '(bool -> bool@SERVER)') def test_sequence_sum(self): @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_sum(x) self.assertEqual(str(foo1.type_signature), '(int32* -> int32)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_sum(x) self.assertEqual( str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_sum(x) self.assertEqual( str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)') def test_sequence_map(self): @tff.tf_computation(tf.int32) def over_threshold(x): return x > 10 @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_map(over_threshold, x) self.assertEqual(str(foo1.type_signature), '(int32* -> bool*)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_map(over_threshold, x) self.assertEqual( str(foo2.type_signature), '(int32*@SERVER -> bool*@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_map(over_threshold, x) self.assertEqual( str(foo3.type_signature), '({int32*}@CLIENTS -> {bool*}@CLIENTS)') def test_sequence_reduce(self): add_numbers = tff.tf_computation(tf.add, [tf.int32, tf.int32]) @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual(str(foo1.type_signature), '(int32* -> int32)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual( str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual( str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)')
def test_federated_broadcast_with_server_non_all_equal_int(self): with self.assertRaises(TypeError): @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER)) def _(x): return tff.federated_broadcast(x)