Пример #1
0
def build_control_computation(gan: GanFnsAndTypes,
                              disc_optimizer_fn: OptimizerBuilder,
                              gen_optimizer_fn: OptimizerBuilder, tau: float):
    """Returns a `tff.tf_computation` for the `client_computation`.

  This is a thin wrapper around `gan_training_tf_fns.client_computation`.

  Args:
    gan: A `GanFnsAndTypes` object.

  Returns:
    A `tff.tf_computation.`
  """
    @tff.tf_computation(tff.SequenceType(gan.gen_input_type),
                        tff.SequenceType(gan.real_data_type),
                        gan.from_server_type)
    def control_computation(gen_inputs, real_data, from_server):
        """Returns the client_output."""
        generator = gan.generator_model_fn()
        discriminator = gan.discriminator_model_fn()
        zero_gen = tf.nest.map_structure(tf.zeros_like, generator.weights)
        zero_disc = tf.nest.map_structure(tf.zeros_like, discriminator.weights)
        return gan_training_tf_fns.client_control(
            gen_inputs_ds=gen_inputs,
            real_data_ds=real_data,
            from_server=from_server,
            generator=generator,
            discriminator=discriminator,
            disc_optimizer=disc_optimizer_fn(),
            gen_optimizer=gen_optimizer_fn(),
            zero_gen=gan.generator_model_fn(),
            zero_disc=gan.discriminator_model_fn(),
            tau=tau)

    return control_computation
Пример #2
0
def build_client_computation(gan: GanFnsAndTypes):
    """Returns a `tff.tf_computation` for the `client_computation`.

  This is a thin wrapper around `gan_training_tf_fns.client_computation`.

  Args:
    gan: A `GanFnsAndTypes` object.

  Returns:
    A `tff.tf_computation.`
  """
    @tff.tf_computation(tff.SequenceType(gan.gen_input_type),
                        tff.SequenceType(gan.real_data_type),
                        gan.from_server_type)
    def client_computation(gen_inputs, real_data, from_server):
        """Returns the client_output."""
        return gan_training_tf_fns.client_computation(
            gen_inputs_ds=gen_inputs,
            real_data_ds=real_data,
            from_server=from_server,
            generator=gan.generator_model_fn(),
            discriminator=gan.discriminator_model_fn(),
            train_discriminator_fn=gan.train_discriminator_fn)

    return client_computation
Пример #3
0
def _temperature_sensor_example_next_fn():

  @tff.tf_computation(
      tff.SequenceType(tf.float32), tf.float32)
  def count_over(ds, t):
    return ds.reduce(
        np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32))

  @tff.tf_computation(tff.SequenceType(tf.float32))
  def count_total(ds):
    return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0)

  @tff.federated_computation(
      tff.FederatedType(tff.SequenceType(tf.float32), tff.CLIENTS),
      tff.FederatedType(tf.float32, tff.SERVER))
  def comp(temperatures, threshold):
    return tff.federated_mean(
        tff.federated_map(
            count_over,
            tff.federated_zip(
                [temperatures,
                 tff.federated_broadcast(threshold)])),
        tff.federated_map(count_total, temperatures))

  return comp
Пример #4
0
def build_client_computation(gan: GanFnsAndTypes,
                             disc_optimizer_fn: OptimizerBuilder,
                             gen_optimizer_fn: OptimizerBuilder, tau: float):
    """Returns a `tff.tf_computation` for the `client_computation`.

  This is a thin wrapper around `gan_training_tf_fns.client_computation`.

  Args:
    gan: A `GanFnsAndTypes` object.

  Returns:
    A `tff.tf_computation.`
  """
    @tff.tf_computation(tff.SequenceType(gan.gen_input_type),
                        tff.SequenceType(gan.real_data_type),
                        gan.from_server_type,
                        gan.from_server_type.generator_weights,
                        gan.from_server_type.discriminator_weights)
    def client_computation(gen_inputs, real_data, from_server,
                           control_input_gen, control_input_disc):
        """Returns the client_output."""
        return gan_training_tf_fns.client_computation(
            gen_inputs_ds=gen_inputs,
            real_data_ds=real_data,
            from_server=from_server,
            generator=gan.generator_model_fn(),
            discriminator=gan.discriminator_model_fn(),
            gen_optimizer=gen_optimizer_fn(),
            disc_optimizer=disc_optimizer_fn(),
            control_input_gen=control_input_gen,
            control_input_disc=control_input_disc,
            tau=tau)

    return client_computation
Пример #5
0
  def test_build_with_preprocess_function(self):
    test_dataset = tf.data.Dataset.range(5)
    client_datasets_type = tff.FederatedType(
        tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)

    @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
    def preprocess_dataset(ds):

      def to_batch(x):
        return _Batch(
            tf.fill(dims=(784,), value=float(x) * 2.0),
            tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0))

      return ds.map(to_batch).batch(2)

    iterproc = _build_simple_fed_pa_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        server_optimizer_fn=tf.keras.optimizers.SGD)
    iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
        preprocess_dataset, iterproc)

    with tf.Graph().as_default():
      test_model_for_types = _uncompiled_model_builder()

    server_state_type = tff.FederatedType(
        fed_pa_schedule.ServerState(
            model=tff.framework.type_from_tensors(
                tff.learning.ModelWeights(
                    test_model_for_types.trainable_variables,
                    test_model_for_types.non_trainable_variables)),
            optimizer_state=(tf.int64,),
            round_num=tf.float32), tff.SERVER)
    metrics_type = tff.FederatedType(
        tff.StructType([('loss', tf.float32),
                        ('model_delta_zeros_percent', tf.float32),
                        ('model_delta_correction_l2_norm', tf.float32)]),
        tff.SERVER)

    expected_parameter_type = collections.OrderedDict(
        server_state=server_state_type,
        federated_dataset=client_datasets_type,
    )
    expected_result_type = (server_state_type, metrics_type)

    expected_type = tff.FunctionType(
        parameter=expected_parameter_type, result=expected_result_type)
    self.assertTrue(
        iterproc.next.type_signature.is_equivalent_to(expected_type),
        msg='{s}\n!={t}'.format(
            s=iterproc.next.type_signature, t=expected_type))
