def test_named_n_tuple_federated_zip(self, n, fed_type): initial_tuple_type = tff.NamedTupleType([fed_type] * n) named_fed_type = tff.FederatedType( [(str(k), fed_type.member) for k in range(n)], tff.CLIENTS) mixed_fed_type = tff.FederatedType( [(str(k), fed_type.member) if k % 2 == 0 else fed_type.member for k in range(n)], tff.CLIENTS) named_function_type = tff.FunctionType(initial_tuple_type, named_fed_type) mixed_function_type = tff.FunctionType(initial_tuple_type, mixed_fed_type) named_type_string = str(named_function_type) mixed_type_string = str(mixed_function_type) @tff.federated_computation([fed_type] * n) def foo(x): arg = {str(k): x[k] for k in range(n)} return tff.federated_zip(arg) self.assertEqual(str(foo.type_signature), named_type_string) @tff.federated_computation([fed_type] * n) def bar(x): arg = anonymous_tuple.AnonymousTuple([ (str(k), x[k]) if k % 2 == 0 else (None, x[k]) for k in range(n) ]) return tff.federated_zip(arg) self.assertEqual(str(bar.type_signature), mixed_type_string)
def test_named_n_tuple_federated_zip(self, n, fed_type): initial_tuple_type = tff.NamedTupleType([fed_type] * n) named_fed_type = tff.FederatedType([(str(k), fed_type.member) for k in range(n)], tff.CLIENTS) mixed_fed_type = tff.FederatedType( [(str(k), fed_type.member) if k % 2 == 0 else fed_type.member for k in range(n)], tff.CLIENTS) named_function_type = tff.FunctionType(initial_tuple_type, named_fed_type) mixed_function_type = tff.FunctionType(initial_tuple_type, mixed_fed_type) named_type_string = str(named_function_type) mixed_type_string = str(mixed_function_type) @tff.federated_computation([fed_type] * n) def foo(x): arg = {str(k): x[k] for k in range(n)} return tff.federated_zip(arg) self.assertEqual(str(foo.type_signature), named_type_string) def _make_test_tuple(x, k): """Make a test tuple with a name if k is even, otherwise unnamed.""" if k % 2 == 0: return str(k), x[k] else: return None, x[k] @tff.federated_computation([fed_type] * n) def bar(x): arg = anonymous_tuple.AnonymousTuple( _make_test_tuple(x, k) for k in range(n)) return tff.federated_zip(arg) self.assertEqual(str(bar.type_signature), mixed_type_string)
def test_fed_comp_typical_usage_as_decorator_with_labeled_type(self): @tff.federated_computation(( ('f', tff.FunctionType(tf.int32, tf.int32)), ('x', tf.int32), )) def foo(f, x): return f(f(x)) @tff.tf_computation(tf.int32) def square(x): return x**2 @tff.tf_computation(tf.int32, tf.int32) def square_drop_y(x, y): # pylint: disable=unused-argument return x * x self.assertEqual(str(foo.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(foo(square, 10), int(1e4)) self.assertEqual(square_drop_y(square_drop_y(10, 5), 100), int(1e4)) self.assertEqual(square_drop_y(square_drop_y(10, 100), 5), int(1e4)) with self.assertRaisesRegexp(TypeError, 'is not assignable from source type'): foo(square_drop_y, 10)
def check_and_pack_before_broadcast_type_signature(type_spec, previously_packed_types): """Checks types inferred from `before_broadcast` and packs in `previously_packed_types`. After splitting the `next` portion of a `tff.utils.IterativeProcess` into `before_broadcast` and `after_broadcast`, `before_broadcast` should have type signature `<s1, c1> -> s2`. This function validates `c1` and `s1` against the existing entries in `previously_packed_types`, then packs `s2`. Args: type_spec: The `type_signature` attribute of the `before_broadcast` portion of the `tff.utils.IterativeProcess` from which we are looking to extract an instance of `canonical_form.CanonicalForm`. previously_packed_types: Dict containing the information from `next` in the iterative process we are parsing. Returns: A `dict` packing the types which can be inferred from `type_signature`. Raises: TypeError: If `type_signature` is incompatible with `previously_packed_types`. """ should_raise = False if not (isinstance(type_spec, tff.FunctionType) and isinstance(type_spec.parameter, tff.NamedTupleType) and len(type_spec.parameter) == 2 and type_spec.parameter[0] == previously_packed_types['s1_type'] and type_spec.parameter[1] == previously_packed_types['c1_type']): should_raise = True if not (isinstance(type_spec.result, tff.FederatedType) and type_spec.result.placement == tff.SERVER): should_raise = True if should_raise: # TODO(b/121290421): These error messages, and indeed the 'track boolean and # raise once' logic of these methods as well, is intended to be provisional # and revisited when we've seen the compilation pipeline fail more clearly, # or maybe preferably iteratively improved as new failure modes are # encountered. raise TypeError( 'We have encountered an error checking the type signature ' 'of `before_broadcast`; expected it to have the form ' '`<s1,c1> -> s2`, with `s1` matching {} and `c1` matching ' '{}, as defined in `connical_form.CanonicalForm`, but ' 'encountered a type spec {}'.format( previously_packed_types['s1_type'], previously_packed_types['c1_type'], type_spec)) s2 = type_spec.result newly_determined_types = {} newly_determined_types['s2_type'] = s2 newly_determined_types['prepare_type'] = tff.FunctionType( previously_packed_types['s1_type'].member, s2.member) return dict( itertools.chain(six.iteritems(previously_packed_types), six.iteritems(newly_determined_types)))
def test_n_tuple_federated_zip_tensor_args(self, n): fed_type = tff.FederatedType(tf.int32, tff.CLIENTS) initial_tuple_type = tff.NamedTupleType([fed_type] * n) final_fed_type = tff.FederatedType([tf.int32] * n, tff.CLIENTS) function_type = tff.FunctionType(initial_tuple_type, final_fed_type) type_string = str(function_type) @tff.federated_computation([tff.FederatedType(tf.int32, tff.CLIENTS)] * n) def foo(x): return tff.federated_zip(x) self.assertEqual(str(foo.type_signature), type_string)
def _normalize_intrinsic_bit(comp): """Replaces federated map all equal with federated map.""" if comp.uri != tff_framework.FEDERATED_MAP_ALL_EQUAL.uri: return comp, False parameter_type = [ comp.type_signature.parameter[0], tff.FederatedType(comp.type_signature.parameter[1].member, tff.CLIENTS) ] intrinsic_type = tff.FunctionType( parameter_type, tff.FederatedType(comp.type_signature.result.member, tff.CLIENTS)) new_intrinsic = tff_framework.Intrinsic(tff_framework.FEDERATED_MAP.uri, intrinsic_type) return new_intrinsic, True
def test_n_tuple_federated_zip_mixed_args(self, n, m): tuple_fed_type = tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS) single_fed_type = tff.FederatedType(tf.int32, tff.CLIENTS) initial_tuple_type = tff.NamedTupleType([tuple_fed_type] * n + [single_fed_type] * m) final_fed_type = tff.FederatedType([[tf.int32, tf.int32]] * n + [tf.int32] * m, tff.CLIENTS) function_type = tff.FunctionType(initial_tuple_type, final_fed_type) type_string = str(function_type) @tff.federated_computation([ tff.FederatedType( tff.NamedTupleType([tf.int32, tf.int32]), tff.CLIENTS) ] * n + [tff.FederatedType(tf.int32, tff.CLIENTS)] * m) def baz(x): return tff.federated_zip(x) self.assertEqual(str(baz.type_signature), type_string)
def test_fed_comp_typical_usage_as_decorator_with_unlabeled_type(self): @tff.federated_computation((tff.FunctionType(tf.int32, tf.int32), tf.int32)) def foo(f, x): assert isinstance(f, tff.Value) assert isinstance(x, tff.Value) assert str(f.type_signature) == '(int32 -> int32)' assert str(x.type_signature) == 'int32' result_value = f(f(x)) assert isinstance(result_value, tff.Value) assert str(result_value.type_signature) == 'int32' return result_value self.assertEqual(str(foo.type_signature), '(<(int32 -> int32),int32> -> int32)') @tff.tf_computation(tf.int32) def third_power(x): return x**3 self.assertEqual(foo(third_power, 10), int(1e9)) self.assertEqual(foo(third_power, 1), 1)
def check_and_pack_after_aggregate_type_signature(type_spec, previously_packed_types): """Checks types inferred from `after_aggregate` and packs in `previously_packed_types`. After splitting the `next` portion of a `tff.utils.IterativeProcess` all the way down, `after_aggregate` should have type signature `<<<s1,c1>,c2>,s3> -> <s6,s7,c6>`. This function validates every element of the above, extracting and packing in addition types of `s3` and `s4`. Args: type_spec: The `type_signature` attribute of the `after_aggregate` portion of the `tff.utils.IterativeProcess` from which we are looking to extract an instance of `canonical_form.CanonicalForm`. previously_packed_types: Dict containing the information from `next`, `before_broadcast` and `before_aggregate` in the iterative process we are parsing. Returns: A `dict` packing the types which can be inferred from `type_spec`. Raises: TypeError: If `type_signature` is incompatible with `previously_packed_types`. """ should_raise = False if not (type_spec.parameter[0][0][0] == previously_packed_types['s1_type'] and type_spec.parameter[0][0][1] == previously_packed_types['c1_type'] and type_spec.parameter[0][1] == previously_packed_types['c2_type'] and type_spec.parameter[1] == previously_packed_types['s3_type']): should_raise = True if not (type_spec.result[0] == previously_packed_types['s6_type'] and type_spec.result[1] == previously_packed_types['s7_type']): should_raise = True if len( type_spec.result ) == 3 and type_spec.result[2] != previously_packed_types['c6_type']: should_raise = True if should_raise: # TODO(b/121290421): These error messages, and indeed the 'track boolean and # raise once' logic of these methods as well, is intended to be provisional # and revisited when we've seen the compilation pipeline fail more clearly, # or maybe preferably iteratively improved as new failure modes are # encountered. raise TypeError( 'Encountered a type error while checking `after_aggregate`; ' 'expected a type signature of the form ' '`<<<s1,c1>,c2>,s3> -> <s6,s7,c6>`, where s1 matches {}, ' 'c1 matches {}, c2 matches {}, s3 matches {}, s6 matches ' '{}, s7 matches {}, c6 matches {}, as defined in ' '`canonical_form.CanonicalForm`. Encountered a type signature ' '{}.'.format(previously_packed_types['s1_type'], previously_packed_types['c1_type'], previously_packed_types['c2_type'], previously_packed_types['s3_type'], previously_packed_types['s6_type'], previously_packed_types['s7_type'], previously_packed_types['c6_type'], type_spec)) s4_type = tff.FederatedType([ previously_packed_types['s1_type'].member, previously_packed_types['s3_type'].member ], tff.SERVER) s5_type = tff.FederatedType([ previously_packed_types['s6_type'].member, previously_packed_types['s7_type'].member ], tff.SERVER) newly_determined_types = {} newly_determined_types['s4_type'] = s4_type newly_determined_types['s5_type'] = s5_type newly_determined_types['update_type'] = tff.FunctionType( s4_type.member, s5_type.member) c3_type = tff.FederatedType([ previously_packed_types['c1_type'].member, previously_packed_types['c2_type'].member ], tff.CLIENTS) newly_determined_types['c3_type'] = c3_type return dict( itertools.chain(six.iteritems(previously_packed_types), six.iteritems(newly_determined_types)))
def check_and_pack_before_aggregate_type_signature(type_spec, previously_packed_types): """Checks types inferred from `before_aggregate` and packs in `previously_packed_types`. After splitting the `after_broadcast` portion of a `tff.utils.IterativeProcess` into `before_aggregate` and `after_aggregate`, `before_aggregate` should have type signature `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`. This function validates `c1`, `s1` and `c2` against the existing entries in `previously_packed_types`, then packs `s5`, `zero`, `accumulate`, `merge` and `report`. Args: type_spec: The `type_signature` attribute of the `before_aggregate` portion of the `tff.utils.IterativeProcess` from which we are looking to extract an instance of `canonical_form.CanonicalForm`. previously_packed_types: Dict containing the information from `next` and `before_broadcast` in the iterative process we are parsing. Returns: A `dict` packing the types which can be inferred from `type_spec`. Raises: TypeError: If `type_signature` is incompatible with `previously_packed_types`. """ should_raise = False if not (isinstance(type_spec, tff.FunctionType) and isinstance(type_spec.parameter, tff.NamedTupleType)): should_raise = True if not (isinstance(type_spec.parameter[0], tff.NamedTupleType) and len(type_spec.parameter[0]) == 2 and type_spec.parameter[0][0] == previously_packed_types['s1_type'] and type_spec.parameter[0][1] == previously_packed_types['c1_type']): should_raise = True if not (isinstance(type_spec.parameter[1], tff.FederatedType) and type_spec.parameter[1].placement == tff.CLIENTS and type_spec.parameter[1].member == previously_packed_types['s2_type'].member): should_raise = True if not (isinstance(type_spec.result, tff.NamedTupleType) and len(type_spec.result) == 5 and isinstance(type_spec.result[0], tff.FederatedType) and type_spec.result[0].placement == tff.CLIENTS and tff_framework.is_tensorflow_compatible_type(type_spec.result[1]) and type_spec.result[2] == tff.FunctionType( [type_spec.result[1], type_spec.result[0].member], type_spec.result[1]) and type_spec.result[3] == tff.FunctionType([type_spec.result[1], type_spec.result[1]], type_spec.result[1]) and type_spec.result[4].parameter == type_spec.result[1] and tff_framework.is_tensorflow_compatible_type( type_spec.result[4].result)): should_raise = True if should_raise: # TODO(b/121290421): These error messages, and indeed the 'track boolean and # raise once' logic of these methods as well, is intended to be provisional # and revisited when we've seen the compilation pipeline fail more clearly, # or maybe preferably iteratively improved as new failure modes are # encountered. raise TypeError( 'Encountered a type error while checking ' '`before_aggregate`. Expected a type signature of the ' 'form `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`, ' 'where `s1` matches {}, `c1` matches {}, and `c2` matches ' 'the result of broadcasting {}, as defined in ' '`canonical_form.CanonicalForm`. Found type signature {}.'.format( previously_packed_types['s1_type'], previously_packed_types['c1_type'], previously_packed_types['s2_type'], type_spec)) newly_determined_types = {} c2_type = type_spec.parameter[1] newly_determined_types['c2_type'] = c2_type c3_type = tff.FederatedType( [previously_packed_types['c1_type'].member, c2_type.member], tff.CLIENTS) newly_determined_types['c3_type'] = c3_type c5_type = type_spec.result[0] zero_type = tff.FunctionType(None, type_spec.result[1]) accumulate_type = type_spec.result[2] merge_type = type_spec.result[3] report_type = type_spec.result[4] newly_determined_types['c5_type'] = c5_type newly_determined_types['zero_type'] = zero_type newly_determined_types['accumulate_type'] = accumulate_type newly_determined_types['merge_type'] = merge_type newly_determined_types['report_type'] = report_type newly_determined_types['s3_type'] = tff.FederatedType( report_type.result, tff.SERVER) c4_type = tff.FederatedType([ newly_determined_types['c5_type'].member, previously_packed_types['c6_type'].member ], tff.CLIENTS) newly_determined_types['c4_type'] = c4_type newly_determined_types['work_type'] = tff.FunctionType( c3_type.member, c4_type.member) return dict( itertools.chain(six.iteritems(previously_packed_types), six.iteritems(newly_determined_types)))
def __init__(self, initialize, prepare, work, zero, accumulate, merge, report, update): """Constructs a representation of a MapReduce-like iterative process. NOTE: All the computations supplied here as arguments must be TensorFlow computations, i.e., instances of `tff.Computation` constructed by the `tff.tf_computation` decorator/wrapper. Args: initialize: The computation that produces the initial server state. prepare: The computation that prepares the input for the clients. work: The client-side work computation. zero: The computation that produces the initial state for accumulators. accumulate: The computation that adds a client update to an accumulator. merge: The computation to use for merging pairs of accumulators. report: The computation that produces the final server-side aggregate for the top level accumulator (the global update). update: The computation that takes the global update and the server state and produces the new server state, as well as server-side output. Raises: TypeError: If the Python or TFF types of the arguments are invalid or not compatible with each other. AssertionError: If the manner in which the given TensorFlow computations are represented by TFF does not match what this code is expecting (this is an internal error that requires code update). """ for label, comp in [ ('initialize', initialize), ('prepare', prepare), ('work', work), ('zero', zero), ('accumulate', accumulate), ('merge', merge), ('report', report), ('update', update), ]: py_typecheck.check_type(comp, tff.Computation, label) # TODO(b/130633916): Remove private access once an appropriate API for it # becomes available. comp_proto = comp._computation_proto # pylint: disable=protected-access if not isinstance(comp_proto, computation_pb2.Computation): # Explicitly raised to force it to be done in non-debug mode as well. raise AssertionError('Cannot find the embedded computation definition.') which_comp = comp_proto.WhichOneof('computation') if which_comp != 'tensorflow': raise TypeError('Expected all computations supplied as arguments to ' 'be plain TensorFlow, found {}.'.format(which_comp)) if prepare.type_signature.parameter != initialize.type_signature.result: raise TypeError( 'The `prepare` computation expects an argument of type {}, ' 'which does not match the result type {} of `initialize`.'.format( prepare.type_signature.parameter, initialize.type_signature.result)) if (not isinstance(work.type_signature.parameter, tff.NamedTupleType) or len(work.type_signature.parameter) != 2): raise TypeError( 'The `work` computation expects an argument of type {} that is not ' 'a two-tuple.'.format(work.type_signature.parameter)) if work.type_signature.parameter[1] != prepare.type_signature.result: raise TypeError( 'The `work` computation expects an argument tuple with type {} as ' 'the second element (the initial client state from the server), ' 'which does not match the result type {} of `prepare`.'.format( work.type_signature.parameter[1], prepare.type_signature.result)) if (not isinstance(work.type_signature.result, tff.NamedTupleType) or len(work.type_signature.result) != 2): raise TypeError( 'The `work` computation returns a result of type {} that is not a ' 'two-tuple.'.format(work.type_signature.result)) expected_accumulate_type = tff.FunctionType( [zero.type_signature.result, work.type_signature.result[0]], zero.type_signature.result) if accumulate.type_signature != expected_accumulate_type: raise TypeError( 'The `accumulate` computation has type signature {}, which does ' 'not match the expected {} as implied by the type signatures of ' '`zero` and `work`.'.format(accumulate.type_signature, expected_accumulate_type)) expected_merge_type = tff.FunctionType( [accumulate.type_signature.result, accumulate.type_signature.result], accumulate.type_signature.result) if merge.type_signature != expected_merge_type: raise TypeError( 'The `merge` computation has type signature {}, which does ' 'not match the expected {} as implied by the type signature ' 'of `accumulate`.'.format(merge.type_signature, expected_merge_type)) if report.type_signature.parameter != merge.type_signature.result: raise TypeError( 'The `report` computation expects an argument of type {}, ' 'which does not match the result type {} of `merge`.'.format( report.type_signature.parameter, merge.type_signature.result)) expected_update_parameter_type = tff.to_type( [initialize.type_signature.result, report.type_signature.result]) if update.type_signature.parameter != expected_update_parameter_type: raise TypeError( 'The `update` computation expects an argument of type {}, ' 'which does not match the expected {} as implied by the type ' 'signatures of `initialize` and `report`.'.format( update.type_signature.parameter, expected_update_parameter_type)) if (not isinstance(update.type_signature.result, tff.NamedTupleType) or len(update.type_signature.result) != 2): raise TypeError( 'The `update` computation returns a result of type {} that is not ' 'a two-tuple.'.format(update.type_signature.result)) if update.type_signature.result[0] != initialize.type_signature.result: raise TypeError( 'The `update` computation returns a result tuple with type {} as ' 'the first element (the updated state of the server), which does ' 'not match the result type {} of `initialize`.'.format( update.type_signature.result[0], initialize.type_signature.result)) self._initialize = initialize self._prepare = prepare self._work = work self._zero = zero self._accumulate = accumulate self._merge = merge self._report = report self._update = update