def test_constructor_with_type_mismatch(self):
        initialize = _build_initialize_comp(0)

        with self.assertRaisesRegex(
                TypeError,
                r'The return type of initialize_fn must be assignable.*'):

            @computations.federated_computation(tf.float32, tf.float32)
            def add_float32(current, val):
                return current + val

            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_float32)

        with self.assertRaisesRegex(
                TypeError,
                'The return type of next_fn must be assignable to the first parameter'
        ):

            @computations.federated_computation(tf.int32)
            def add_bad_result(_):
                return 0.0

            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_bad_result)

        with self.assertRaisesRegex(
                TypeError,
                'The return type of next_fn must be assignable to the first parameter'
        ):

            @computations.federated_computation(tf.int32)
            def add_bad_multi_result(_):
                return 0.0, 0

            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_bad_multi_result)

        with self.assertRaisesRegex(
                TypeError, 'MeasuredProcess must return a NamedTupleType'):

            @computations.federated_computation(tf.int32)
            def add_not_tuple_result(_):
                return 0

            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_not_tuple_result)

        with self.assertRaisesRegex(
                TypeError,
                'must match type signature <state=A,result=B,measurements=C>'):

            @computations.federated_computation(tf.int32)
            def add_not_named_tuple_result(_):
                return 0, 0, 0

            measured_process.MeasuredProcess(
                initialize_fn=initialize, next_fn=add_not_named_tuple_result)
Esempio n. 2
0
    def test_constructor_with_initialize_bad_type(self):
        with self.assertRaisesRegex(TypeError,
                                    r'Expected .*\.Computation, .*'):
            measured_process.MeasuredProcess(initialize_fn=None,
                                             next_fn=add_int32)

        with self.assertRaises(iterative_process.InitializeFnHasArgsError):

            @computations.federated_computation(tf.int32)
            def one_arg_initialize(one_arg):
                del one_arg  # Unused.
                return values.to_value(0)

            measured_process.MeasuredProcess(initialize_fn=one_arg_initialize,
                                             next_fn=add_int32)
Esempio n. 3
0
 def test_next_return_namedtuple_raises(self):
   measured_process_output = collections.namedtuple(
       'MeasuredProcessOutput', ['state', 'result', 'measurements'])
   namedtuple_next_fn = computations.tf_computation(
       tf.int32)(lambda state: measured_process_output(state, (), ()))
   with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
     measured_process.MeasuredProcess(test_initialize_fn, namedtuple_next_fn)
Esempio n. 4
0
def build_encoded_broadcast_process(value_type, encoders):
    """Builds `MeasuredProcess` for `value_type`, to be encoded by `encoders`.

  The returned `MeasuredProcess` has a next function with the TFF type
  signature:

  ```
  (<state_type@SERVER, {value_type}@CLIENTS> ->
   <state=state_type@SERVER, result=value_type@SERVER, measurements=()@SERVER>)
  ```

  Args:
    value_type: The type of values to be broadcasted by the `MeasuredProcess`.
      Either a `tff.TensorType` or a `tff.StructType`.
    encoders: A collection of `SimpleEncoder` objects to be used for encoding
      `values`. Must have the same structure as `values`.

  Returns:
    A `MeasuredProcess` of which `next_fn` encodes the input at `tff.SERVER`,
    broadcasts the encoded representation and decodes the encoded representation
    at `tff.CLIENTS`.

  Raises:
    ValueError: If `value_type` and `encoders` do not have the same structure.
    TypeError: If `encoders` are not instances of `SimpleEncoder`, or if
      `value_type` are not compatible with the expected input of the `encoders`.
  """
    py_typecheck.check_type(
        value_type,
        (computation_types.TensorType, computation_types.StructType))

    _validate_value_type_and_encoders(value_type, encoders,
                                      tensor_encoding.core.SimpleEncoder)

    initial_state_fn, state_type = _build_initial_state_tf_computation(
        encoders)

    @computations.federated_computation()
    def initial_state_comp():
        return intrinsics.federated_eval(initial_state_fn, placements.SERVER)

    encode_fn, decode_fn = _build_encode_decode_tf_computations_for_broadcast(
        state_type, value_type, encoders)

    @computations.federated_computation(
        initial_state_comp.type_signature.result,
        computation_types.FederatedType(value_type, placements.SERVER))
    def encoded_broadcast_comp(state, value):
        """Encoded broadcast federated_computation."""
        empty_metrics = intrinsics.federated_value((), placements.SERVER)
        new_state, encoded_value = intrinsics.federated_map(
            encode_fn, (state, value))
        client_encoded_value = intrinsics.federated_broadcast(encoded_value)
        client_value = intrinsics.federated_map(decode_fn,
                                                client_encoded_value)
        return measured_process.MeasuredProcessOutput(
            state=new_state, result=client_value, measurements=empty_metrics)

    return measured_process.MeasuredProcess(initialize_fn=initial_state_comp,
                                            next_fn=encoded_broadcast_comp)
