Beispiel #1
0
 def test_construction_calls_model_fn(self):
     # Assert that the the process building does not call `model_fn` too many
     # times. `model_fn` can potentially be expensive (loading weights,
     # processing, etc).
     mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression)
     federated_averaging.build_federated_averaging_process(
         model_fn=mock_model_fn,
         client_optimizer_fn=tf.keras.optimizers.SGD)
     # TODO(b/186451541): reduce the number of calls to model_fn.
     self.assertEqual(mock_model_fn.call_count, 4)
Beispiel #2
0
 def test_fails_stateful_aggregate_and_process(self):
     model_weights_type = model_utils.weights_type_from_model(
         model_examples.LinearRegression)
     with self.assertRaises(optimizer_utils.DisjointArgumentError):
         federated_averaging.build_federated_averaging_process(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=tf.keras.optimizers.SGD,
             stateful_delta_aggregate_fn=tff.utils.StatefulAggregateFn(
                 initialize_fn=lambda: (),
                 next_fn=lambda state, value, weight=None:  # pylint: disable=g-long-lambda
                 (state, tff.federated_mean(value, weight))),
             aggregation_process=optimizer_utils.build_stateless_mean(
                 model_delta_type=model_weights_type.trainable))
Beispiel #3
0
 def test_fails_stateful_broadcast_and_process(self):
     model_weights_type = model_utils.weights_type_from_model(
         model_examples.LinearRegression)
     with self.assertRaises(optimizer_utils.DisjointArgumentError):
         federated_averaging.build_federated_averaging_process(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=tf.keras.optimizers.SGD,
             stateful_model_broadcast_fn=tff.utils.StatefulBroadcastFn(
                 initialize_fn=lambda: (),
                 next_fn=lambda state, weights:  # pylint: disable=g-long-lambda
                 (state, tff.federated_broadcast(weights))),
             broadcast_process=optimizer_utils.build_stateless_broadcaster(
                 model_weights_type=model_weights_type))
Beispiel #4
0
 def test_aggregation_process_deprecation(self):
     aggregation_process = mean.MeanFactory().create(
         computation_types.to_type([(tf.float32, (2, 1)), tf.float32]),
         computation_types.TensorType(tf.float32))
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter('always')
         federated_averaging.build_federated_averaging_process(
             model_fn=model_examples.LinearRegression,
             client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1),
             aggregation_process=aggregation_process)
         self.assertNotEmpty(w)
         self.assertEqual(w[0].category, DeprecationWarning)
         self.assertRegex(str(w[0].message),
                          'aggregation_process .* is deprecated')
Beispiel #5
0
  def test_orchestration_execute_from_keras(self, build_keras_model_fn):
    dummy_batch = collections.OrderedDict(
        x=np.zeros([1, 2], np.float32), y=np.zeros([1, 1], np.float32))

    def model_fn():
      keras_model = build_keras_model_fn(feature_dims=2)
      return keras_utils.from_keras_model(
          keras_model,
          dummy_batch,
          loss=tf.keras.losses.MeanSquaredError(),
          metrics=[])

    iterative_process = federated_averaging.build_federated_averaging_process(
        model_fn=model_fn,
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))

    ds = tf.data.Dataset.from_tensor_slices(
        collections.OrderedDict(
            x=[[1.0, 2.0], [3.0, 4.0]],
            y=[[5.0], [6.0]],
        )).batch(2)
    federated_ds = [ds] * 3

    server_state = iterative_process.initialize()

    prev_loss = np.inf
    for _ in range(3):
      server_state, metrics = iterative_process.next(server_state, federated_ds)
      self.assertLess(metrics.loss, prev_loss)
      prev_loss = metrics.loss