Пример #6
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        if not isinstance(self._generator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._generator)))
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        if not isinstance(self._discriminator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._discriminator)))

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v.read_value()),
                var_struct)

        self.discriminator_weights_type = vars_to_type(
            self._discriminator.weights)
        self.generator_weights_type = vars_to_type(self._generator.weights)

        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type,
            meta_gen=self.generator_weights_type,
            meta_disc=self.discriminator_weights_type)

        self.client_gen_input_type = tff.type_at_clients(
            tff.SequenceType(self.gen_input_type))
        self.client_real_data_type = tff.type_at_clients(
            tff.SequenceType(self.real_data_type))
        self.server_gen_input_type = tff.type_at_server(
            tff.SequenceType(self.gen_input_type))

        if self.train_discriminator_dp_average_query is not None:
            self.aggregation_process = tff.aggregators.DifferentiallyPrivateFactory(
                query=self.train_discriminator_dp_average_query).create(
                    value_type=tff.to_type(self.discriminator_weights_type))
        else:
            self.aggregation_process = tff.aggregators.MeanFactory().create(
                value_type=tff.to_type(self.discriminator_weights_type),
                weight_type=tff.to_type(tf.float32))
Пример #7
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        if not isinstance(self._generator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._generator)))
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        if not isinstance(self._discriminator, tf.keras.models.Model):
            raise TypeError(
                'Expected `tf.keras.models.Model`, found {}.'.format(
                    type(self._discriminator)))

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v.read_value()),
                var_struct)

        self.discriminator_weights_type = vars_to_type(
            self._discriminator.weights)
        self.generator_weights_type = vars_to_type(self._generator.weights)

        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type)

        self.client_gen_input_type = tff.type_at_clients(
            tff.SequenceType(self.gen_input_type))
        self.client_real_data_type = tff.type_at_clients(
            tff.SequenceType(self.real_data_type))
        self.server_gen_input_type = tff.type_at_server(
            tff.SequenceType(self.gen_input_type))

        # Right now, the logic in this library is effectively "if DP use stateful
        # aggregator, else don't use stateful aggregator". An alternative
        # formulation would be to always use a stateful aggregator, but when not
        # using DP default the aggregator to be a stateless mean, e.g.,
        # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283.
        if self.train_discriminator_dp_average_query is not None:
            self.dp_averaging_fn = tff.utils.build_dp_aggregate_process(
                value_type=tff.to_type(self.discriminator_weights_type),
                query=self.train_discriminator_dp_average_query)
Пример #8
0
def build_client_computation(gan: GanFnsAndTypes):
    """Returns a `tff.tf_computation` for the `client_computation`.

  This is a thin wrapper around `gan_training_tf_fns.client_computation`.

  Args:
    gan: A `GanFnsAndTypes` object.

  Returns:
    A `tff.tf_computation.`
  """
    @tff.tf_computation(tff.SequenceType(gan.gen_input_type),
                        tff.SequenceType(gan.real_data_type),
                        gan.from_server_type)
    def client_computation(gen_inputs, real_data, from_server):
        """Returns the client_output."""
        steps = from_server.counters['num_rounds']
        scheduler = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            [1000], [0.001, 0.0005])
        generator = gan.generator_model_fn()
        state_gen_optimizer = gan.state_gen_optimizer_fn(
            scheduler.__call__(steps))
        gan_training_tf_fns.initialize_optimizer_vars(generator,
                                                      state_gen_optimizer)
        discriminator = gan.discriminator_model_fn()
        state_disc_optimizer = gan.state_disc_optimizer_fn(0.0002)
        gan_training_tf_fns.initialize_optimizer_vars(discriminator,
                                                      state_disc_optimizer)
        if gan.disc_status == 'fedadam':
            return gan_training_tf_fns.client_computation_fedadam(
                gen_inputs_ds=gen_inputs,
                real_data_ds=real_data,
                from_server=from_server,
                generator=generator,
                discriminator=discriminator,
                state_gen_optimizer=state_gen_optimizer,
                state_disc_optimizer=state_disc_optimizer)
        else:
            return gan_training_tf_fns.client_computation(
                gen_inputs_ds=gen_inputs,
                real_data_ds=real_data,
                from_server=from_server,
                generator=generator,
                discriminator=discriminator,
                state_gen_optimizer=state_gen_optimizer,
                state_disc_optimizer=state_disc_optimizer)

    return client_computation