Esempio n. 5
0
 def test_measured_process_output_as_state_raises(self):
     empty_output = lambda: MeasuredProcessOutput((), (), ())
     initialize_fn = computations.tf_computation(empty_output)
     next_fn = computations.tf_computation(
         initialize_fn.type_signature.result)(lambda state: empty_output())
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         measured_process.MeasuredProcess(initialize_fn, next_fn)
Esempio n. 6
0
 def test_construction_with_empty_state_does_not_raise(self):
   initialize_fn = computations.tf_computation()(lambda: ())
   next_fn = computations.tf_computation(
       ())(lambda x: MeasuredProcessOutput(x, (), ()))
   try:
     measured_process.MeasuredProcess(initialize_fn, next_fn)
   except:  # pylint: disable=bare-except
     self.fail('Could not construct an MeasuredProcess with empty state.')
Esempio n. 7
0
 def test_federated_init_state_not_assignable(self):
     zero = lambda: intrinsics.federated_value(0, placements.SERVER)
     initialize_fn = computations.federated_computation()(zero)
     next_fn = computations.federated_computation(
         computation_types.FederatedType(tf.int32, placements.CLIENTS))(
             lambda state: MeasuredProcessOutput(state, zero(), zero()))
     with self.assertRaises(errors.TemplateStateNotAssignableError):
         measured_process.MeasuredProcess(initialize_fn, next_fn)
Esempio n. 8
0
def build_encoded_sum_process(value_type, encoders):
    """Builds `MeasuredProcess` for `value_type`, to be encoded by `encoders`.

  The returned `MeasuredProcess` has a next function with the TFF type
  signature:

  ```
  (<state_type@SERVER, {value_type}@CLIENTS> ->
   <state=state_type@SERVER, result=value_type@SERVER, measurements=()@SERVER>)
  ```

  Args:
    value_type: The type of values to be encoded by the `MeasuredProcess`.
      Either a `tff.TensorType` or a `tff.StructType`.
    encoders: A collection of `GatherEncoder` objects to be used for encoding
      `values`. Must have the same structure as `values`.

  Returns:
    A `MeasuredProcess` of which `next_fn` encodes the input at `tff.CLIENTS`,
    and computes their sum at `tff.SERVER`, automatically splitting the decoding
    part based on its commutativity with sum.

  Raises:
    ValueError: If `value_type` and `encoders` do not have the same structure.
    TypeError: If `encoders` are not instances of `GatherEncoder`, or if
      `value_type` are not compatible with the expected input of the `encoders`.
  """
    py_typecheck.check_type(
        value_type,
        (computation_types.TensorType, computation_types.StructType))

    _validate_value_type_and_encoders(value_type, encoders,
                                      tensor_encoding.core.GatherEncoder)

    initial_state_fn, state_type = _build_initial_state_tf_computation(
        encoders)

    @computations.federated_computation()
    def initial_state_comp():
        return intrinsics.federated_eval(initial_state_fn, placements.SERVER)

    nest_encoder = _build_tf_computations_for_gather(state_type, value_type,
                                                     encoders)
    encoded_sum_fn = _build_encoded_sum_fn(nest_encoder)

    @computations.federated_computation(
        initial_state_comp.type_signature.result,
        computation_types.FederatedType(value_type, placements.CLIENTS))
    def encoded_sum_comp(state, values):
        """Encoded sum federated_computation."""
        empty_metrics = intrinsics.federated_value((), placements.SERVER)
        state, result = encoded_sum_fn(state, values)
        return collections.OrderedDict(state=state,
                                       result=result,
                                       measurements=empty_metrics)

    return measured_process.MeasuredProcess(initialize_fn=initial_state_comp,
                                            next_fn=encoded_sum_comp)