Beispiel #6
0
  def test_orchestration_execute_from_keras(self, build_keras_model_fn):
    dummy_batch = collections.OrderedDict([
        ('x', np.zeros([1, 2], np.float32)),
        ('y', np.zeros([1, 1], np.float32)),
    ])

    def model_fn():
      keras_model = build_keras_model_fn(feature_dims=2)
      keras_model.compile(
          optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
          loss=tf.keras.losses.MeanSquaredError(),
          metrics=[])
      return keras_utils.from_compiled_keras_model(keras_model, dummy_batch)

    iterative_process = federated_averaging.build_federated_averaging_process(
        model_fn=model_fn)

    ds = tf.data.Dataset.from_tensor_slices({
        'x': [[1., 2.], [3., 4.]],
        'y': [[5.], [6.]]
    }).batch(2)
    federated_ds = [ds] * 3

    server_state = iterative_process.initialize()

    prev_loss = np.inf
    for _ in range(3):
      server_state, metrics = iterative_process.next(server_state, federated_ds)
      self.assertLess(metrics.loss, prev_loss)
      prev_loss = metrics.loss
Beispiel #7
0
async def fed_avg():
    await load_data()
    
    iterative_process = federated_averaging.build_federated_averaging_process(
        model_fn=model_examples.LinearRegression,
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1))
    
    model = iterative_process.initialize()
    
    self.assertIsInstance(
        iterative_process.get_model_weights(model), model_utils.ModelWeights)
    self.assertAllClose(model.model.trainable,
                        iterative_process.get_model_weights(model).trainable)
    
    for _ in range(num_clients):
        model, _ = iterative_process.next(model, datasets)
        # self.assertIsInstance(
        #     iterative_process.get_model_weights(model), model_utils.ModelWeights)
        # self.assertAllClose(model.model.trainable,
        #                     iterative_process.get_model_weights(model).trainable)
      
    
    model.save_weights('consolidated.h5', save_format = 'h5')
    update_model(model)
    
    def test_orchestration_execute_from_keras_with_lookup(self):
        self.skipTest('https://github.com/tensorflow/federated/issues/783')

        def model_fn():
            dummy_batch = collections.OrderedDict([
                ('x', tf.constant([['R']], tf.string)),
                ('y', tf.zeros([1, 1], tf.float32)),
            ])
            keras_model = model_examples.build_lookup_table_keras_model()
            return keras_utils.from_compiled_keras_model(
                keras_model,
                dummy_batch,
                loss=tf.keras.losses.MeanSquaredError(),
                metrics=[])

        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_fn,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=
                                                                0.1))

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [['R'], ['G'], ['B']]),
                ('y', [[1.0], [2.0], [3.0]]),
            ])).batch(2)
        federated_ds = [ds] * 3

        server_state = iterative_process.initialize()

        prev_loss = np.inf
        for _ in range(3):
            server_state, metrics = iterative_process.next(
                server_state, federated_ds)
            self.assertLess(metrics.loss, prev_loss)
            prev_loss = metrics.loss
    def test_orchestration_execute_from_keras_with_lookup(self):
        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(x=[['R'], ['G'], ['B']],
                                    y=[[1.0], [2.0], [3.0]])).batch(2)

        def model_fn():
            keras_model = model_examples.build_lookup_table_keras_model()
            return keras_utils.from_keras_model(
                keras_model,
                loss=tf.keras.losses.MeanSquaredError(),
                input_spec=ds.element_spec,
                metrics=[])

        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_fn,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=
                                                                0.1))

        federated_ds = [ds] * 3

        server_state = iterative_process.initialize()

        prev_loss = np.inf
        for _ in range(3):
            server_state, metrics = iterative_process.next(
                server_state, federated_ds)
            self.assertLess(metrics.loss, prev_loss)
            prev_loss = metrics.loss