Пример #9
0
    def __attrs_post_init__(self):
        self.gen_input_type = tensor_spec_for_batch(self.dummy_gen_input)
        self.real_data_type = tensor_spec_for_batch(self.dummy_real_data)

        # Model-weights based types
        self._generator = self.generator_model_fn()
        _ = self._generator(self.dummy_gen_input)
        py_typecheck.check_type(self._generator, tf.keras.models.Model)
        self._discriminator = self.discriminator_model_fn()
        _ = self._discriminator(self.dummy_real_data)
        py_typecheck.check_type(self._discriminator, tf.keras.models.Model)

        def vars_to_type(var_struct):
            # TODO(b/131681951): read_value() shouldn't be needed
            return tf.nest.map_structure(
                lambda v: tf.TensorSpec.from_tensor(v.read_value()),
                var_struct)

        self.discriminator_weights_type = vars_to_type(
            self._discriminator.weights)
        self.generator_weights_type = vars_to_type(self._generator.weights)

        self.from_server_type = gan_training_tf_fns.FromServer(
            generator_weights=self.generator_weights_type,
            discriminator_weights=self.discriminator_weights_type)

        self.client_gen_input_type = tff.FederatedType(
            tff.SequenceType(self.gen_input_type), tff.CLIENTS)
        self.client_real_data_type = tff.FederatedType(
            tff.SequenceType(self.real_data_type), tff.CLIENTS)
        self.server_gen_input_type = tff.FederatedType(
            tff.SequenceType(self.gen_input_type), tff.SERVER)

        # Right now, the logic in this library is effectively "if DP use stateful
        # aggregator, else don't use stateful aggregator". An alternative
        # formulation would be to always use a stateful aggregator, but when not
        # using DP default the aggregator to be a stateless mean, e.g.,
        # https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/framework/optimizer_utils.py#L283.
        # This change will be easier to make if the tff.StatefulAggregateFn is
        # modified to have a property that gives the type of the aggregation state
        # (i.e., what we're storing in self.dp_averaging_state_type).
        if self.train_discriminator_dp_average_query is not None:
            self.dp_averaging_fn, self.dp_averaging_state_type = (
                tff.utils.build_dp_aggregate(
                    query=self.train_discriminator_dp_average_query,
                    value_type_fn=lambda value: self.
                    discriminator_weights_type,
                    from_tff_result_fn=lambda record: list(record)))  # pylint: disable=unnecessary-lambda
Пример #10
0
def validator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN,
              client_state_fn: CLIENT_STATE_FN):
    model = model_fn()
    client_state = client_state_fn()

    dataset_type = tff.SequenceType(model.input_spec)
    client_state_type = tff.framework.type_from_tensors(client_state)
    weights_type = tff.learning.framework.weights_type_from_model(model)

    validate_client_tf = tff.tf_computation(
        lambda dataset, state, weights: __validate_client(
            dataset, state, weights, coefficient_fn, model_fn,
            tf.function(client.validate)),
        (dataset_type, client_state_type, weights_type))

    federated_weights_type = tff.type_at_server(weights_type)
    federated_dataset_type = tff.type_at_clients(dataset_type)
    federated_client_state_type = tff.type_at_clients(client_state_type)

    def validate(weights, datasets, client_states):
        broadcast = tff.federated_broadcast(weights)
        outputs = tff.federated_map(validate_client_tf,
                                    (datasets, client_states, broadcast))
        metrics = model.federated_output_computation(outputs.metrics)

        return metrics

    return tff.federated_computation(
        validate, (federated_weights_type, federated_dataset_type,
                   federated_client_state_type))
Пример #11
0
def validator(
  model_fn: MODEL_FN,
  client_state_fn: CLIENT_STATE_FN
):
  model = model_fn()
  client_state = client_state_fn()

  dataset_type = tff.SequenceType(model.input_spec)
  client_state_type = tff.framework.type_from_tensors(client_state)

  validate_client_tf = tff.tf_computation(
    lambda dataset, state: __validate_client(
      dataset,
      state,
      model_fn,
      tf.function(client.validate)
    ),
    (dataset_type, client_state_type)
  )

  federated_dataset_type = tff.type_at_clients(dataset_type)
  federated_client_state_type = tff.type_at_clients(client_state_type)    

  def validate(datasets, client_states):
    outputs = tff.federated_map(validate_client_tf, (datasets, client_states))
    metrics = model.federated_output_computation(outputs.metrics)

    return metrics

  return tff.federated_computation(
    validate,
    (federated_dataset_type, federated_client_state_type)
  )
Пример #12
0
    def test_executes_dataset_concat_aggregation(self):

        tensor_spec = tf.TensorSpec(shape=[2], dtype=tf.float32)

        @tff.tf_computation
        def create_empty_ds():
            empty_tensor = tf.zeros(shape=[0] + tensor_spec.shape,
                                    dtype=tensor_spec.dtype)
            return tf.data.Dataset.from_tensor_slices(empty_tensor)

        @tff.tf_computation
        def concat_datasets(ds1, ds2):
            return ds1.concatenate(ds2)

        @tff.tf_computation
        def identity(ds):
            return ds

        @tff.federated_computation(
            tff.type_at_clients(tff.SequenceType(tensor_spec)))
        def do_a_federated_aggregate(client_ds):
            return tff.federated_aggregate(value=client_ds,
                                           zero=create_empty_ds(),
                                           accumulate=concat_datasets,
                                           merge=concat_datasets,
                                           report=identity)

        input_data = tf.data.Dataset.from_tensor_slices([[0.1, 0.2]])
        ds = do_a_federated_aggregate([input_data])
        self.assertIsInstance(ds, tf.data.Dataset)