Esempio n. 9
0
    def test_constructor_with_next_result_param_type_mismatch(self):
        initialize = _build_initialize_comp(0)
        with self.assertRaises(iterative_process.NextMustReturnStateError):

            @computations.federated_computation(tf.int32)
            def add_bad_result(_):
                return 0.0

            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_bad_result)
Esempio n. 10
0
 def test_not_finalizer_type_raises(self):
     finalizer = test_finalizer()
     bad_finalizer = measured_process.MeasuredProcess(
         finalizer.initialize, finalizer.next)
     with self.assertRaisesRegex(TypeError, 'FinalizerProcess'):
         composers.compose_learning_process(test_init_model_weights_fn,
                                            test_distributor(),
                                            test_client_work(),
                                            test_aggregator(),
                                            bad_finalizer)
Esempio n. 11
0
def _wrap_in_measured_process(
        stateful_fn: Union[computation_utils.StatefulBroadcastFn,
                           computation_utils.StatefulAggregateFn],
        input_type: computation_types.Type
) -> measured_process.MeasuredProcess:
    """Converts a `computation_utils.StatefulFn` to a `tff.templates.MeasuredProcess`."""
    py_typecheck.check_type(stateful_fn,
                            (computation_utils.StatefulBroadcastFn,
                             computation_utils.StatefulAggregateFn))

    @computations.federated_computation()
    def initialize_comp():
        if not isinstance(stateful_fn.initialize,
                          computation_base.Computation):
            initialize = computations.tf_computation(stateful_fn.initialize)
        else:
            initialize = stateful_fn.initialize
        return intrinsics.federated_eval(initialize, placements.SERVER)

    state_type = initialize_comp.type_signature.result

    if isinstance(stateful_fn, computation_utils.StatefulBroadcastFn):

        @computations.federated_computation(
            state_type,
            computation_types.FederatedType(input_type, placements.SERVER),
        )
        def next_comp(state, value):
            empty_metrics = intrinsics.federated_value((), placements.SERVER)
            state, result = stateful_fn(state, value)
            return collections.OrderedDict(state=state,
                                           result=result,
                                           measurements=empty_metrics)

    elif isinstance(stateful_fn, computation_utils.StatefulAggregateFn):

        @computations.federated_computation(
            state_type,
            computation_types.FederatedType(input_type, placements.CLIENTS),
            computation_types.FederatedType(tf.float32, placements.CLIENTS))
        def next_comp(state, value, weight):
            empty_metrics = intrinsics.federated_value((), placements.SERVER)
            state, result = stateful_fn(state, value, weight)
            return collections.OrderedDict(state=state,
                                           result=result,
                                           measurements=empty_metrics)

    else:
        raise TypeError(
            'Received a {t}, expected either a computation_utils.StatefulAggregateFn or a '
            'computation_utils.StatefulBroadcastFn.'.format(
                t=type(stateful_fn)))

    return measured_process.MeasuredProcess(initialize_fn=initialize_comp,
                                            next_fn=next_comp)
Esempio n. 12
0
    def test_constructor_with_state_only(self):
        ip = measured_process.MeasuredProcess(_build_initialize_comp(0),
                                              count_int32)

        state = ip.initialize()
        iterations = 10
        for _ in range(iterations):
            state, result, measurements = attr.astuple(ip.next(state))
            self.assertLen(result, 0)
            self.assertLen(measurements, 0)
        self.assertEqual(state, iterations)
Esempio n. 13
0
  def test_federated_next_state_not_assignable(self):
    initialize_fn = computations.federated_computation()(
        lambda: intrinsics.federated_value(0, placements.SERVER))

    @computations.federated_computation(initialize_fn.type_signature.result)
    def next_fn(state):
      return MeasuredProcessOutput(
          intrinsics.federated_broadcast(state), (), ())

    with self.assertRaises(errors.TemplateStateNotAssignableError):
      measured_process.MeasuredProcess(initialize_fn, next_fn)
