def federated_map(mapping_fn, value): """Maps a federated value pointwise using a mapping function. The function `mapping_fn` is applied separately across the group of devices represented by the placement type of `value`. For example, if `value` has placement type `tff.CLIENTS`, then `mapping_fn` is applied to each client individually. In particular, this operation does not alter the placement of the federated value. Args: mapping_fn: A mapping function to apply pointwise to member constituents of `value`. The parameter of this function must be of the same type as the member constituents of `value`. value: A value of a TFF federated type (or a value that can be implicitly converted into a TFF federated type, e.g., by zipping) placed at `tff.CLIENTS` or `tff.SERVER`. Returns: A federated value with the same placement as `value` that represents the result of `mapping_fn` on the member constituent of `arg`. Raises: TypeError: If the arguments are not of the appropriate types. """ factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) return factory.federated_map(mapping_fn, value)
def federated_mean(value, weight=None): """Computes a `tff.SERVER` mean of `value` placed on `tff.CLIENTS`. For values `v_1, ..., v_k`, and weights `w_1, ..., w_k`, this means `sum_{i=1}^k (w_i * v_i) / sum_{i=1}^k w_i`. Args: value: The value of which the mean is to be computed. Must be of a TFF federated type placed at `tff.CLIENTS`. The value may be structured, e.g., its member constituents can be named tuples. The tensor types that the value is composed of must be floating-point or complex. weight: An optional weight, a TFF federated integer or floating-point tensor value, also placed at `tff.CLIENTS`. Returns: A representation at the `tff.SERVER` of the mean of the member constituents of `value`, optionally weighted with `weight` if specified (otherwise, the member constituents contributed by all clients are equally weighted). Raises: TypeError: If `value` is not a federated TFF value placed at `tff.CLIENTS`, or if `weight` is not a federated integer or a floating-point tensor with the matching placement. """ factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) return factory.federated_mean(value, weight)
def federated_reduce(value, zero, op): """Reduces `value` from `tff.CLIENTS` to `tff.SERVER` using a reduction `op`. This method reduces a set of member constituents of a `value` of federated type `T@CLIENTS` for some `T`, using a given `zero` in the algebra (i.e., the result of reducing an empty set) of some type `U`, and a reduction operator `op` with type signature `(<U,T> -> U)` that incorporates a single `T`-typed member constituent of `value` into the `U`-typed result of partial reduction. In the special case of `T` equal to `U`, this corresponds to the classical notion of reduction of a set using a commutative associative binary operator. The generalized reduction (with `T` not equal to `U`) requires that repeated application of `op` to reduce a set of `T` always yields the same `U`-typed result, regardless of the order in which elements of `T` are processed in the course of the reduction. Args: value: A value of a TFF federated type placed at the `tff.CLIENTS`. zero: The result of reducing a value with no constituents. op: An operator with type signature `(<U,T> -> U)`, where `T` is the type of the constituents of `value` and `U` is the type of `zero` to be used in performing the reduction. Returns: A representation on the `tff.SERVER` of the result of reducing the set of all member constituents of `value` using the operator `op` into a single item. Raises: TypeError: If the arguments are not of the types specified above. """ factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) return factory.federated_reduce(value, zero, op)
def test_allows_assignable_but_not_equal_zero_and_reduction_types(self): factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) element_type = tf.string zero_type = computation_types.TensorType(tf.string, [1]) reduced_type = computation_types.TensorType(tf.string, [None]) @computations.tf_computation(reduced_type, element_type) @computations.check_returns_type(reduced_type) def append(accumulator, element): return tf.concat([accumulator, [element]], 0) @computations.tf_computation @computations.check_returns_type(zero_type) def zero(): return tf.convert_to_tensor(['The beginning']) @computations.federated_computation( computation_types.at_clients(element_type)) @computations.check_returns_type( computation_types.at_server(reduced_type)) def collect(client_values): return factory.federated_reduce(client_values, zero(), append) self.assertEqual(collect.type_signature.compact_representation(), '({string}@CLIENTS -> string[?]@SERVER)')
def federated_aggregate(value, zero, accumulate, merge, report): """Aggregates `value` from `tff.CLIENTS` to `tff.SERVER`. This generalized aggregation function admits multi-layered architectures that involve one or more intermediate stages to handle scalable aggregation across a very large number of participants. The multi-stage aggregation process is defined as follows: * Clients are organized into groups. Within each group, a set of all the member constituents of `value` contributed by clients in the group are first reduced in a manner similar to `tff.federated_reduce` using reduction operator `accumulate` with `zero` as the zero in the algebra. As described in the documentation for `tff.federated_reduce`, if members of `value` are of type `T`, and `zero` (the result of reducing an empty set) is of type `U`, the reduction operator `accumulate` used at this stage should be of type `(<U,T> -> U)`. The result of this stage is a set of items of type `U`, one item for each group of clients. * Next, the `U`-typed items generated by the preceding stage are merged using the binary commutative associative operator `merge` of type `(<U,U> -> U)`. This can be interpreted as a `tff.federated_reduce` using `merge` as the reduction operator, and the same `zero` in the algebra. The result of this stage is a single top-level `U` that emerges at the root of the hierarchy at the `tff.SERVER`. Actual implementations may structure this step as a cascade of multiple layers. * Finally, the `U`-typed result of the reduction performed in the preceding stage is projected into the result value using `report` as the mapping function (for example, if the structures being merged consist of counters, this final step might include computing their ratios). Args: value: A value of a TFF federated type placed at `tff.CLIENTS` to aggregate. zero: The zero of type `U` in the algebra of reduction operators, as described above. accumulate: The reduction operator to use in the first stage of the process. If `value` is of type `{T}@CLIENTS`, and `zero` is of type `U`, this operator should be of type `(<U,T> -> U)`. merge: The reduction operator to employ in the second stage of the process. Must be of type `(<U,U> -> U)`, where `U` is as defined above. report: The projection operator to use at the final stage of the process to compute the final result of aggregation. If the intended result to be returned by `tff.federated_aggregate` is of type `R@SERVER`, this operator must be of type `(U -> R)`. Returns: A representation on the `tff.SERVER` of the result of aggregating `value` using the multi-stage process described above. Raises: TypeError: If the arguments are not of the types specified above. """ factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) return factory.federated_aggregate(value, zero, accumulate, merge, report)
def federated_secure_sum(value, bitwidth): """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. This function computes a sum such that it should not be possible for the server to learn any clients individual value. The specific algorithm and mechanism used to compute the secure sum may vary depending on the target runtime environment the computation is compiled for or executed on. See https://research.google/pubs/pub47246/ for more information. Not all executors support `tff.federated_secure_sum()`; consult the documentation for the specific executor or executor stack you plan on using for the specific of how it's handled by that executor. The `bitwidth` argument represents the bitwidth of the aggregand, that is the bitwidth of the input `value`. The federated secure sum bitwidth (i.e., the bitwidth of the *sum* of the input `value`s over all clients) will be a function of this bitwidth and the number of participating clients. Example: ```python value = tff.federated_value(1, tff.CLIENTS) result = tff.federated_secure_sum(value, 2) value = tff.federated_value([1, 1], tff.CLIENTS) result = tff.federated_secure_sum(value, [2, 4]) value = tff.federated_value([1, [1, 1]], tff.CLIENTS) result = tff.federated_secure_sum(value, [2, [4, 8]]) ``` Note: To sum non-integer values or to sum integers with fewer constraints and weaker privacy properties, consider using `federated_sum`. Args: value: An integer value of a TFF federated type placed at the `tff.CLIENTS`, in the range [0, 2^bitwidth - 1]. bitwidth: An integer or nested structure of integers matching the structure of `value`. If integer `bitwidth` is used with a nested `value`, the same integer is used for each tensor in `value`. Returns: A representation of the sum of the member constituents of `value` placed on the `tff.SERVER`. Raises: TypeError: If the argument is not a federated TFF value placed at `tff.CLIENTS`. """ factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) return factory.federated_secure_sum(value, bitwidth)
def test_type_signature_with_non_federated_type(self): factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) @computations.tf_computation(np.int32, np.int32) def add(x, y): return x + y @computations.federated_computation( computation_types.SequenceType(np.int32)) def foo(value): return factory.sequence_reduce(value, 0, add) self.assertEqual(foo.type_signature.compact_representation(), '(int32* -> int32)')
def federated_sum(value): """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. Args: value: A value of a TFF federated type placed at the `tff.CLIENTS`. Returns: A representation of the sum of the member constituents of `value` placed on the `tff.SERVER`. Raises: TypeError: if the argument is not a federated TFF value placed at `tff.CLIENTS`. """ factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack) return factory.federated_sum(value)
def federated_collect(value): """Returns a federated value from `tff.CLIENTS` as a `tff.SERVER` sequence. Args: value: A value of a TFF federated type placed at the `tff.CLIENTS`. Returns: A stream of the same type as the member constituents of `value` placed at the `tff.SERVER`. Raises: TypeError: if the argument is not a federated TFF value placed at `tff.CLIENTS`. """ factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack) return factory.federated_collect(value)
def federated_eval(fn, placement): """Evaluates a federated computation at `placement`, returning the result. Args: fn: A no-arg TFF computation. placement: The desired result placement (either `tff.SERVER` or `tff.CLIENTS`). Returns: A federated value with the given placement `placement`. Raises: TypeError: If the arguments are not of the appropriate types. """ factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack) return factory.federated_eval(fn, placement)
def federated_value(value, placement): """Returns a federated value at `placement`, with `value` as the constituent. Args: value: A value of a non-federated TFF type to be placed. placement: The desired result placement (either `tff.SERVER` or `tff.CLIENTS`). Returns: A federated value with the given placement `placement`, and the member constituent `value` equal at all locations. Raises: TypeError: If the arguments are not of the appropriate types. """ factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack) return factory.federated_value(value, placement)
def test_federated_map_all_equal(self): factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) @computations.tf_computation(tf.int32) def add_one(x): return x + 1 @computations.federated_computation def comp(): value = intrinsics.federated_value(10, placement_literals.CLIENTS) return factory.federated_map_all_equal(add_one, value) executor, _ = _create_test_executor() result = _invoke(executor, comp) for value in result: self.assertEqual(value.numpy(), 10 + 1)
def federated_secure_sum(value, bitwidth): """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. This function computes a sum such that it should not be possible for the server to learn any clients individual value. The specific algorithm and mechanism used to compute the secure sum may vary depending on the target runtime environment the computation is compiled for or executed on. See https://research.google/pubs/pub47246/ for more information. Not all executors support `tff.federated_secure_sum()`; consult the documentation for the specific executor or executor stack you plan on using for the specific of how it's handled by that executor. TODO(b/148147384): Describe the semantics of secure sum intrinsic. Example: ```python value = tff.federated_value(1, tff.CLIENTS) result = tff.federated_secure_sum(value, 2) value = tff.federated_value([1, 1], tff.CLIENTS) result = tff.federated_secure_sum(value, [2, 4]) value = tff.federated_value([1, [1, 1]], tff.CLIENTS) result = tff.federated_secure_sum(value, [2, [4, 8]]) ``` NOTE: To sum non-integer values or to sum integers with fewer constraints and weaker privacy properties, consider using `federated_sum`. Args: value: A value of a TFF federated type placed at the `tff.CLIENTS`. bitwidth: An integer or nested structure of integers. Returns: A representation of the sum of the member constituents of `value` placed on the `tff.SERVER`. Raises: TypeError: if the argument is not a federated TFF value placed at `tff.CLIENTS`. """ factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) return factory.federated_secure_sum(value, bitwidth)
def test_type_signature_with_federated_type(self): factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) @computations.tf_computation(np.int32, np.int32) def add(x, y): return x + y @computations.federated_computation( computation_types.FederatedType( computation_types.SequenceType(np.int32), placement_literals.CLIENTS)) def foo(value): zero = intrinsics.federated_value(0, placement_literals.CLIENTS) return factory.sequence_reduce(value, zero, add) self.assertEqual(foo.type_signature.compact_representation(), '({int32*}@CLIENTS -> {int32}@CLIENTS)')
def federated_broadcast(value): """Broadcasts a federated value from the `tff.SERVER` to the `tff.CLIENTS`. Args: value: A value of a TFF federated type placed at the `tff.SERVER`, all members of which are equal (the `tff.FederatedType.all_equal` property of `value` is `True`). Returns: A representation of the result of broadcasting: a value of a TFF federated type placed at the `tff.CLIENTS`, all members of which are equal. Raises: TypeError: if the argument is not a federated TFF value placed at the `tff.SERVER`. """ factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack) return factory.federated_broadcast(value)
def federated_zip(value): """Converts an N-tuple of federated values into a federated N-tuple value. Args: value: A value of a TFF named tuple type, the elements of which are federated values with the same placement. Returns: A federated value placed at the same location as the members of `value`, in which every member component is a named tuple that consists of the corresponding member components of the elements of `value`. Raises: TypeError: if the argument is not a named tuple of federated values with the same placement. """ factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack) return factory.federated_zip(value)
def sequence_sum(value): """Computes a sum of elements in a sequence. Args: value: A value of a TFF type that is either a sequence, or a federated sequence. Returns: The sum of elements in the sequence. If the argument `value` is of a federated type, the result is also of a federated type, with the sum computed locally and independently at each location (see also a discussion on `sequence_map` and `sequence_reduce`). Raises: TypeError: If the arguments are of wrong or unsupported types. """ factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack) return factory.sequence_sum(value)
def federated_apply(func, arg): """Applies a given function to a federated value on the `tff.SERVER`. Args: func: A function to apply to the member content of `arg` on the `tff.SERVER`. The parameter of this function must be of the same type as the member constituent of `arg`. arg: A value of a TFF federated type placed at the `tff.SERVER`, and with the `all_equal` bit set. Returns: A federated value on the `tff.SERVER` that represents the result of applying `func` to the member constituent of `arg`. Raises: TypeError: If the arguments are not of the appropriate types. """ factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack) return factory.federated_apply(func, arg)
def federated_sum(value): """Computes a sum at `tff.SERVER` of a `value` placed on the `tff.CLIENTS`. To sum integer values with stronger privacy properties, consider using `tff.federated_secure_sum`. Args: value: A value of a TFF federated type placed at the `tff.CLIENTS`. Returns: A representation of the sum of the member constituents of `value` placed on the `tff.SERVER`. Raises: TypeError: If the argument is not a federated TFF value placed at `tff.CLIENTS`. """ factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) return factory.federated_sum(value)
def federated_map(mapping_fn, value): """Maps a federated value on `tff.CLIENTS` pointwise using a mapping function. Args: mapping_fn: A mapping function to apply pointwise to member constituents of `value` on each of the participants in `tff.CLIENTS`. The parameter of this function must be of the same type as the member constituents of `value`. value: A value of a TFF federated type placed at the `tff.CLIENTS`, or a value that can be implicitly converted into a TFF federated type, e.g., by zipping. Returns: A federated value on `tff.CLIENTS` that represents the result of mapping. Raises: TypeError: If the arguments are not of the appropriate types. """ factory = intrinsic_factory.IntrinsicFactory(context_stack_impl.context_stack) return factory.federated_map(mapping_fn, value)
def sequence_reduce(value, zero, op): """Reduces a TFF sequence `value` given a `zero` and reduction operator `op`. This method reduces a set of elements of a TFF sequence `value`, using a given `zero` in the algebra (i.e., the result of reducing an empty sequence) of some type `U`, and a reduction operator `op` with type signature `(<U,T> -> U)` that incorporates a single `T`-typed element of `value` into the `U`-typed result of partial reduction. In the special case of `T` equal to `U`, this corresponds to the classical notion of reduction of a set using a commutative associative binary operator. The generalized reduction (with `T` not equal to `U`) requires that repeated application of `op` to reduce a set of `T` always yields the same `U`-typed result, regardless of the order in which elements of `T` are processed in the course of the reduction. One can also invoke `sequence_reduce` on a federated sequence, in which case the reductions are performed pointwise; under the hood, we construct an expression of the form `federated_map(x -> sequence_reduce(x, zero, op), value)`. See also the discussion on `sequence_map`. Note: When applied to a federated value this function does the reduce point-wise. Args: value: A value that is either a TFF sequence, or a federated sequence. zero: The result of reducing a sequence with no elements. op: An operator with type signature `(<U,T> -> U)`, where `T` is the type of the elements of the sequence, and `U` is the type of `zero` to be used in performing the reduction. Returns: The `U`-typed result of reducing elements in the sequence, or if the `value` is federated, a federated `U` that represents the result of locally reducing each member constituent of `value`. Raises: TypeError: If the arguments are not of the types specified above. """ factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) return factory.sequence_reduce(value, zero, op)
def sequence_map(mapping_fn, value): """Maps a TFF sequence `value` pointwise using a given function `mapping_fn`. This function supports two modes of usage: * When applied to a non-federated sequence, it maps individual elements of the sequence pointwise. If the supplied `mapping_fn` is of type `T->U` and the sequence `value` is of type `T*` (a sequence of `T`-typed elements), the result is a sequence of type `U*` (a sequence of `U`-typed elements), with each element of the input sequence individually mapped by `mapping_fn`. In this mode of usage, `sequence_map` behaves like a compuatation with type signature `<T->U,T*> -> U*`. * When applied to a federated sequence, `sequence_map` behaves as if it were individually applied to each member constituent. In this mode of usage, one can think of `sequence_map` as a specialized variant of `federated_map` that is designed to work with sequences and allows one to specify a `mapping_fn` that operates at the level of individual elements. Indeed, under the hood, when `sequence_map` is invoked on a federated type, it injects `federated_map`, thus emitting expressions like `federated_map(a -> sequence_map(mapping_fn, x), value)`. Args: mapping_fn: A mapping function to apply pointwise to elements of `value`. value: A value of a TFF type that is either a sequence, or a federated sequence. Returns: A sequence with the result of applying `mapping_fn` pointwise to each element of `value`, or if `value` was federated, a federated sequence with the result of invoking `sequence_map` on member sequences locally and independently at each location. Raises: TypeError: If the arguments are not of the appropriate types. """ factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) return factory.sequence_map(mapping_fn, value)
def test_federated_map_all_equal(self): factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) @computations.tf_computation(tf.int32) def add_one(x): return x + 1 @computations.federated_computation def comp(): value = intrinsics.federated_value(10, placements.CLIENTS) return factory.federated_map_all_equal(add_one, value) val = _run_test_comp(comp, num_clients=3) self.assertIsInstance(val, federating_executor.FederatingExecutorValue) self.assertEqual(val.type_signature.compact_representation(), 'int32@CLIENTS') self.assertIsInstance(val.internal_representation, list) self.assertLen(val.internal_representation, 3) for v in val.internal_representation: self.assertIsInstance(v, eager_tf_executor.EagerValue) self.assertEqual(v.internal_representation.numpy(), 11)
def get_intrinsic_bodies(context_stack): """Returns a dictionary of intrinsic bodies. Args: context_stack: The context stack to use. """ py_typecheck.check_type(context_stack, context_stack_base.ContextStack) intrinsics = intrinsic_factory.IntrinsicFactory(context_stack) # TODO(b/122728050): Implement reductions that follow roughly the following # breakdown in order to minimize the number of intrinsics that backends need # to support and maximize opportunities for merging processing logic to keep # the number of communication phases as small as it is practical. Perform # these reductions before FEDERATED_SUM (more reductions documented below). # # - FEDERATED_AGGREGATE(x, zero, accu, merge, report) := # GENERIC_MAP( # GENERIC_REDUCE( # GENERIC_PARTIAL_REDUCE(x, zero, accu, INTERMEDIATE_AGGREGATORS), # zero, merge, SERVER), # report) # # - FEDERATED_APPLY(f, x) := GENERIC_APPLY(f, x) # # - FEDERATED_BROADCAST(x) := GENERIC_BROADCAST(x, CLIENTS) # # - FEDERATED_COLLECT(x) := GENERIC_COLLECT(x, SERVER) # # - FEDERATED_MAP(f, x) := GENERIC_MAP(f, x) # # - FEDERATED_VALUE_AT_CLIENTS(x) := GENERIC_PLACE(x, CLIENTS) # # - FEDERATED_VALUE_AT_SERVER(x) := GENERIC_PLACE(x, SERVER) # # - FEDERATED_AVERAGE(x) := FEDERATED_WEIGHTED_AVERAGE( # x, FEDERATED_VALUE_AT_CLIENTS(GENERIC_ONE)) # # - FEDERATED_WEIGHTED_AVERAGE(x, w) := # GENERIC_DIVIDE(FEDERATED_SUM(GENERIC_MULTIPLY(x, w), w)) # def federated_sum(x): zero = intrinsic_utils.zero_for(x.type_signature.member, context_stack) plus = intrinsic_utils.plus_for(x.type_signature.member, context_stack) return intrinsics.federated_reduce(x, zero, plus) # TODO(b/122728050): Implement the remaining (post-FEDERATED_SUM) reductions # as defined below, in the order listed here: # # - FEDERATED_SUM(x) := FEDERATED_REDUCE(x, GENERIC_ZERO, GENERIC_PLUS) # # - FEDERATED_REDUCE(x, zero, op) := # FEDERATED_APPLY(a -> SEQUENCE_REDUCE(a, zero, op), FEDERATED_COLLECT(x)) # # - FEDERATED_ZIP(x, y) := GENERIC_ZIP(x, y) # # - GENERIC_AVERAGE(x: {T}@p, q: placement) := # GENERIC_WEIGHTED_AVERAGE(x, GENERIC_ONE, q) # # - GENERIC_WEIGHTED_AVERAGE(x: {T}@p, w: {U}@p, q: placement) := # GENERIC_MAP(GENERIC_DIVIDE, GENERIC_SUM( # GENERIC_MAP(GENERIC_MULTIPLY, GENERIC_ZIP(x, w)), p)) # # NOTE: The above formula does not account for type casting issues that # arise due to the interplay betwen the types of values and weights and # how they relate to types of products and ratios, and either the formula # or the type signatures may need to be tweaked. # # - GENERIC_SUM(x: {T}@p, q: placement) := # GENERIC_REDUCE(x, GENERIC_ZERO, GENERIC_PLUS, q) # # - GENERIC_PARTIAL_SUM(x: {T}@p, q: placement) := # GENERIC_PARTIAL_REDUCE(x, GENERIC_ZERO, GENERIC_PLUS, q) # # - GENERIC_AGGREGATE( # x: {T}@p, zero: U, accu: <U,T>->U, merge: <U,U>=>U, report: U->R, # q: placement) := # GENERIC_MAP(report, GENERIC_REDUCE(x, zero, accu, q)) # # - GENERIC_REDUCE(x: {T}@p, zero: U, op: <U,T>->U, q: placement) := # GENERIC_MAP((a -> SEQUENCE_REDUCE(a, zero, op)), GENERIC_COLLECT(x, q)) # # - GENERIC_PARTIAL_REDUCE(x: {T}@p, zero: U, op: <U,T>->U, q: placement) := # GENERIC_MAP( # (a -> SEQUENCE_REDUCE(a, zero, op)), GENERIC_PARTIAL_COLLECT(x, q)) # # - SEQUENCE_SUM(x: T*) := # SEQUENCE_REDUCE(x, GENERIC_ZERO, GENERIC_PLUS) # # After performing the full set of reductions, we should only see instances # of the following intrinsics in the result, all of which are currently # considered non-reducible, and intrinsics such as GENERIC_PLUS should apply # only to non-federated, non-sequence types (with the appropriate calls to # GENERIC_MAP or SEQUENCE_MAP injected). # # - GENERIC_APPLY # - GENERIC_BROADCAST # - GENERIC_COLLECT # - GENERIC_DIVIDE # - GENERIC_MAP # - GENERIC_MULTIPLY # - GENERIC_ONE # - GENERIC_ONLY # - GENERIC_PARTIAL_COLLECT # - GENERIC_PLACE # - GENERIC_PLUS # - GENERIC_ZERO # - GENERIC_ZIP # - SEQUENCE_MAP # - SEQUENCE_REDUCE return collections.OrderedDict([(intrinsic_defs.FEDERATED_SUM.uri, federated_sum)])
def get_intrinsic_bodies( context_stack ) -> Dict[str, Callable[[value_impl.ValueImpl], value_impl.ValueImpl]]: """Returns map from intrinsic to reducing function. The returned dictionary is a `collections.OrderedDict` which maps intrinsic URIs to functions from intrinsic argument values to a version of the intrinsic call which has been reduced to a smaller, more fundamental set of intrinsics. Bodies generated by later dictionary entries will not contain references to intrinsics whose entries appear earlier in the dictionary. This property is useful for simple reduction of an entire computation by iterating through the map of intrinsics, substituting calls to each. Args: context_stack: The context stack to use. """ py_typecheck.check_type(context_stack, context_stack_base.ContextStack) intrinsics = intrinsic_factory.IntrinsicFactory(context_stack) # TODO(b/122728050): Implement reductions that follow roughly the following # breakdown in order to minimize the number of intrinsics that backends need # to support and maximize opportunities for merging processing logic to keep # the number of communication phases as small as it is practical. Perform # these reductions before FEDERATED_SUM (more reductions documented below). # # - FEDERATED_AGGREGATE(x, zero, accu, merge, report) := # GENERIC_MAP( # GENERIC_REDUCE( # GENERIC_PARTIAL_REDUCE(x, zero, accu, INTERMEDIATE_AGGREGATORS), # zero, merge, SERVER), # report) # # - FEDERATED_APPLY(f, x) := GENERIC_APPLY(f, x) # # - FEDERATED_BROADCAST(x) := GENERIC_BROADCAST(x, CLIENTS) # # - FEDERATED_COLLECT(x) := GENERIC_COLLECT(x, SERVER) # # - FEDERATED_MAP(f, x) := GENERIC_MAP(f, x) # # - FEDERATED_VALUE_AT_CLIENTS(x) := GENERIC_PLACE(x, CLIENTS) # # - FEDERATED_VALUE_AT_SERVER(x) := GENERIC_PLACE(x, SERVER) def federated_weighted_mean(arg): w = arg[1] multiplied = generic_multiply(arg) summed = federated_sum(intrinsics.federated_zip([multiplied, w])) return generic_divide(summed) def federated_mean(arg): one = value_impl.ValueImpl( building_block_factory.create_generic_constant( arg.type_signature, 1), context_stack) arg = value_impl.to_value([arg, one], None, context_stack) return federated_weighted_mean(arg) def federated_sum(x): operand_type = x.type_signature.member zero = value_impl.ValueImpl( building_block_factory.create_generic_constant(operand_type, 0), context_stack) plus_op = value_impl.ValueImpl( building_block_factory. create_tensorflow_binary_operator_with_upcast( computation_types.StructType([operand_type, operand_type]), tf.add), context_stack) return federated_reduce([x, zero, plus_op]) def federated_reduce(arg): x = arg[0] zero = arg[1] op = arg[2] identity = building_block_factory.create_compiled_identity( op.type_signature.result) return intrinsics.federated_aggregate(x, zero, op, op, identity) def generic_divide(arg): """Divides two arguments when possible.""" return _apply_generic_op(tf.divide, arg[0], arg[1], intrinsics, context_stack) def generic_multiply(arg): """Multiplies two arguments when possible.""" return _apply_generic_op(tf.multiply, arg[0], arg[1], intrinsics, context_stack) def generic_plus(arg): """Adds two arguments when possible.""" return _apply_generic_op(tf.add, arg[0], arg[1], intrinsics, context_stack) # - FEDERATED_ZIP(x, y) := GENERIC_ZIP(x, y) # # - GENERIC_AVERAGE(x: {T}@p, q: placement) := # GENERIC_WEIGHTED_AVERAGE(x, GENERIC_ONE, q) # # - GENERIC_WEIGHTED_AVERAGE(x: {T}@p, w: {U}@p, q: placement) := # GENERIC_MAP(GENERIC_DIVIDE, GENERIC_SUM( # GENERIC_MAP(GENERIC_MULTIPLY, GENERIC_ZIP(x, w)), p)) # # Note: The above formula does not account for type casting issues that # arise due to the interplay betwen the types of values and weights and # how they relate to types of products and ratios, and either the formula # or the type signatures may need to be tweaked. # # - GENERIC_SUM(x: {T}@p, q: placement) := # GENERIC_REDUCE(x, GENERIC_ZERO, GENERIC_PLUS, q) # # - GENERIC_PARTIAL_SUM(x: {T}@p, q: placement) := # GENERIC_PARTIAL_REDUCE(x, GENERIC_ZERO, GENERIC_PLUS, q) # # - GENERIC_AGGREGATE( # x: {T}@p, zero: U, accu: <U,T>->U, merge: <U,U>=>U, report: U->R, # q: placement) := # GENERIC_MAP(report, GENERIC_REDUCE(x, zero, accu, q)) # # - GENERIC_REDUCE(x: {T}@p, zero: U, op: <U,T>->U, q: placement) := # GENERIC_MAP((a -> SEQUENCE_REDUCE(a, zero, op)), GENERIC_COLLECT(x, q)) # # - GENERIC_PARTIAL_REDUCE(x: {T}@p, zero: U, op: <U,T>->U, q: placement) := # GENERIC_MAP( # (a -> SEQUENCE_REDUCE(a, zero, op)), GENERIC_PARTIAL_COLLECT(x, q)) # # - SEQUENCE_SUM(x: T*) := # SEQUENCE_REDUCE(x, GENERIC_ZERO, GENERIC_PLUS) # # After performing the full set of reductions, we should only see instances # of the following intrinsics in the result, all of which are currently # considered non-reducible, and intrinsics such as GENERIC_PLUS should apply # only to non-federated, non-sequence types (with the appropriate calls to # GENERIC_MAP or SEQUENCE_MAP injected). # # - GENERIC_APPLY # - GENERIC_BROADCAST # - GENERIC_COLLECT # - GENERIC_DIVIDE # - GENERIC_MAP # - GENERIC_MULTIPLY # - GENERIC_ONE # - GENERIC_ONLY # - GENERIC_PARTIAL_COLLECT # - GENERIC_PLACE # - GENERIC_PLUS # - GENERIC_ZERO # - GENERIC_ZIP # - SEQUENCE_MAP # - SEQUENCE_REDUCE return collections.OrderedDict([ (intrinsic_defs.FEDERATED_MEAN.uri, federated_mean), (intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri, federated_weighted_mean), (intrinsic_defs.FEDERATED_SUM.uri, federated_sum), (intrinsic_defs.GENERIC_DIVIDE.uri, generic_divide), (intrinsic_defs.GENERIC_MULTIPLY.uri, generic_multiply), (intrinsic_defs.GENERIC_PLUS.uri, generic_plus), ])
def get_intrinsic_bodies(context_stack): """Returns a `collections.OrderedDict` of intrinsic bodies. This dictionary respects the invariant that no body may refer to an intrinsic whose body appears previously in the `dict`. Args: context_stack: The context stack to use. """ py_typecheck.check_type(context_stack, context_stack_base.ContextStack) intrinsics = intrinsic_factory.IntrinsicFactory(context_stack) # TODO(b/122728050): Implement reductions that follow roughly the following # breakdown in order to minimize the number of intrinsics that backends need # to support and maximize opportunities for merging processing logic to keep # the number of communication phases as small as it is practical. Perform # these reductions before FEDERATED_SUM (more reductions documented below). # # - FEDERATED_AGGREGATE(x, zero, accu, merge, report) := # GENERIC_MAP( # GENERIC_REDUCE( # GENERIC_PARTIAL_REDUCE(x, zero, accu, INTERMEDIATE_AGGREGATORS), # zero, merge, SERVER), # report) # # - FEDERATED_APPLY(f, x) := GENERIC_APPLY(f, x) # # - FEDERATED_BROADCAST(x) := GENERIC_BROADCAST(x, CLIENTS) # # - FEDERATED_COLLECT(x) := GENERIC_COLLECT(x, SERVER) # # - FEDERATED_MAP(f, x) := GENERIC_MAP(f, x) # # - FEDERATED_VALUE_AT_CLIENTS(x) := GENERIC_PLACE(x, CLIENTS) # # - FEDERATED_VALUE_AT_SERVER(x) := GENERIC_PLACE(x, SERVER) def _pack_binary_operator_args(x, y): """Packs arguments to binary operator into a single arg.""" def _only_tuple_or_tensor(value): return type_analysis.contains_only_types( value.type_signature, (computation_types.NamedTupleType, computation_types.TensorType)) if _only_tuple_or_tensor(x) and _only_tuple_or_tensor(y): arg = value_impl.ValueImpl( building_blocks.Tuple([ value_impl.ValueImpl.get_comp(x), value_impl.ValueImpl.get_comp(y) ]), context_stack) elif (isinstance(x.type_signature, computation_types.FederatedType) and isinstance(y.type_signature, computation_types.FederatedType) and x.type_signature.placement == y.type_signature.placement): if not type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature.member, y.type_signature.member): raise TypeError( 'The members of the federated types {} and {} are not division ' 'compatible; see `type_utils.is_binary_op_with_upcast_compatible_pair` ' 'for more details.'.format(x.type_signature, y.type_signature)) packed_arg = value_impl.ValueImpl( building_blocks.Tuple([ value_impl.ValueImpl.get_comp(x), value_impl.ValueImpl.get_comp(y) ]), context_stack) arg = intrinsics.federated_zip(packed_arg) else: raise TypeError return arg def _check_top_level_compatibility_with_generic_operators(x, y, op_name): """Performs non-recursive check on the types of `x` and `y`.""" x_compatible = type_analysis.contains_only_types( x.type_signature, (computation_types.NamedTupleType, computation_types.TensorType, computation_types.FederatedType)) y_compatible = type_analysis.contains_only_types( y.type_signature, (computation_types.NamedTupleType, computation_types.TensorType, computation_types.FederatedType)) def _make_bad_type_tree_string(index, type_spec): return ( '{} is only implemented for pairs of ' 'arguments both containing only federated, tuple and ' 'tensor types; you have passed argument at index {} of type {} ' .format(op_name, index, type_spec)) if not (x_compatible and y_compatible): if y_compatible: raise TypeError(_make_bad_type_tree_string( 0, x.type_signature)) elif x_compatible: raise TypeError(_make_bad_type_tree_string( 1, y.type_signature)) else: raise TypeError( '{} is only implemented for pairs of ' 'arguments both containing only federated, tuple and ' 'tensor types; both your arguments fail this condition. ' 'You have passed first argument of type {} ' 'and second argument of type {}.'.format( op_name, x.type_signature, y.type_signature)) top_level_mismatch_string = ( '{} does not accept arguments of type {} and ' '{}, as they are mismatched at the top level.'.format( op_name, x.type_signature, y.type_signature)) if isinstance(x.type_signature, computation_types.FederatedType): if (not isinstance(y.type_signature, computation_types.FederatedType) or x.type_signature.placement != y.type_signature.placement or not type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature.member, y.type_signature.member)): raise TypeError(top_level_mismatch_string) if isinstance(x.type_signature, computation_types.NamedTupleType): if type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature, y.type_signature): return None elif not isinstance(y.type_signature, computation_types.NamedTupleType) or dir( x.type_signature) != dir(y.type_signature): raise TypeError(top_level_mismatch_string) def federated_weighted_mean(arg): w = arg[1] multiplied = generic_multiply(arg) summed = federated_sum(intrinsics.federated_zip([multiplied, w])) return generic_divide(summed) def federated_mean(arg): one = value_impl.ValueImpl( building_block_factory.create_generic_constant( arg.type_signature, 1), context_stack) arg = value_impl.to_value([arg, one], None, context_stack) return federated_weighted_mean(arg) def federated_sum(x): zero = value_impl.ValueImpl( building_block_factory.create_generic_constant( x.type_signature.member, 0), context_stack) plus_op = value_impl.ValueImpl( building_block_factory.create_binary_operator_with_upcast( computation_types.NamedTupleType( [x.type_signature.member, x.type_signature.member]), tf.add), context_stack) return federated_reduce([x, zero, plus_op]) def federated_reduce(arg): x = arg[0] zero = arg[1] op = arg[2] identity = building_block_factory.create_compiled_identity( op.type_signature.result) return intrinsics.federated_aggregate(x, zero, op, op, identity) def _generic_op_can_be_applied(x, y): return type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature, y.type_signature) or isinstance( x.type_signature, computation_types.FederatedType) def _apply_generic_op(op, x, y): arg = _pack_binary_operator_args(x, y) arg_comp = value_impl.ValueImpl.get_comp(arg) result = building_block_factory.apply_binary_operator_with_upcast( arg_comp, op) return value_impl.ValueImpl(result, context_stack) def generic_divide(arg): """Divides two arguments when possible.""" x = arg[0] y = arg[1] _check_top_level_compatibility_with_generic_operators( x, y, 'Generic divide') if _generic_op_can_be_applied(x, y): return _apply_generic_op(tf.divide, x, y) elif isinstance(x.type_signature, computation_types.NamedTupleType): # This case is needed if federated types are nested deeply. names = [ t[0] for t in anonymous_tuple.iter_elements(x.type_signature) ] divided = [ value_impl.ValueImpl.get_comp(generic_divide([x[i], y[i]])) for i in range(len(names)) ] named_divided = building_block_factory.create_named_tuple( building_blocks.Tuple(divided), names) return value_impl.ValueImpl(named_divided, context_stack) else: raise TypeError( 'Generic divide encountered unexpected type {}, {}'.format( x.type_signature, y.type_signature)) def generic_multiply(arg): """Multiplies two arguments when possible.""" x = arg[0] y = arg[1] _check_top_level_compatibility_with_generic_operators( x, y, 'Generic multiply') if _generic_op_can_be_applied(x, y): return _apply_generic_op(tf.multiply, x, y) elif isinstance(x.type_signature, computation_types.NamedTupleType): # This case is needed if federated types are nested deeply. names = [ t[0] for t in anonymous_tuple.iter_elements(x.type_signature) ] multiplied = [ value_impl.ValueImpl.get_comp(generic_multiply([x[i], y[i]])) for i in range(len(names)) ] named_multiplied = building_block_factory.create_named_tuple( building_blocks.Tuple(multiplied), names) return value_impl.ValueImpl(named_multiplied, context_stack) else: raise TypeError( 'Generic multiply encountered unexpected type {}, {}'.format( x.type_signature, y.type_signature)) def generic_plus(arg): """Adds two arguments when possible.""" x = arg[0] y = arg[1] _check_top_level_compatibility_with_generic_operators( x, y, 'Generic plus') if _generic_op_can_be_applied(x, y): return _apply_generic_op(tf.add, x, y) # TODO(b/136587334): Push this logic down a level elif isinstance(x.type_signature, computation_types.NamedTupleType): # This case is needed if federated types are nested deeply. names = [ t[0] for t in anonymous_tuple.iter_elements(x.type_signature) ] added = [ value_impl.ValueImpl.get_comp(generic_plus([x[i], y[i]])) for i in range(len(names)) ] named_added = building_block_factory.create_named_tuple( building_blocks.Tuple(added), names) return value_impl.ValueImpl(named_added, context_stack) else: raise TypeError( 'Generic plus encountered unexpected type {}, {}'.format( x.type_signature, y.type_signature)) # - FEDERATED_ZIP(x, y) := GENERIC_ZIP(x, y) # # - GENERIC_AVERAGE(x: {T}@p, q: placement) := # GENERIC_WEIGHTED_AVERAGE(x, GENERIC_ONE, q) # # - GENERIC_WEIGHTED_AVERAGE(x: {T}@p, w: {U}@p, q: placement) := # GENERIC_MAP(GENERIC_DIVIDE, GENERIC_SUM( # GENERIC_MAP(GENERIC_MULTIPLY, GENERIC_ZIP(x, w)), p)) # # Note: The above formula does not account for type casting issues that # arise due to the interplay betwen the types of values and weights and # how they relate to types of products and ratios, and either the formula # or the type signatures may need to be tweaked. # # - GENERIC_SUM(x: {T}@p, q: placement) := # GENERIC_REDUCE(x, GENERIC_ZERO, GENERIC_PLUS, q) # # - GENERIC_PARTIAL_SUM(x: {T}@p, q: placement) := # GENERIC_PARTIAL_REDUCE(x, GENERIC_ZERO, GENERIC_PLUS, q) # # - GENERIC_AGGREGATE( # x: {T}@p, zero: U, accu: <U,T>->U, merge: <U,U>=>U, report: U->R, # q: placement) := # GENERIC_MAP(report, GENERIC_REDUCE(x, zero, accu, q)) # # - GENERIC_REDUCE(x: {T}@p, zero: U, op: <U,T>->U, q: placement) := # GENERIC_MAP((a -> SEQUENCE_REDUCE(a, zero, op)), GENERIC_COLLECT(x, q)) # # - GENERIC_PARTIAL_REDUCE(x: {T}@p, zero: U, op: <U,T>->U, q: placement) := # GENERIC_MAP( # (a -> SEQUENCE_REDUCE(a, zero, op)), GENERIC_PARTIAL_COLLECT(x, q)) # # - SEQUENCE_SUM(x: T*) := # SEQUENCE_REDUCE(x, GENERIC_ZERO, GENERIC_PLUS) # # After performing the full set of reductions, we should only see instances # of the following intrinsics in the result, all of which are currently # considered non-reducible, and intrinsics such as GENERIC_PLUS should apply # only to non-federated, non-sequence types (with the appropriate calls to # GENERIC_MAP or SEQUENCE_MAP injected). # # - GENERIC_APPLY # - GENERIC_BROADCAST # - GENERIC_COLLECT # - GENERIC_DIVIDE # - GENERIC_MAP # - GENERIC_MULTIPLY # - GENERIC_ONE # - GENERIC_ONLY # - GENERIC_PARTIAL_COLLECT # - GENERIC_PLACE # - GENERIC_PLUS # - GENERIC_ZERO # - GENERIC_ZIP # - SEQUENCE_MAP # - SEQUENCE_REDUCE return collections.OrderedDict([ (intrinsic_defs.FEDERATED_MEAN.uri, federated_mean), (intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri, federated_weighted_mean), (intrinsic_defs.FEDERATED_SUM.uri, federated_sum), (intrinsic_defs.GENERIC_DIVIDE.uri, generic_divide), (intrinsic_defs.GENERIC_MULTIPLY.uri, generic_multiply), (intrinsic_defs.GENERIC_PLUS.uri, generic_plus), ])