Пример #13
0
def iterator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN,
             client_state_fn: CLIENT_STATE_FN,
             server_optimizer_fn: OPTIMIZER_FN,
             client_optimizer_fn: OPTIMIZER_FN):
    model = model_fn()
    client_state = client_state_fn()

    init_tf = tff.tf_computation(
        lambda: __initialize_server(model_fn, server_optimizer_fn))

    server_state_type = init_tf.type_signature.result
    client_state_type = tff.framework.type_from_tensors(client_state)

    update_server_tf = tff.tf_computation(
        lambda state, weights_delta: __update_server(
            state, weights_delta, model_fn, server_optimizer_fn,
            tf.function(server.update)),
        (server_state_type, server_state_type.model.trainable))

    state_to_message_tf = tff.tf_computation(
        lambda state: __state_to_message(state, tf.function(server.to_message)
                                         ), server_state_type)

    dataset_type = tff.SequenceType(model.input_spec)
    server_message_type = state_to_message_tf.type_signature.result

    update_client_tf = tff.tf_computation(
        lambda dataset, state, message: __update_client(
            dataset, state, message, coefficient_fn, model_fn,
            client_optimizer_fn, tf.function(client.update)),
        (dataset_type, client_state_type, server_message_type))

    federated_server_state_type = tff.type_at_server(server_state_type)
    federated_dataset_type = tff.type_at_clients(dataset_type)
    federated_client_state_type = tff.type_at_clients(client_state_type)

    def init_tff():
        return tff.federated_value(init_tf(), tff.SERVER)

    def next_tff(server_state, datasets, client_states):
        message = tff.federated_map(state_to_message_tf, server_state)
        broadcast = tff.federated_broadcast(message)

        outputs = tff.federated_map(update_client_tf,
                                    (datasets, client_states, broadcast))
        weights_delta = tff.federated_mean(outputs.weights_delta,
                                           weight=outputs.client_weight)

        metrics = model.federated_output_computation(outputs.metrics)

        next_state = tff.federated_map(update_server_tf,
                                       (server_state, weights_delta))

        return next_state, metrics, outputs.client_state

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(init_tff),
        next_fn=tff.federated_computation(
            next_tff, (federated_server_state_type, federated_dataset_type,
                       federated_client_state_type)))
Пример #14
0
def evaluator(coefficient_fn: COEFFICIENT_FN, model_fn: MODEL_FN,
              client_state_fn: CLIENT_STATE_FN):
    model = model_fn()
    client_state = client_state_fn()

    dataset_type = tff.SequenceType(model.input_spec)
    client_state_type = tff.framework.type_from_tensors(client_state)
    weights_type = tff.framework.type_from_tensors(
        tff.learning.ModelWeights.from_model(model))

    evaluate_client_tf = tff.tf_computation(
        lambda dataset, state, weights: __evaluate_client(
            dataset, state, weights, coefficient_fn, model_fn,
            tf.function(client.evaluate)),
        (dataset_type, client_state_type, weights_type))

    federated_weights_type = tff.type_at_server(weights_type)
    federated_dataset_type = tff.type_at_clients(dataset_type)
    federated_client_state_type = tff.type_at_clients(client_state_type)

    def evaluate(weights, datasets, client_states):
        broadcast = tff.federated_broadcast(weights)
        outputs = tff.federated_map(evaluate_client_tf,
                                    (datasets, client_states, broadcast))

        confusion_matrix = tff.federated_sum(outputs.confusion_matrix)
        aggregated_metrics = model.federated_output_computation(
            outputs.metrics)
        collected_metrics = tff.federated_collect(outputs.metrics)

        return confusion_matrix, aggregated_metrics, collected_metrics

    return tff.federated_computation(
        evaluate, (federated_weights_type, federated_dataset_type,
                   federated_client_state_type))
Пример #15
0
    def test_eval_fn_has_correct_type_signature(self):
        metrics_builder = lambda: [tf.keras.metrics.MeanSquaredError()]
        eval_fn = evaluation.build_centralized_evaluation(
            tff_model_fn, metrics_builder)
        actual_type = eval_fn.type_signature

        model_type = tff.FederatedType(
            tff.learning.ModelWeights(
                trainable=(
                    tff.TensorType(tf.float32, [1, 1]),
                    tff.TensorType(tf.float32, [1]),
                ),
                non_trainable=(),
            ), tff.SERVER)

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.SERVER)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(
                mean_squared_error=tff.TensorType(tf.float32),
                num_examples=tff.TensorType(tf.float32)), tff.SERVER)

        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            model_weights=model_type, centralized_dataset=dataset_type),
                                         result=metrics_type)

        actual_type.check_assignable_from(expected_type)
Пример #16
0
    def test_execute_with_preprocess_function(self):
        test_dataset = tf.data.Dataset.range(1)

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def to_example(x):
                del x  # Unused.
                return _Batch(x=np.ones([784], dtype=np.float32),
                              y=np.ones([1], dtype=np.int64))

            return ds.map(to_example).batch(1)

        iterproc = fed_avg_schedule.build_fed_avg_process(
            _uncompiled_model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
            preprocess_dataset, iterproc)

        _, train_outputs, _ = self._run_rounds(iterproc, [test_dataset], 6)
        self.assertLess(train_outputs[-1]['loss'], train_outputs[0]['loss'])
        train_gap_first_half = train_outputs[0]['loss'] - train_outputs[2][
            'loss']
        train_gap_second_half = train_outputs[3]['loss'] - train_outputs[5][
            'loss']
        self.assertLess(train_gap_second_half, train_gap_first_half)
Пример #17
0
def build_server_computation(gan: GanFnsAndTypes, server_state_type: tff.Type,
                             client_output_type: tff.Type):
    """Returns a `tff.tf_computation` for the `server_computation`.

  This is a thin wrapper around `gan_training_tf_fns.server_computation`.

  Args:
    gan: A `GanFnsAndTypes` object.
    server_state_type: The `tff.Type` of the ServerState.
    client_output_type: The `tff.Type` of the ClientOutput.

  Returns:
    A `tff.tf_computation.`
  """
    @tff.tf_computation(server_state_type,
                        tff.SequenceType(gan.gen_input_type),
                        client_output_type, gan.dp_averaging_state_type)
    def server_computation(server_state, gen_inputs, client_output,
                           new_dp_averaging_state):
        """The wrapped server_computation."""
        return gan_training_tf_fns.server_computation(
            server_state=server_state,
            gen_inputs_ds=gen_inputs,
            client_output=client_output,
            generator=gan.generator_model_fn(),
            discriminator=gan.discriminator_model_fn(),
            server_disc_update_optimizer=gan.server_disc_update_optimizer_fn(),
            train_generator_fn=gan.train_generator_fn,
            new_dp_averaging_state=new_dp_averaging_state)

    return server_computation