Esempio n. 14
0
    def test_constructor_with_init_next_type_mismatch(self):
        initialize = _build_initialize_comp(0)
        with self.assertRaises(
                iterative_process.NextMustAcceptStateFromInitializeError):

            @computations.federated_computation(tf.float32, tf.float32)
            def add_float32(current, val):
                return current + val

            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_float32)
Esempio n. 15
0
    def test_constructor_with_init_next_type_mismatch(self):
        initialize = _build_initialize_comp(0)
        with self.assertRaisesRegex(
                TypeError,
                r'The return type of initialize_fn must be assignable.*'):

            @computations.federated_computation(tf.float32, tf.float32)
            def add_float32(current, val):
                return current + val

            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_float32)
Esempio n. 16
0
    def test_is_valid_broadcast_process_bad_placement(self):
        @federated_computation.federated_computation()
        def stateless_init():
            return intrinsics.federated_value((), placements.SERVER)

        @federated_computation.federated_computation(
            computation_types.FederatedType((), placements.SERVER),
            computation_types.FederatedType((), placements.SERVER),
        )
        def fake_broadcast(state, value):
            empty_metrics = intrinsics.federated_value(1.0, placements.SERVER)
            return measured_process.MeasuredProcessOutput(
                state=state, result=value, measurements=empty_metrics)

        stateless_process = measured_process.MeasuredProcess(
            initialize_fn=stateless_init, next_fn=fake_broadcast)

        # Expect to be false because `result` of `next` is on the server.
        self.assertFalse(
            optimizer_utils.is_valid_broadcast_process(stateless_process))

        @federated_computation.federated_computation()
        def stateless_init2():
            return intrinsics.federated_value((), placements.SERVER)

        @federated_computation.federated_computation(
            computation_types.FederatedType((), placements.SERVER),
            computation_types.FederatedType((), placements.CLIENTS),
        )
        def stateless_broadcast(state, value):
            empty_metrics = intrinsics.federated_value(1.0, placements.SERVER)
            return measured_process.MeasuredProcessOutput(
                state=state, result=value, measurements=empty_metrics)

        stateless_process = measured_process.MeasuredProcess(
            initialize_fn=stateless_init2, next_fn=stateless_broadcast)

        # Expect to be false because second param of `next` is on the clients.
        self.assertFalse(
            optimizer_utils.is_valid_broadcast_process(stateless_process))
Esempio n. 17
0
    def test_constructor_with_next_result_param_type_mismatch(self):
        initialize = _build_initialize_comp(0)
        with self.assertRaisesRegex(
                TypeError,
                'The return type of next_fn must be assignable to the first parameter'
        ):

            @computations.federated_computation(tf.int32)
            def add_bad_result(_):
                return 0.0

            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_bad_result)
Esempio n. 18
0
    def test_constructor_with_state_tuple_arg(self):
        ip = measured_process.MeasuredProcess(_build_initialize_comp(0),
                                              add_int32)

        state = ip.initialize()
        iterations = 10
        for val in range(iterations):
            output = ip.next(state, val)
            state = output.state
        self.assertEqual(output.state, sum(range(iterations)))
        self.assertEqual(output.result, val)
        expected_measurment = sum(range(iterations - 1)) / iterations
        self.assertAllClose(output.measurements, [expected_measurment])
Esempio n. 19
0
    def test_constructor_with_next_result_not_measuredprocessoutput(self):
        initialize = _build_initialize_comp(0)
        with self.assertRaisesRegex(
                TypeError,
                'MeasuredProcess must return a MeasuredProcessOutput'):

            @computations.federated_computation(tf.int32)
            def add_not_tuple_result(_):
                return 0

            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_not_tuple_result)

        with self.assertRaisesRegex(
                TypeError,
                'MeasuredProcess must return a MeasuredProcessOutput'):

            @computations.federated_computation(tf.int32)
            def add_not_named_tuple_result(_):
                return 0, 0, 0

            measured_process.MeasuredProcess(
                initialize_fn=initialize, next_fn=add_not_named_tuple_result)
