def test_simple_training(self):
        it_process = flars_fedavg.build_federated_averaging_process(
            _keras_model_fn,
            client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
        server_state = it_process.initialize()

        # Test out manually setting weights:
        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=True)

        @tf.function
        def deterministic_batch():
            return collections.OrderedDict(x=tf.ones([1, 784],
                                                     dtype=tf.float32),
                                           y=tf.ones([1, 1], dtype=tf.int64))

        batch = deterministic_batch()
        federated_data = [[batch]]

        def keras_evaluate(state):
            state.model.assign_weights_to(keras_model)
            # N.B. The loss computed here won't match the loss computed by TFF because
            # of the Dropout layer.
            keras_model.test_on_batch(**batch)

        loss_list = []
        for _ in range(10):
            keras_evaluate(server_state)
            server_state, output = it_process.next(server_state,
                                                   federated_data)
            loss_list.append(output['loss'])
        keras_evaluate(server_state)

        self.assertLess(np.mean(loss_list[7:]), np.mean(loss_list[:2]))
 def test_construction(self):
     it_process = flars_fedavg.build_federated_averaging_process(
         _keras_model_fn,
         client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
     self.assertIsInstance(it_process, tff.templates.IterativeProcess)
     federated_data_type = it_process.next.type_signature.parameter[1]
     self.assertEqual(str(federated_data_type),
                      '{<x=float32[?,784],y=int64[?,1]>*}@CLIENTS')
 def test_self_contained_example_keras_model(self):
     client_data = create_client_data()
     train_data = [client_data()]
     trainer = flars_fedavg.build_federated_averaging_process(
         _keras_model_fn,
         client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
     state = trainer.initialize()
     losses = []
     for _ in range(2):
         state, outputs = trainer.next(state, train_data)
         # Track the loss.
         losses.append(outputs['loss'])
     self.assertLess(losses[1], losses[0])
def _federated_averaging_training_loop(model_fn,
                                       client_optimizer_fn,
                                       server_optimizer_fn,
                                       client_datasets_fn,
                                       evaluate_fn,
                                       total_rounds=500,
                                       rounds_per_eval=1,
                                       metrics_hook=None):
  """A simple example of training loop for the Federated Averaging algorithm.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    client_datasets_fn: A function that takes the round number, and returns a
      list of `tf.data.Datset`, one per client.
    evaluate_fn: A function that takes state, performs evaluation, and returns
      evaluations metrics.
    total_rounds: Number of rounds to train.
    rounds_per_eval: How often to call the  `metrics_hook` function.
    metrics_hook: A function taking arguments (training metrics,
      evaluation metrics, and round number). Optional.

  Returns:
    Final `ServerState`.
  """
  logging.info('Starting federated training loop')

  checkpoint_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name)
  checkpoint_manager_obj = checkpoint_manager.FileCheckpointManager(
      checkpoint_dir)

  if FLAGS.server_optimizer != 'flars':
    logging.error('Unsupported server_optimzier: %s', FLAGS.server_optimizer)
  else:
    iterative_process = flars_fedavg.build_federated_averaging_process(
        model_fn,
        client_optimizer_fn=client_optimizer_fn,
        server_optimizer_fn=server_optimizer_fn)

  # construct an initial state here to act as a checkpoint template
  inital_state = iterative_process.initialize()

  logging.info('Looking for checkpoints in \'%s\'', checkpoint_dir)
  state, round_num = checkpoint_manager_obj.load_latest_checkpoint(inital_state)

  if state is None:
    logging.info('No previous checkpoints, initializing experiment')
    state = inital_state
    round_num = 0
    if metrics_hook is not None:
      eval_metrics = evaluate_fn(state)
      metrics_hook({}, eval_metrics, round_num)
    checkpoint_manager_obj.save_checkpoint(state, 0)
  else:
    logging.info('Restarted from checkpoint round %d', round_num)

  while round_num < total_rounds:
    round_num += 1
    train_metrics = {}

    # Reset the executor to clear the cache, and clear the default graph to
    # garbage collect tf.Functions that will no longer be used.
    tff.backends.native.set_local_execution_context(max_fanout=25)
    tf.compat.v1.reset_default_graph()

    round_start_time = time.time()
    data_prep_start_time = time.time()
    train_data = client_datasets_fn(round_num)
    train_metrics['prepare_datasets_secs'] = time.time() - data_prep_start_time

    training_start_time = time.time()
    state, tff_train_metrics = iterative_process.next(state, train_data)
    tff_train_metrics = tff_train_metrics._asdict(recursive=True)
    train_metrics.update(tff_train_metrics)
    train_metrics['training_secs'] = time.time() - training_start_time

    logging.info('Round {:2d} elapsed time: {:.2f}s .'.format(
        round_num, (time.time() - round_start_time)))
    train_metrics['total_round_secs'] = time.time() - round_start_time

    if (round_num % FLAGS.rounds_per_checkpoint == 0 or
        round_num == total_rounds):
      save_checkpoint_start_time = time.time()
      checkpoint_manager_obj.save_checkpoint(state, round_num)
      train_metrics['save_checkpoint_secs'] = (
          time.time() - save_checkpoint_start_time)

    if round_num % rounds_per_eval == 0 or round_num == total_rounds:
      if metrics_hook is not None:
        eval_metrics = evaluate_fn(state)
        metrics_hook(train_metrics, eval_metrics, round_num)