Пример #18
0
def get_federated_tokenize_fn(dataset_name, dataset_element_type_structure):
    """Get a federated tokenizer function."""
    @tff.tf_computation(tff.SequenceType(dataset_element_type_structure))
    def tokenize_dataset(dataset):
        """The TF computation to tokenize a dataset."""
        dataset = tokenize(dataset, dataset_name)
        return dataset

    @tff.federated_computation(
        tff.type_at_clients(tff.SequenceType(dataset_element_type_structure)))
    def tokenize_datasets(datasets):
        """The TFF computation to compute tokenized datasets."""
        tokenized_datasets = tff.federated_map(tokenize_dataset, datasets)
        return tokenized_datasets

    return tokenize_datasets
Пример #19
0
    def test_iterative_process_type_signature(self):
        client_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)
        server_lr_callback = callbacks.create_reduce_lr_on_plateau(
            learning_rate=0.1,
            min_delta=0.5,
            window_size=2,
            decay_factor=1.0,
            cooldown=0)

        iterative_process = adaptive_fed_avg.build_fed_avg_process(
            _uncompiled_model_builder,
            client_lr_callback,
            server_lr_callback,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        lr_callback_type = tff.framework.type_from_tensors(client_lr_callback)

        server_state_type = tff.FederatedType(
            adaptive_fed_avg.ServerState(model=tff.learning.ModelWeights(
                trainable=(tff.TensorType(tf.float32, [1, 1]),
                           tff.TensorType(tf.float32, [1])),
                non_trainable=()),
                                         optimizer_state=[tf.int64],
                                         client_lr_callback=lr_callback_type,
                                         server_lr_callback=lr_callback_type),
            tff.SERVER)

        self.assertEqual(
            iterative_process.initialize.type_signature,
            tff.FunctionType(parameter=None, result=server_state_type))

        dataset_type = tff.FederatedType(
            tff.SequenceType(
                collections.OrderedDict(
                    x=tff.TensorType(tf.float32, [None, 1]),
                    y=tff.TensorType(tf.float32, [None, 1]))), tff.CLIENTS)

        metrics_type = tff.FederatedType(
            collections.OrderedDict(loss=tff.TensorType(tf.float32)),
            tff.SERVER)
        output_type = collections.OrderedDict(before_training=metrics_type,
                                              during_training=metrics_type)

        expected_result_type = (server_state_type, output_type)
        expected_type = tff.FunctionType(parameter=collections.OrderedDict(
            server_state=server_state_type, federated_dataset=dataset_type),
                                         result=expected_result_type)

        actual_type = iterative_process.next.type_signature
        self.assertEqual(actual_type,
                         expected_type,
                         msg='{s}\n!={t}'.format(s=actual_type,
                                                 t=expected_type))
Пример #20
0
    def test_executes_passthru_dataset(self):
        @tff.tf_computation(tff.SequenceType(tf.int64))
        def passthru_dataset(ds):
            return ds

        input_data = tf.data.Dataset.range(10)
        ds = passthru_dataset(input_data)
        self.assertIsInstance(ds, tf.data.Dataset)
    def test_build_with_preprocess_function(self):
        test_dataset = tf.data.Dataset.range(5)
        client_datasets_type = tff.type_at_clients(
            tff.SequenceType(test_dataset.element_spec))

        @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
        def preprocess_dataset(ds):
            def create_batch(x):
                return collections.OrderedDict(
                    x=[tf.cast(x, dtype=tf.float32)], y=[2.0])

            return ds.map(create_batch).batch(2)

        iterproc = fed_avg_schedule.build_fed_avg_process(
            model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            client_lr=0.01,
            server_optimizer_fn=tf.keras.optimizers.SGD)

        iterproc = tff.simulation.compose_dataset_computation_with_iterative_process(
            preprocess_dataset, iterproc)

        with tf.Graph().as_default():
            test_model_for_types = model_builder()

        server_state_type = tff.FederatedType(
            fed_avg_schedule.ServerState(model=tff.framework.type_from_tensors(
                tff.learning.ModelWeights(
                    test_model_for_types.trainable_variables,
                    test_model_for_types.non_trainable_variables)),
                                         optimizer_state=(tf.int64, ),
                                         round_num=tf.float32), tff.SERVER)
        metrics_type = test_model_for_types.federated_output_computation.type_signature.result

        expected_parameter_type = collections.OrderedDict(
            server_state=server_state_type,
            federated_dataset=client_datasets_type,
        )
        expected_result_type = (server_state_type, metrics_type)

        expected_type = tff.FunctionType(parameter=expected_parameter_type,
                                         result=expected_result_type)
        self.assertTrue(
            iterproc.next.type_signature.is_equivalent_to(expected_type),
            msg='{s}\n!={t}'.format(s=iterproc.next.type_signature,
                                    t=expected_type))
Пример #22
0
def create_preprocess_fn(
    num_epochs: int,
    batch_size: int,
    shuffle_buffer_size: int = MAX_CLIENT_DATASET_SIZE,
    emnist_task: str = 'digit_recognition',
    num_parallel_calls: tf.Tensor = tf.data.experimental.AUTOTUNE
) -> tff.Computation:
    """Creates a preprocessing function for EMNIST client datasets.

  The preprocessing shuffles, repeats, batches, and then reshapes, using
  the `shuffle`, `repeat`, `batch`, and `map` attributes of a
  `tf.data.Dataset`, in that order.

  Args:
    num_epochs: An integer representing the number of epochs to repeat the
      client datasets.
    batch_size: An integer representing the batch size on clients.
    shuffle_buffer_size: An integer representing the shuffle buffer size on
      clients. If set to a number <= 1, no shuffling occurs.
    emnist_task: A string indicating the EMNIST task being performed. Must be
      one of 'digit_recognition' or 'autoencoder'. If the former, then elements
      are mapped to tuples of the form (pixels, label), if the latter then
      elements are mapped to tuples of the form (pixels, pixels).
    num_parallel_calls: An integer representing the number of parallel calls
      used when performing `tf.data.Dataset.map`.

  Returns:
    A `tff.Computation` performing the preprocessing discussed above.
  """
    if num_epochs < 1:
        raise ValueError('num_epochs must be a positive integer.')
    if shuffle_buffer_size <= 1:
        shuffle_buffer_size = 1

    if emnist_task == 'digit_recognition':
        mapping_fn = _reshape_for_digit_recognition
    elif emnist_task == 'autoencoder':
        mapping_fn = _reshape_for_autoencoder
    else:
        raise ValueError('emnist_task must be one of "digit_recognition" or '
                         '"autoencoder".')

    # Features are intentionally sorted lexicographically by key for consistency
    # across datasets.
    feature_dtypes = collections.OrderedDict(label=tff.TensorType(tf.int32),
                                             pixels=tff.TensorType(tf.float32,
                                                                   shape=(28,
                                                                          28)))

    @tff.tf_computation(tff.SequenceType(feature_dtypes))
    def preprocess_fn(dataset):
        return dataset.shuffle(shuffle_buffer_size).repeat(num_epochs).batch(
            batch_size,
            drop_remainder=False).map(mapping_fn,
                                      num_parallel_calls=num_parallel_calls)

    return preprocess_fn
Пример #23
0
  def test_build_with_preprocess_function(self):
    test_dataset = tf.data.Dataset.range(5)
    client_datasets_type = tff.FederatedType(
        tff.SequenceType(test_dataset.element_spec), tff.CLIENTS)

    @tff.tf_computation(tff.SequenceType(test_dataset.element_spec))
    def preprocess_dataset(ds):

      def to_batch(x):
        return _Batch(
            tf.fill(dims=(784,), value=float(x) * 2.0),
            tf.expand_dims(tf.cast(x + 1, dtype=tf.int64), axis=0))

      return ds.map(to_batch).batch(2)

    iterproc_adapter = fed_avg_schedule.build_fed_avg_process(
        _uncompiled_model_builder,
        client_optimizer_fn=tf.keras.optimizers.SGD,
        server_optimizer_fn=tf.keras.optimizers.SGD,
        dataset_preprocess_comp=preprocess_dataset)

    with tf.Graph().as_default():
      test_model_for_types = _uncompiled_model_builder()

    iterproc = iterproc_adapter._iterative_process

    server_state_type = tff.FederatedType(
        fed_avg_schedule.ServerState(
            model=tff.framework.type_from_tensors(
                fed_avg_schedule.ModelWeights(
                    test_model_for_types.trainable_variables,
                    test_model_for_types.non_trainable_variables)),
            optimizer_state=(tf.int64,),
            round_num=tf.float32), tff.SERVER)
    metrics_type = test_model_for_types.federated_output_computation.type_signature.result

    expected_type = tff.FunctionType(
        parameter=(server_state_type, client_datasets_type),
        result=(server_state_type, metrics_type))
    self.assertEqual(
        iterproc.next.type_signature,
        expected_type,
        msg='{s}\n!={t}'.format(
            s=iterproc.next.type_signature, t=expected_type))
Пример #24
0
def build_server_computation(gan: GanFnsAndTypes, server_state_type: tff.Type,
                             client_output_type: tff.Type):
    """Returns a `tff.tf_computation` for the `server_computation`.

  This is a thin wrapper around `gan_training_tf_fns.server_computation`.

  Args:
    gan: A `GanFnsAndTypes` object.
    server_state_type: The `tff.Type` of the ServerState.
    client_output_type: The `tff.Type` of the ClientOutput.

  Returns:
    A `tff.tf_computation.`
  """
    @tff.tf_computation(server_state_type,
                        tff.SequenceType(gan.gen_input_type),
                        client_output_type, gan.dp_averaging_state_type)
    def server_computation(server_state, gen_inputs, client_output,
                           new_dp_averaging_state):
        """The wrapped server_computation."""
        # initialize the optimizers beforehand so you don't create them within the tf.function
        steps = server_state.counters['num_rounds']
        scheduler = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            [1000], [0.001, 0.0005])
        state_gen_optimizer = gan.state_gen_optimizer_fn(
            scheduler.__call__(steps))
        generator = gan.generator_model_fn()
        gan_training_tf_fns.initialize_optimizer_vars(generator,
                                                      state_gen_optimizer)
        discriminator = gan.discriminator_model_fn(0.0002)
        state_disc_optimizer = gan.state_disc_optimizer_fn(steps)
        gan_training_tf_fns.initialize_optimizer_vars(discriminator,
                                                      state_disc_optimizer)

        if gan.disc_status == 'fedadam':
            return gan_training_tf_fns.server_computation_fedadam(
                server_state=server_state,
                gen_inputs_ds=gen_inputs,
                client_output=client_output,
                generator=generator,
                discriminator=discriminator,
                state_disc_optimizer=state_disc_optimizer,
                state_gen_optimizer=state_gen_optimizer,
                new_dp_averaging_state=new_dp_averaging_state)
        else:
            return gan_training_tf_fns.server_computation(
                server_state=server_state,
                gen_inputs_ds=gen_inputs,
                client_output=client_output,
                generator=generator,
                discriminator=discriminator,
                state_disc_optimizer=state_disc_optimizer,
                state_gen_optimizer=state_gen_optimizer,
                new_dp_averaging_state=new_dp_averaging_state)

    return server_computation