def _create_test_measured_process_state_at_clients():

  @federated_computation.federated_computation(
      computation_types.at_clients(tf.int32),
      computation_types.at_clients(tf.int32))
  def next_fn(state, values):
    return measured_process.MeasuredProcessOutput(
        state, intrinsics.federated_sum(values),
        intrinsics.federated_value(1, placements.SERVER))

  return measured_process.MeasuredProcess(
      initialize_fn=federated_computation.federated_computation(
          lambda: intrinsics.federated_value(0, placements.CLIENTS)),
      next_fn=next_fn)
Esempio n. 21
0
  def test_construction_with_unknown_dimension_does_not_raise(self):
    initialize_fn = computations.tf_computation()(
        lambda: tf.constant([], dtype=tf.string))

    @computations.tf_computation(
        computation_types.TensorType(shape=[None], dtype=tf.string))
    def next_fn(strings):
      return MeasuredProcessOutput(
          tf.concat([strings, tf.constant(['abc'])], axis=0), (), ())

    try:
      measured_process.MeasuredProcess(initialize_fn, next_fn)
    except:  # pylint: disable=bare-except
      self.fail('Could not construct an MeasuredProcess with parameter types '
                'with statically unknown shape.')
Esempio n. 22
0
    def test_measured_process_output_as_state_raises(self):
        no_value = lambda: intrinsics.federated_value((), placements.SERVER)

        @computations.federated_computation()
        def initialize_fn():
            return intrinsics.federated_zip(
                MeasuredProcessOutput(no_value(), no_value(), no_value()))

        @computations.federated_computation(
            initialize_fn.type_signature.result, CLIENTS_FLOAT)
        def next_fn(state, value):
            del state, value
            return MeasuredProcessOutput(no_value(), no_value(), no_value())

        with self.assertRaises(errors.TemplateStateNotAssignableError):
            measured_process.MeasuredProcess(initialize_fn, next_fn)
Esempio n. 23
0
    def test_federated_measured_process_output_raises(self):
        initialize_fn = computations.federated_computation()(
            lambda: intrinsics.federated_value(0, placements.SERVER))
        empty = lambda: intrinsics.federated_value((), placements.SERVER)
        state_type = initialize_fn.type_signature.result

        # Using federated_zip to place FederatedType at the top of the hierarchy.
        @computations.federated_computation(state_type)
        def next_fn(state):
            return intrinsics.federated_zip(
                MeasuredProcessOutput(state, empty(), empty()))

        # A MeasuredProcessOutput containing three `FederatedType`s is different
        # than a `FederatedType` containing a MeasuredProcessOutput. Corrently, only
        # the former is considered valid.
        with self.assertRaises(errors.TemplateStateNotAssignableError):
            measured_process.MeasuredProcess(initialize_fn, next_fn)
Esempio n. 24
0
def build_stateless_mean(
    *, model_delta_type: Union[computation_types.StructType,
                               computation_types.TensorType]
) -> measured_process.MeasuredProcess:
    """Builds a `MeasuredProcess` that wraps` tff.federated_mean`."""
    @computations.federated_computation(
        NONE_SERVER_TYPE,
        computation_types.FederatedType(model_delta_type, placements.CLIENTS),
        computation_types.FederatedType(tf.float32, placements.CLIENTS))
    def stateless_mean(state, value, weight):
        empty_metrics = intrinsics.federated_value((), placements.SERVER)
        return measured_process.MeasuredProcessOutput(
            state=state,
            result=intrinsics.federated_mean(value, weight=weight),
            measurements=empty_metrics)

    return measured_process.MeasuredProcess(
        initialize_fn=_empty_server_initialization, next_fn=stateless_mean)
Esempio n. 25
0
    def test_constructor_with_tensors_unknown_dimensions_succeeds(self):
        @computations.tf_computation
        def init():
            return tf.constant([], dtype=tf.string)

        @computations.tf_computation(
            computation_types.TensorType(shape=[None], dtype=tf.string))
        def next_fn(strings):
            return MeasuredProcessOutput(state=tf.concat(
                [strings, tf.constant(['abc'])], axis=0),
                                         result=(),
                                         measurements=())

        try:
            measured_process.MeasuredProcess(init, next_fn)
        except:  # pylint: disable=bare-except
            self.fail(
                'Could not construct an MeasuredProcess with parameter types '
                'including unknown dimension tennsors.')
