def test_execute(self):
        broadcast_fn = computation_utils.StatefulBroadcastFn(
            initialize_fn=broadcast_initialize_fn, next_fn=broadcast_next_fn)
        broadcast_arg_type = computation_types.FederatedType(
            tf.float32, placements.SERVER)

        @computations.federated_computation(broadcast_arg_type)
        def federated_broadcast_test(args):
            state = intrinsics.federated_value(broadcast_fn.initialize(),
                                               placements.SERVER)
            return broadcast_fn(state, args)

        expected_type_signature = computation_types.FunctionType(
            parameter=broadcast_arg_type,
            result=computation_types.NamedTupleType([
                computation_types.FederatedType(
                    collections.OrderedDict([('call_count', tf.int32)]),
                    placements.SERVER),
                computation_types.FederatedType(tf.float32,
                                                placements.CLIENTS,
                                                all_equal=True)
            ]))
        self.assertEqual(federated_broadcast_test.type_signature,
                         expected_type_signature)
        state, value = federated_broadcast_test(1.0)
        self.assertAlmostEqual(value, 1.0)
        self.assertDictEqual(state._asdict(), {'call_count': 1})
Ejemplo n.º 2
0
 def test_fails_stateful_broadcast_and_process(self):
   model_weights_type = model_utils.weights_type_from_model(
       model_examples.LinearRegression)
   with self.assertRaises(optimizer_utils.DisjointArgumentError):
     optimizer_utils.build_model_delta_optimizer_process(
         model_fn=model_examples.LinearRegression,
         model_to_client_delta_fn=DummyClientDeltaFn,
         server_optimizer_fn=tf.keras.optimizers.SGD,
         stateful_model_broadcast_fn=computation_utils.StatefulBroadcastFn(
             initialize_fn=lambda: (),
             next_fn=lambda state, weights:  # pylint: disable=g-long-lambda
             (state, intrinsics.federated_broadcast(weights))),
         broadcast_process=optimizer_utils.build_stateless_broadcaster(
             model_weights_type=model_weights_type))
Ejemplo n.º 3
0
def build_encoded_broadcast(values, encoders):
    """Builds `StatefulBroadcastFn` for `values`, to be encoded by `encoders`.

  Args:
    values: Values to be broadcasted by the `StatefulBroadcastFn`. Must be
      convertible to `tff.Value`.
    encoders: A collection of `SimpleEncoder` objects to be used for encoding
      `values`. Must have the same structure as `values`.

  Returns:
    A `StatefulBroadcastFn` of which `next_fn` encodes the input at
    `tff.SERVER`, broadcasts the encoded representation and decodes the encoded
    representation at `tff.CLIENTS`.

  Raises:
    ValueError: If `values` and `encoders` do not have the same structure.
    TypeError: If `encoders` are not instances of `SimpleEncoder`, or if
      `values` are not compatible with the expected input of the `encoders`.
  """
    warnings.warn(
        'Deprecation warning: tff.utils.build_encoded_broadcast() is deprecated, '
        'use tff.utils.build_encoded_broadcast_process() instead.',
        DeprecationWarning)

    tf.nest.assert_same_structure(values, encoders)
    tf.nest.map_structure(
        lambda e, v: _validate_encoder(e, v, tensor_encoding.core.SimpleEncoder
                                       ), encoders, values)

    value_type = type_conversions.type_from_tensors(values)

    initial_state_fn, state_type = _build_initial_state_tf_computation(
        encoders)

    encode_fn, decode_fn = _build_encode_decode_tf_computations_for_broadcast(
        state_type, value_type, encoders)

    def encoded_broadcast_fn(state, value):
        """Encoded broadcast federated_computation."""
        new_state, encoded_value = intrinsics.federated_map(
            encode_fn, (state, value))
        client_encoded_value = intrinsics.federated_broadcast(encoded_value)
        client_value = intrinsics.federated_map(decode_fn,
                                                client_encoded_value)
        return new_state, client_value

    return computation_utils.StatefulBroadcastFn(
        initialize_fn=initial_state_fn, next_fn=encoded_broadcast_fn)
Ejemplo n.º 4
0
def _state_incrementing_mean_next(server_state, client_value, weight=None):
    new_state = intrinsics.federated_map(_add_one, server_state)
    return (new_state, intrinsics.federated_mean(client_value, weight=weight))


state_incrementing_mean = computation_utils.StatefulAggregateFn(
    lambda: tf.constant(0), _state_incrementing_mean_next)


def _state_incrementing_broadcast_next(server_state, server_value):
    new_state = intrinsics.federated_map(_add_one, server_state)
    return (new_state, intrinsics.federated_broadcast(server_value))


state_incrementing_broadcaster = computation_utils.StatefulBroadcastFn(
    lambda: tf.constant(0), _state_incrementing_broadcast_next)


def _build_test_measured_broadcast(
    model_weights_type: computation_types.StructType
) -> measured_process.MeasuredProcess:
    """Builds a test `MeasuredProcess` that has state and metrics."""
    @computations.federated_computation()
    def initialize_comp():
        return intrinsics.federated_value(0, placements.SERVER)

    @computations.federated_computation(
        computation_types.FederatedType(tf.int32, placements.SERVER),
        computation_types.FederatedType(model_weights_type, placements.SERVER))
    def next_comp(state, value):
        return collections.OrderedDict(
Ejemplo n.º 5
0
 def federated_broadcast_test(values):
   broadcast_fn = computation_utils.StatefulBroadcastFn(
       initialize_fn=broadcast_initialize_fn, next_fn=broadcast_next_fn)
   state = tff.federated_value(broadcast_fn.initialize(), tff.SERVER)
   return broadcast_fn(state, values)