Пример #25
0
def create_preprocess_fn(
    num_epochs: int,
    batch_size: int,
    shuffle_buffer_size: int = 50,
    sequence_length: int = SEQUENCE_LENGTH,
    num_parallel_calls: int = tf.data.experimental.AUTOTUNE
) -> tff.Computation:
    """Creates a preprocessing function for Shakespeare client datasets.

  This function maps a dataset of string snippets to a dataset of input/output
  character ID sequences. This is done by first repeating the dataset and
  shuffling (according to `num_epochs` and `shuffle_buffer_size`), mapping
  the the string sequences to tokens, and packing them into input/output
  sequences of length `sequence_length`.

  Args:
    num_epochs: An integer representing the number of epochs to repeat the
      client datasets.
    batch_size: An integer representing the batch size on clients.
    shuffle_buffer_size: An integer representing the shuffle buffer size on
      clients. If set to a number <= 1, no shuffling occurs.
    sequence_length: the length of each example in the batch.
    num_parallel_calls: An integer representing the number of parallel calls
      used when performing `tf.data.Dataset.map`.

  Returns:
    A `tff.Computation` performing the preprocessing described above.
  """
    if num_epochs < 1:
        raise ValueError('num_epochs must be a positive integer.')
    if sequence_length < 1:
        raise ValueError('sequence_length must be a positive integer.')
    if shuffle_buffer_size <= 1:
        shuffle_buffer_size = 1

    feature_dtypes = collections.OrderedDict(snippets=tf.string, )

    @tff.tf_computation(tff.SequenceType(feature_dtypes))
    def preprocess_fn(dataset):
        to_tokens = _build_tokenize_fn(split_length=sequence_length + 1)
        return (
            dataset.shuffle(shuffle_buffer_size).repeat(num_epochs)
            # Convert snippets to int64 tokens and pad.
            .map(to_tokens, num_parallel_calls=num_parallel_calls)
            # Separate into individual tokens
            .unbatch()
            # Join into sequences of the desired length. The previous call of
            # map(to_ids,...) ensures that the collection of tokens has length
            # divisible by sequence_length + 1, so no batch dropping is expected.
            .batch(sequence_length + 1, drop_remainder=True)
            # Batch sequences together for mini-batching purposes.
            .batch(batch_size)
            # Convert batches into training examples.
            .map(_split_target, num_parallel_calls=num_parallel_calls))

    return preprocess_fn
