Esempio n. 1
0
def build_encoded_sum(values, encoders):
  """Builds `StatefulAggregateFn` for `values`, to be encoded by `encoders`.

  Args:
    values: Values to be encoded by the `StatefulAggregateFn`. Must be
      convertible to `tff.Value`.
    encoders: A collection of `GatherEncoder` objects to be used for encoding
      `values`. Must have the same structure as `values`.

  Returns:
    A `StatefulAggregateFn` of which `next_fn` encodes the input at
    `tff.CLIENTS`, and computes their sum at `tff.SERVER`, automatically
    splitting the decoding part based on its commutativity with sum.

  Raises:
    ValueError: If `values` and `encoders` do not have the same structure.
    TypeError: If `encoders` are not instances of `GatherEncoder`, or if
      `values` are not compatible with the expected input of the `encoders`.
  """

  tf.nest.assert_same_structure(values, encoders)
  tf.nest.map_structure(
      lambda e, v: _validate_encoder(e, v, tensor_encoding.core.GatherEncoder),
      encoders, values)

  value_type = type_utils.type_from_tensors(values)

  initial_state_fn, state_type = _build_initial_state_tf_computation(encoders)

  nest_encoder = _build_tf_computations_for_gather(state_type, value_type,
                                                   encoders)
  encoded_sum_fn = _build_encoded_sum_fn(nest_encoder)

  return computation_utils.StatefulAggregateFn(
      initialize_fn=initial_state_fn, next_fn=encoded_sum_fn)
Esempio n. 2
0
def build_encoded_mean(values, encoders):
    """Builds `StatefulAggregateFn` for `values`, to be encoded by `encoders`.

  Args:
    values: Values to be encoded by the `StatefulAggregateFn`. Must be
      convertible to `tff.Value`.
    encoders: A collection of `GatherEncoder` objects to be used for encoding
      `values`. Must have the same structure as `values`.

  Returns:
    A `StatefulAggregateFn` of which `next_fn` encodes the input at
    `tff.CLIENTS`, and computes their mean at `tff.SERVER`, automatically
    splitting the decoding part based on its commutativity with sum.

  Raises:
    ValueError: If `values` and `encoders` do not have the same structure.
    TypeError: If `encoders` are not instances of `GatherEncoder`, or if
      `values` are not compatible with the expected input of the `encoders`.
  """

    tf.nest.assert_same_structure(values, encoders)
    tf.nest.map_structure(
        lambda e, v: _validate_encoder(e, v, tensor_encoding.core.GatherEncoder
                                       ), encoders, values)

    value_type = type_utils.type_from_tensors(values)

    initial_state_fn = _build_initial_state_tf_computation(encoders)
    state_type = initial_state_fn.type_signature.result

    nest_encoder = _build_tf_computations_for_gather(state_type, value_type,
                                                     encoders)
    encoded_sum_fn = _build_encoded_sum_fn(nest_encoder)

    @computations.tf_computation(value_type,
                                 computation_types.to_type(tf.float32))
    def multiply_fn(value, weight):
        return tf.nest.map_structure(lambda v: v * tf.cast(weight, v.dtype),
                                     value)

    @computations.tf_computation(value_type,
                                 computation_types.to_type(tf.float32))
    def divide_fn(value, denominator):
        return tf.nest.map_structure(
            lambda v: v / tf.cast(denominator, v.dtype), value)

    def encoded_mean_fn(state, values, weight):
        weighted_values = intrinsics.federated_map(multiply_fn,
                                                   [values, weight])
        updated_state, summed_decoded_values = encoded_sum_fn(
            state, weighted_values)
        summed_weights = intrinsics.federated_sum(weight)
        decoded_values = intrinsics.federated_map(
            divide_fn, [summed_decoded_values, summed_weights])
        return updated_state, decoded_values

    return computation_utils.StatefulAggregateFn(
        initialize_fn=initial_state_fn, next_fn=encoded_mean_fn)
Esempio n. 3
0
def wrap_aggregate_fn(dp_aggregate_fn, sample_value):
    tff_types = type_utils.type_from_tensors(sample_value)

    @computations.federated_computation
    def run_initialize():
        return intrinsics.federated_value(dp_aggregate_fn.initialize(),
                                          placement_literals.SERVER)

    @computations.federated_computation(run_initialize.type_signature.result,
                                        computation_types.FederatedType(
                                            tff_types,
                                            placement_literals.CLIENTS))
    def run_aggregate(global_state, client_values):
        return dp_aggregate_fn(global_state, client_values)

    return run_initialize, run_aggregate
Esempio n. 4
0
def build_encoded_broadcast(values, encoders):
    """Builds `StatefulBroadcastFn` for `values`, to be encoded by `encoders`.

  Args:
    values: Values to be broadcasted by the `StatefulBroadcastFn`. Must be
      convertible to `tff.Value`.
    encoders: A collection of `SimpleEncoder` objects to be used for encoding
      `values`. Must have the same structure as `values`.

  Returns:
    A `StatefulBroadcastFn` of which `next_fn` encodes the input at
    `tff.SERVER`, broadcasts the encoded representation and decodes the encoded
    representation at `tff.CLIENTS`.

  Raises:
    ValueError: If `values` and `encoders` do not have the same structure.
    TypeError: If `encoders` are not instances of `SimpleEncoder`, or if
      `values` are not compatible with the expected input of the `encoders`.
  """

    tf.nest.assert_same_structure(values, encoders)
    tf.nest.map_structure(
        lambda e, v: _validate_encoder(e, v, tensor_encoding.core.SimpleEncoder
                                       ), encoders, values)

    value_type = type_utils.type_from_tensors(values)

    initial_state_fn = _build_initial_state_tf_computation(encoders)
    state_type = initial_state_fn.type_signature.result

    encode_fn, decode_fn = _build_encode_decode_tf_computations_for_broadcast(
        state_type, value_type, encoders)

    def encoded_broadcast_fn(state, value):
        """Encoded broadcast federated_computation."""
        new_state, encoded_value = intrinsics.federated_map(
            encode_fn, (state, value))
        client_encoded_value = intrinsics.federated_broadcast(encoded_value)
        client_value = intrinsics.federated_map(decode_fn,
                                                client_encoded_value)
        return new_state, client_value

    return computation_utils.StatefulBroadcastFn(
        initialize_fn=initial_state_fn, next_fn=encoded_broadcast_fn)