Esempio n. 26
0
    def test_constructor_with_next_struct_of_different_placedresult(self):
        @computations.federated_computation
        def initialize_comp():
            return intrinsics.federated_value(0, placements.SERVER)

        # A `next` function that returns different placements for the components.
        @computations.federated_computation(
            initialize_comp.type_signature.result)
        def next_comp(state):
            return measured_process.MeasuredProcessOutput(
                state=state,
                result=intrinsics.federated_value(0, placements.CLIENTS),
                measurements=intrinsics.federated_value((), placements.SERVER))

        try:
            measured_process.MeasuredProcess(initialize_fn=initialize_comp,
                                             next_fn=next_comp)
        except Exception as e:  # pylint: disable=broad-except
            self.fail(f'Failed to construct MeasuredProcess: {e}')
Esempio n. 27
0
def build_stateless_broadcaster(
    *, model_weights_type: Union[computation_types.StructType,
                                 computation_types.TensorType]
) -> measured_process.MeasuredProcess:
    """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`."""
    @computations.federated_computation(
        computation_types.FederatedType((), placements.SERVER),
        computation_types.FederatedType(model_weights_type, placements.SERVER),
    )
    def stateless_broadcast(state, value):
        empty_metrics = intrinsics.federated_value((), placements.SERVER)
        return measured_process.MeasuredProcessOutput(
            state=state,
            result=intrinsics.federated_broadcast(value),
            measurements=empty_metrics)

    return measured_process.MeasuredProcess(
        initialize_fn=_empty_server_initialization,
        next_fn=stateless_broadcast)
  def test_federated_evaluation_fails_stateful_broadcast(self):
    # Create a test stateful measured process that doesn't do anything useful.

    @computations.federated_computation
    def init_fn():
      return intrinsics.federated_eval(
          computations.tf_computation(
              lambda: tf.zeros(shape=[], dtype=tf.float32)), placements.SERVER)

    @computations.federated_computation(
        computation_types.at_server(tf.float32),
        computation_types.at_clients(tf.int32))
    def next_fn(state, value):
      return measured_process.MeasuredProcessOutput(state, value, state)

    broadcaster = measured_process.MeasuredProcess(init_fn, next_fn)
    with self.assertRaisesRegex(ValueError, 'stateful broadcast'):
      federated_evaluation.build_federated_evaluation(
          TestModelQuant, broadcast_process=broadcaster)
Esempio n. 29
0
    def test_constructor_with_next_federated_same_placed_struct_result(self):
        @computations.federated_computation
        def initialize_comp():
            return intrinsics.federated_value(0, placements.SERVER)

        # A `next` function that returns all the same placement and is zipped so
        # the FederatedType is at the top of the type hierarchy.
        @computations.federated_computation(
            initialize_comp.type_signature.result)
        def next_comp(state):
            return intrinsics.federated_zip(
                measured_process.MeasuredProcessOutput(
                    state=state,
                    result=intrinsics.federated_value(0, placements.SERVER),
                    measurements=intrinsics.federated_value(
                        (), placements.SERVER))),

        with self.assertRaises(iterative_process.NextMustReturnStateError):
            measured_process.MeasuredProcess(initialize_fn=initialize_comp,
                                             next_fn=next_comp)
    def test_is_stateful_process_true(self):
        @computations.federated_computation()
        def stateful_init():
            return intrinsics.federated_value(2.0, placements.SERVER)

        @computations.federated_computation(
            computation_types.FederatedType(tf.float32, placements.SERVER),
            computation_types.FederatedType((), placements.SERVER),
        )
        def stateful_broadcast(state, value):
            empty_metrics = intrinsics.federated_value(1.0, placements.SERVER)
            return measured_process.MeasuredProcessOutput(
                state=state,
                result=intrinsics.federated_broadcast(value),
                measurements=empty_metrics)

        stateful_process = measured_process.MeasuredProcess(
            initialize_fn=stateful_init, next_fn=stateful_broadcast)

        self.assertTrue(optimizer_utils.is_stateful_process(stateful_process))