Beispiel #10
0
    def test_get_model_weights(self):
        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_examples.LinearRegression,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=
                                                                0.1))

        num_clients = 3
        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                x=[[1.0, 2.0], [3.0, 4.0]],
                y=[[5.0], [6.0]],
            )).batch(2)
        datasets = [ds] * num_clients

        state = iterative_process.initialize()
        self.assertIsInstance(iterative_process.get_model_weights(state),
                              model_utils.ModelWeights)
        self.assertAllClose(
            state.model.trainable,
            iterative_process.get_model_weights(state).trainable)

        for _ in range(3):
            state, _ = iterative_process.next(state, datasets)
            self.assertIsInstance(iterative_process.get_model_weights(state),
                                  model_utils.ModelWeights)
            self.assertAllClose(
                state.model.trainable,
                iterative_process.get_model_weights(state).trainable)
    def benchmark_simple_execution(self, executor_id):
        num_clients = 10
        num_client_samples = 20
        batch_size = 4
        num_rounds = 10

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(x=[[1., 2.]] * num_client_samples,
                                    y=[[5.]] *
                                    num_client_samples)).batch(batch_size)

        federated_ds = [ds] * num_clients

        building_time_array = []
        build_time_start = time.time()
        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_examples.TrainableLinearRegression)
        build_time_stop = time.time()
        building_time_array.append(build_time_stop - build_time_start)
        name = ('computation_building_time, simple execution '
                'TrainableLinearRegression, executor {}'.format(executor_id))
        self.report_benchmark(name=name,
                              wall_time=np.mean(building_time_array),
                              iters=1)

        initialization_array = []
        initialization_start = time.time()
        initial_state = iterative_process.initialize()
        initialization_stop = time.time()
        initialization_array.append(initialization_stop - initialization_start)
        name = ('computation_initialization_time, simple execution '
                'TrainableLinearRegression, executor {}'.format(executor_id))
        self.report_benchmark(name=name,
                              wall_time=np.mean(initialization_array),
                              iters=1)

        next_state = initial_state

        execution_array = []
        for _ in range(num_rounds - 1):
            round_start = time.time()
            next_state, _ = iterative_process.next(next_state, federated_ds)
            round_stop = time.time()
            execution_array.append(round_stop - round_start)
        name = (
            'Average per round time, {clients} clients, {examples} examples '
            'per client, batch size {batch_size}, TrainableLinearRegression, '
            'executor {executor}'.format(clients=num_clients,
                                         examples=num_client_samples,
                                         batch_size=batch_size,
                                         executor=executor_id))
        self.report_benchmark(name=name,
                              wall_time=np.mean(execution_array),
                              iters=num_rounds,
                              extras={'std_dev': np.std(execution_array)})