def _create_tff_parallel_clients_with_dataset_reduce():
    @tf.function
    def reduce_fn(x, y):
        return x + y

    @tf.function
    def dataset_reduce_fn(ds, initial_val):
        return ds.reduce(initial_val, reduce_fn)

    @tff.tf_computation(tff.SequenceType(tf.int64))
    def dataset_reduce_fn_wrapper(ds):
        initial_val = tf.Variable(np.int64(1.0))
        return dataset_reduce_fn(ds, initial_val)

    @tff.federated_computation(tff.at_clients(tff.SequenceType(tf.int64)))
    def parallel_client_run(client_datasets):
        return tff.federated_map(dataset_reduce_fn_wrapper, client_datasets)

    return parallel_client_run
Пример #27
0
def build_federated_evaluation(
    model_fn: Callable[[], tff.learning.Model],
    metrics_builder: Callable[[], Sequence[tf.keras.metrics.Metric]]
) -> tff.federated_computation:
    """Builds a federated evaluation `tff.federated_computation`.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    metrics_builder: A no-arg function that returns a sequence of
      `tf.keras.metrics.Metric` objects. These metrics must have a callable
      `update_state` accepting `y_true` and `y_pred` arguments, corresponding to
      the true and predicted label, respectively.

  Returns:
    A `tff.federated_computation` that accepts model weights and federated data,
    and returns the evaluation metrics, aggregated in both uniform- and
    example-weighted manners.
  """
    # Wrap model construction in a graph to avoid polluting the global context
    # with variables created for this model.
    with tf.Graph().as_default():
        placeholder_model = model_fn()
        model_weights_type = tff.learning.framework.weights_type_from_model(
            placeholder_model)
        model_input_type = tff.SequenceType(placeholder_model.input_spec)

    @tff.tf_computation(model_weights_type, model_input_type)
    def compute_client_metrics(model_weights, federated_dataset):
        model = model_fn()
        metrics = metrics_builder()
        return compute_metrics(model, model_weights, metrics,
                               federated_dataset)

    @tff.federated_computation(tff.type_at_server(model_weights_type),
                               tff.type_at_clients(model_input_type))
    def federated_evaluate(model_weights, federated_dataset):
        client_model = tff.federated_broadcast(model_weights)
        client_metrics = tff.federated_map(compute_client_metrics,
                                           (client_model, federated_dataset))
        # Extract the number of examples in order to compute client weights
        num_examples = client_metrics.num_examples
        uniform_weighted_metrics = tff.federated_mean(client_metrics,
                                                      weight=None)
        example_weighted_metrics = tff.federated_mean(client_metrics,
                                                      weight=num_examples)
        # Aggregate the metrics in a single nested dictionary
        aggregate_metrics = collections.OrderedDict()
        aggregate_metrics[AggregationMethods.EXAMPLE_WEIGHTED.
                          value] = example_weighted_metrics
        aggregate_metrics[AggregationMethods.UNIFORM_WEIGHTED.
                          value] = uniform_weighted_metrics

        return aggregate_metrics

    return federated_evaluate