Beispiel #12
0
    def benchmark_simple_execution(self):
        num_clients = 10
        num_client_samples = 20
        batch_size = 4
        num_rounds = 10

        ds = tf.data.Dataset.from_tensor_slices({
            "x": [[1., 2.]] * num_client_samples,
            "y": [[5.]] * num_client_samples
        }).batch(batch_size)

        federated_ds = [ds] * num_clients

        building_time_array = []
        build_time_start = time.time()
        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_examples.TrainableLinearRegression)
        build_time_stop = time.time()
        building_time_array.append(build_time_stop - build_time_start)
        self.report_benchmark(
            name="computation_building_time, simple execution "
            "TrainableLinearRegression",
            wall_time=np.mean(building_time_array),
            iters=1)

        initialization_array = []
        initialization_start = time.time()
        server_state = iterative_process.initialize()
        initialization_stop = time.time()
        initialization_array.append(initialization_stop - initialization_start)
        self.report_benchmark(
            name="computation_initialization_time, simple execution "
            "TrainableLinearRegression",
            wall_time=np.mean(initialization_array),
            iters=1)

        next_state = server_state

        execution_array = []
        next_state, _ = iterative_process.next(server_state, federated_ds)
        for _ in range(num_rounds - 1):
            round_start = time.time()
            next_state, _ = iterative_process.next(next_state, federated_ds)
            round_stop = time.time()
            execution_array.append(round_stop - round_start)
        self.report_benchmark(name="Time to execute {} rounds, {} clients, "
                              "{} examples per client, batch size {}, "
                              "TrainableLinearRegression".format(
                                  num_rounds, num_clients, num_client_samples,
                                  batch_size),
                              wall_time=np.mean(execution_array),
                              iters=num_rounds,
                              extras={"std_dev": np.std(execution_array)})
  def test_orchestration_typecheck(self):
    iterative_process = federated_averaging.build_federated_averaging_process(
        model_fn=model_examples.TrainableLinearRegression)

    expected_model_weights_type = model_utils.ModelWeights(
        collections.OrderedDict([('a', tff.TensorType(tf.float32, [2, 1])),
                                 ('b', tf.float32)]),
        collections.OrderedDict([('c', tf.float32)]))

    # ServerState consists of a model and optimizer_state. The optimizer_state
    # is provided by TensorFlow, TFF doesn't care what the actual value is.
    expected_federated_server_state_type = tff.FederatedType(
        optimizer_utils.ServerState(expected_model_weights_type,
                                    test.AnyType()),
        placement=tff.SERVER,
        all_equal=True)

    expected_federated_dataset_type = tff.FederatedType(
        tff.SequenceType(
            model_examples.TrainableLinearRegression.make_batch(
                tff.TensorType(tf.float32, [None, 2]),
                tff.TensorType(tf.float32, [None, 1]))),
        tff.CLIENTS,
        all_equal=False)

    expected_model_output_types = tff.FederatedType(
        collections.OrderedDict([
            ('loss', tff.TensorType(tf.float32, [])),
            ('num_examples', tff.TensorType(tf.int32, [])),
        ]),
        tff.SERVER,
        all_equal=True)

    # `initialize` is expected to be a funcion of no arguments to a ServerState.
    self.assertEqual(
        tff.FunctionType(
            parameter=None, result=expected_federated_server_state_type),
        iterative_process.initialize.type_signature)

    # `next` is expected be a function of (ServerState, Datasets) to
    # ServerState.
    self.assertEqual(
        tff.FunctionType(
            parameter=[
                expected_federated_server_state_type,
                expected_federated_dataset_type
            ],
            result=(expected_federated_server_state_type,
                    expected_model_output_types)),
        iterative_process.next.type_signature)
Beispiel #14
0
  def benchmark_learning_keras_model_mnist(self, executor_id):
    """Code adapted from MNIST learning tutorial ipynb."""
    federated_train_data = generate_fake_mnist_data()
    n_rounds = 10
    computation_building_start = time.time()

    # pylint: disable=missing-docstring
    def model_fn():
      model = tf.keras.models.Sequential([
          tf.keras.layers.Flatten(input_shape=(784,)),
          tf.keras.layers.Dense(
              10,
              kernel_initializer="zeros",
              bias_initializer="zeros",
              activation=tf.nn.softmax)
      ])

      model.compile(
          loss=tf.keras.losses.SparseCategoricalCrossentropy(),
          optimizer=tf.keras.optimizers.SGD(0.1))

      return keras_utils.from_compiled_keras_model(model,
                                                   federated_train_data[0][0])

    iterative_process = federated_averaging.build_federated_averaging_process(
        model_fn)
    computation_building_stop = time.time()
    building_time = computation_building_stop - computation_building_start
    self.report_benchmark(
        name="computation_building_time, "
        "tff.learning Keras model, executor {}".format(executor_id),
        wall_time=building_time,
        iters=1)

    state = iterative_process.initialize()

    execution_array = []
    for _ in range(n_rounds):
      execution_start = time.time()
      state, _ = iterative_process.next(state, federated_train_data)
      execution_stop = time.time()
      execution_array.append(execution_stop - execution_start)

    self.report_benchmark(
        name="Average per round execution time, "
        "tff.learning Keras model, executor {}".format(executor_id),
        wall_time=np.mean(execution_array),
        iters=n_rounds,
        extras={"std_dev": np.std(execution_array)})
  def benchmark_learning_keras_model_mnist(self, executor_id):
    """Code adapted from MNIST learning tutorial ipynb."""
    federated_train_data = generate_fake_mnist_data()
    x_type = tf.TensorSpec(shape=(None, 784), dtype=tf.float32)
    y_type = tf.TensorSpec(shape=(None,), dtype=tf.int32)
    input_spec = collections.OrderedDict(x=x_type, y=y_type)
    n_rounds = 10
    computation_building_start = time.time()

    # pylint: disable=missing-docstring
    def model_fn():
      model = tf.keras.models.Sequential([
          tf.keras.layers.Flatten(input_shape=(784,)),
          tf.keras.layers.Dense(
              10,
              kernel_initializer='zeros',
              bias_initializer='zeros',
              activation=tf.nn.softmax)
      ])
      return keras_utils.from_keras_model(
          model,
          input_spec=input_spec,
          loss=tf.keras.losses.SparseCategoricalCrossentropy())

    iterative_process = federated_averaging.build_federated_averaging_process(
        model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
    computation_building_stop = time.time()
    building_time = computation_building_stop - computation_building_start
    name = ('computation_building_time, tff.learning Keras model, executor {}'
            .format(executor_id))
    self.report_benchmark(name=name, wall_time=building_time, iters=1)

    state = iterative_process.initialize()

    execution_array = []
    for _ in range(n_rounds):
      execution_start = time.time()
      state, _ = iterative_process.next(state, federated_train_data)
      execution_stop = time.time()
      execution_array.append(execution_stop - execution_start)

    name = ('Average per round execution time, tff.learning Keras model, '
            'executor {}'.format(executor_id))
    self.report_benchmark(
        name=name,
        wall_time=np.mean(execution_array),
        iters=n_rounds,
        extras={'std_dev': np.std(execution_array)})
    def test_basic_orchestration_execute(self):
        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_examples.LinearRegression,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=
                                                                0.1))

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                x=[[1.0, 2.0], [3.0, 4.0]],
                y=[[5.0], [6.0]],
            )).batch(2)

        num_clients = 3
        self._run_test(iterative_process,
                       datasets=[ds] * num_clients,
                       expected_num_examples=2 * num_clients)
Beispiel #17
0
    def test_recommended_aggregations_execute(self, default_aggregation):
        process = federated_averaging.build_federated_averaging_process(
            model_fn=model_examples.LinearRegression,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=
                                                                0.1),
            model_update_aggregation_factory=default_aggregation())

        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                x=[[1.0, 2.0], [3.0, 4.0]],
                y=[[5.0], [6.0]],
            )).batch(2)

        num_clients = 3
        state = process.initialize()
        state, metrics = process.next(state, [ds] * num_clients)
        self.assertNotEmpty(metrics['aggregation'])
Beispiel #18
0
  def test_orchestration_execute(self):
    iterative_process = federated_averaging.build_federated_averaging_process(
        model_fn=model_examples.TrainableLinearRegression)

    ds = tf.data.Dataset.from_tensor_slices({
        'x': [[1., 2.], [3., 4.]],
        'y': [[5.], [6.]]
    }).batch(2)

    federated_ds = [ds] * 3

    server_state = iterative_process.initialize()

    prev_loss = np.inf
    for _ in range(3):
      server_state, metric_outputs = iterative_process.next(
          server_state, federated_ds)
      self.assertEqual(metric_outputs.num_examples, 2 * len(federated_ds))
      self.assertLess(metric_outputs.loss, prev_loss)
      prev_loss = metric_outputs.loss