Пример #28
0
def build_federated_averaging_process_attacked(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    aggregation_process=None,
    client_update_tf=ClientExplicitBoosting(boost_factor=1.0)):
    """Builds the TFF computations for optimization using federated averaging with potentially malicious clients.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use during local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use to apply updates to the global model.
    aggregation_process: A 'tff.templates.MeasuredProcess' that aggregates model
      deltas placed@CLIENTS to an aggregated model delta placed@SERVER.
    client_update_tf: a 'tf.function' computes the ClientOutput.

  Returns:
    A `tff.templates.IterativeProcess`.
  """
    with tf.Graph().as_default():
        dummy_model_for_metadata = model_fn()
        weights_type = tff.learning.framework.weights_type_from_model(
            dummy_model_for_metadata)

    if aggregation_process is None:
        aggregation_process = tff.learning.framework.build_stateless_mean(
            model_delta_type=weights_type.trainable)

    server_init = build_server_init_fn(model_fn, server_optimizer_fn,
                                       aggregation_process.initialize)
    server_state_type = server_init.type_signature.result.member
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model)
    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)

    client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                              client_update_tf,
                                              tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.type_at_server(server_state_type)

    federated_dataset_type = tff.type_at_clients(tf_dataset_type)

    run_one_round_tff = build_run_one_round_fn_attacked(
        server_update_fn, client_update_fn, aggregation_process,
        dummy_model_for_metadata, federated_server_state_type,
        federated_dataset_type)

    return tff.templates.IterativeProcess(initialize_fn=server_init,
                                          next_fn=run_one_round_tff)
Пример #29
0
def create_preprocess_fn(
    num_epochs: int,
    batch_size: int,
    shuffle_buffer_size: int = NUM_EXAMPLES_PER_CLIENT,
    crop_shape: Tuple[int, int, int] = CIFAR_SHAPE,
    distort_image=False,
    num_parallel_calls: int = tf.data.experimental.AUTOTUNE) -> tff.Computation:
  """Creates a preprocessing function for CIFAR-100 client datasets.

  Args:
    num_epochs: An integer representing the number of epochs to repeat the
      client datasets.
    batch_size: An integer representing the batch size on clients.
    shuffle_buffer_size: An integer representing the shuffle buffer size on
      clients. If set to a number <= 1, no shuffling occurs.
    crop_shape: A tuple (crop_height, crop_width, num_channels) specifying the
      desired crop shape for pre-processing. This tuple cannot have elements
      exceeding (32, 32, 3), element-wise. The element in the last index should
      be set to 3 to maintain the RGB image structure of the elements.
    distort_image: A boolean indicating whether to perform preprocessing that
      includes image distortion, including random crops and flips.
    num_parallel_calls: An integer representing the number of parallel calls
      used when performing `tf.data.Dataset.map`.

  Returns:
    A `tff.Computation` performing the preprocessing described above.
  """
  if num_epochs < 1:
    raise ValueError('num_epochs must be a positive integer.')
  if shuffle_buffer_size <= 1:
    shuffle_buffer_size = 1

  # Features are intentionally sorted lexicographically by key for consistency
  # across datasets.
  feature_dtypes = collections.OrderedDict(
      coarse_label=tff.TensorType(tf.int64),
      image=tff.TensorType(tf.uint8, shape=(32, 32, 3)),
      label=tff.TensorType(tf.int64))

  image_map_fn = build_image_map(crop_shape, distort_image)

  @tff.tf_computation(tff.SequenceType(feature_dtypes))
  def preprocess_fn(dataset):
    return (
        dataset.shuffle(shuffle_buffer_size).repeat(num_epochs)
        # We map before batching to ensure that the cropping occurs
        # at an image level (eg. we do not perform the same crop on
        # every image within a batch)
        .map(image_map_fn,
             num_parallel_calls=num_parallel_calls).batch(batch_size))

  return preprocess_fn
Пример #30
0
def build_federated_averaging_process_attacked(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    stateful_delta_aggregate_fn=build_stateless_mean(),
    client_update_tf=ClientExplicitBoosting(boost_factor=1.0)):
    """Builds the TFF computations for optimization using federated averaging with potentially malicious clients.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use during local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use to apply updates to the global model.
    stateful_delta_aggregate_fn: A 'tff.computation' that aggregates model
      deltas placed@CLIENTS to an aggregated model delta placed@SERVER.
    client_update_tf: a 'tf.function' computes the ClientOutput.

  Returns:
    A `tff.templates.IterativeProcess`.
  """

    dummy_model_for_metadata = model_fn()

    server_init_tf = build_server_init_fn(
        model_fn, server_optimizer_fn,
        stateful_delta_aggregate_fn.initialize())
    server_state_type = server_init_tf.type_signature.result
    server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                              server_state_type,
                                              server_state_type.model)
    tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)

    client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                              client_update_tf,
                                              tf_dataset_type,
                                              server_state_type.model)

    federated_server_state_type = tff.FederatedType(server_state_type,
                                                    tff.SERVER)

    federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

    run_one_round_tff = build_run_one_round_fn_attacked(
        server_update_fn, client_update_fn, stateful_delta_aggregate_fn,
        dummy_model_for_metadata, federated_server_state_type,
        federated_dataset_type)

    return tff.templates.IterativeProcess(
        initialize_fn=tff.federated_computation(
            lambda: tff.federated_value(server_init_tf(), tff.SERVER)),
        next_fn=run_one_round_tff)