Beispiel #19
0
    def test_execute_empty_data(self, client_optimizer):
        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_examples.LinearRegression,
            client_optimizer_fn=client_optimizer())

        # Results in empty dataset with correct types and shapes.
        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                x=[[1.0, 2.0]],
                y=[[5.0]],
            )).batch(5, drop_remainder=True)

        server_state = iterative_process.initialize()

        first_state, metric_outputs = iterative_process.next(
            server_state, [ds] * 2)
        self.assertAllClose(list(first_state.model.trainable),
                            [[[0.0], [0.0]], 0.0])
        self.assertEqual(metric_outputs['train']['num_examples'], 0)
        self.assertTrue(tf.math.is_nan(metric_outputs['train']['loss']))
    def test_execute_empty_data(self):
        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_examples.TrainableLinearRegression)

        # Results in empty dataset with correct types and shapes.
        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict([
                ('x', [[1.0, 2.0]]),
                ('y', [[5.0]]),
            ])).batch(5, drop_remainder=True)
        federated_ds = [ds] * 2

        server_state = iterative_process.initialize()

        first_state, metric_outputs = iterative_process.next(
            server_state, federated_ds)
        self.assertAllClose(list(first_state.model.trainable),
                            [[[0.0], [0.0]], 0.0])
        self.assertEqual(metric_outputs.num_examples, 0)
        self.assertTrue(tf.math.is_nan(metric_outputs.loss))
Beispiel #21
0
  def test_orchestration_execute(self):
    iterative_process = federated_averaging.build_federated_averaging_process(
        model_fn=model_examples.LinearRegression,
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1))

    ds = tf.data.Dataset.from_tensor_slices(
        collections.OrderedDict(
            x=[[1.0, 2.0], [3.0, 4.0]],
            y=[[5.0], [6.0]],
        )).batch(2)
    federated_ds = [ds] * 3

    server_state = iterative_process.initialize()

    prev_loss = np.inf
    for _ in range(3):
      server_state, metric_outputs = iterative_process.next(
          server_state, federated_ds)
      self.assertEqual(metric_outputs.num_examples, 2 * len(federated_ds))
      self.assertLess(metric_outputs.loss, prev_loss)
      prev_loss = metric_outputs.loss
Beispiel #22
0
    def test_orchestration_execute_from_keras_with_lookup(
            self, client_optimizer):
        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(x=[['R'], ['G'], ['B']],
                                    y=[[1.0], [2.0], [3.0]])).batch(2)

        def model_fn():
            keras_model = model_examples.build_lookup_table_keras_model()
            return keras_utils.from_keras_model(
                keras_model,
                loss=tf.keras.losses.MeanSquaredError(),
                input_spec=ds.element_spec,
                metrics=[NumExamplesCounter()])

        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_fn, client_optimizer_fn=client_optimizer())

        num_clients = 3
        self._run_test(iterative_process,
                       datasets=[ds] * num_clients,
                       expected_num_examples=3 * num_clients)
Beispiel #23
0
    def test_execute_empty_data(self):
        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_examples.TrainableLinearRegression)

        # Results in empty dataset with correct types and shapes.
        ds = tf.data.Dataset.from_tensor_slices({
            'x': [[1., 2.]],
            'y': [[5.]]
        }).batch(5, drop_remainder=True)

        federated_ds = [ds] * 2

        server_state = iterative_process.initialize()

        first_state, metric_outputs = iterative_process.next(
            server_state, federated_ds)
        self.assertEqual(
            self.evaluate(tf.reduce_sum(first_state.model.trainable.a)) +
            self.evaluate(tf.reduce_sum(first_state.model.trainable.b)), 0)
        self.assertEqual(metric_outputs.num_examples, 0)
        self.assertTrue(tf.is_nan(metric_outputs.loss))
Beispiel #24
0
    def test_orchestration_execute_from_keras(self, build_keras_model_fn):
        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                x=[[1.0, 2.0], [3.0, 4.0]],
                y=[[5.0], [6.0]],
            )).batch(2)

        def model_fn():
            keras_model = build_keras_model_fn(feature_dims=2)
            return keras_utils.from_keras_model(
                keras_model,
                loss=tf.keras.losses.MeanSquaredError(),
                input_spec=ds.element_spec,
                metrics=[NumExamplesCounter()])

        iterative_process = federated_averaging.build_federated_averaging_process(
            model_fn=model_fn,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=
                                                                0.01))

        num_clients = 3
        self._run_test(iterative_process,
                       datasets=[ds] * num_clients,
                       expected_num_examples=2 * num_clients)