示例#1
0
  def test_simple_training(self):
    it_process = flars_fedavg.build_federated_averaging_process(
        _keras_model_fn,
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
    server_state = it_process.initialize()

    # Test out manually setting weights:
    keras_model = tff.simulation.models.mnist.create_keras_model(
        compile_model=True)

    @tf.function
    def deterministic_batch():
      return collections.OrderedDict(
          x=np.ones([1, 784], dtype=np.float32),
          y=np.ones([1, 1], dtype=np.int64))

    batch = deterministic_batch()
    federated_data = [[batch]]

    def keras_evaluate(state):
      tff.learning.assign_weights_to_keras_model(keras_model, state.model)
      # N.B. The loss computed here won't match the loss computed by TFF because
      # of the Dropout layer.
      keras_model.test_on_batch(**batch)

    loss_list = []
    for _ in range(3):
      keras_evaluate(server_state)
      server_state, output = it_process.next(server_state, federated_data)
      loss_list.append(output.loss)
    keras_evaluate(server_state)

    self.assertLess(np.mean(loss_list[1:]), loss_list[0])
示例#2
0
 def test_construction(self):
     it_process = flars_fedavg.build_federated_averaging_process(
         _keras_model_fn)
     self.assertIsInstance(it_process, tff.utils.IterativeProcess)
     federated_data_type = it_process.next.type_signature.parameter[1]
     self.assertEqual(str(federated_data_type),
                      '{<x=float32[?,784],y=int64[?,1]>*}@CLIENTS')
示例#3
0
    def test_simple_training(self):
        it_process = flars_fedavg.build_federated_averaging_process(
            _keras_model_fn)
        server_state = it_process.initialize()
        Batch = collections.namedtuple('Batch', ['x', 'y'])  # pylint: disable=invalid-name

        # Test out manually setting weights:
        keras_model = tff.simulation.models.mnist.create_keras_model(
            compile_model=True)

        def deterministic_batch():
            return Batch(x=np.ones([1, 784], dtype=np.float32),
                         y=np.ones([1, 1], dtype=np.int64))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        def keras_evaluate(state):
            tff.learning.assign_weights_to_keras_model(keras_model,
                                                       state.model)
            # N.B. The loss computed here won't match the
            # loss computed by TFF because of the Dropout layer.
            keras_model.test_on_batch(batch.x, batch.y)

        loss_list = []
        for _ in range(3):
            keras_evaluate(server_state)
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)
        keras_evaluate(server_state)

        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
示例#4
0
 def test_construction(self):
   it_process = flars_fedavg.build_federated_averaging_process(
       _keras_model_fn,
       client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
   self.assertIsInstance(it_process, tff.templates.IterativeProcess)
   federated_data_type = it_process.next.type_signature.parameter[1]
   self.assertEqual(
       str(federated_data_type), '{<x=float32[?,784],y=int64[?,1]>*}@CLIENTS')
示例#5
0
 def test_self_contained_example_keras_model(self):
   client_data = create_client_data()
   train_data = [client_data()]
   trainer = flars_fedavg.build_federated_averaging_process(
       _keras_model_fn,
       client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
   state = trainer.initialize()
   losses = []
   for _ in range(2):
     state, outputs = trainer.next(state, train_data)
     # Track the loss.
     losses.append(outputs.loss)
   self.assertLess(losses[1], losses[0])
示例#6
0
    def test_self_contained_example_keras_model(self):
        def model_fn():
            return tff.learning.from_compiled_keras_model(
                tff.simulation.models.mnist.create_simple_keras_model(),
                sample_batch)

        client_data = create_client_data()
        train_data = [client_data()]
        sample_batch = self.evaluate(next(iter(train_data[0])))

        trainer = flars_fedavg.build_federated_averaging_process(model_fn)
        state = trainer.initialize()
        losses = []
        for _ in range(2):
            state, outputs = trainer.next(state, train_data)
            # Track the loss.
            losses.append(outputs.loss)
        self.assertLess(losses[1], losses[0])
示例#7
0
def _federated_averaging_training_loop(model_fn,
                                       client_optimizer_fn,
                                       server_optimizer_fn,
                                       client_datasets_fn,
                                       evaluate_fn,
                                       total_rounds=500,
                                       rounds_per_eval=1,
                                       metrics_hook=None):
    """A simple example of training loop for the Federated Averaging algorithm.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    client_datasets_fn: A function that takes the round number, and returns a
      list of `tf.data.Datset`, one per client.
    evaluate_fn: A function that takes state, performs evaluation, and returns
      evaluations metrics.
    total_rounds: Number of rounds to train.
    rounds_per_eval: How often to call the  `metrics_hook` function.
    metrics_hook: A function taking arguments (training metrics,
      evaluation metrics, and round number). Optional.

  Returns:
    Final `ServerState`.
  """
    logging.info('Starting federated training loop')

    checkpoint_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name)
    checkpoint_manager_obj = checkpoint_manager.FileCheckpointManager(
        checkpoint_dir)

    if FLAGS.server_optimizer != 'flars':
        logging.error('Unsupported server_optimzier: %s',
                      FLAGS.server_optimizer)
    else:
        iterative_process = flars_fedavg.build_federated_averaging_process(
            model_fn,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn)
        ServerState = flars_fedavg.ServerState  # pylint: disable=invalid-name

    # construct an initial state here to act as a checkpoint template
    inital_state = iterative_process.initialize()
    inital_state = ServerState.from_tff_result(inital_state)

    logging.info('Looking for checkpoints in \'%s\'', checkpoint_dir)
    state, round_num = checkpoint_manager_obj.load_latest_checkpoint(
        inital_state)

    if state is None:
        logging.info('No previous checkpoints, initializing experiment')
        state = inital_state
        round_num = 0
        if metrics_hook is not None:
            eval_metrics = evaluate_fn(state)
            metrics_hook({}, eval_metrics, round_num)
        checkpoint_manager_obj.save_checkpoint(state, 0)
    else:
        logging.info('Restarted from checkpoint round %d', round_num)

    while round_num < total_rounds:
        round_num += 1
        train_metrics = {}

        # Reset the executor to clear the cache, and clear the default graph to
        # garbage collect tf.Functions that will no longer be used.
        tff.framework.set_default_executor(
            tff.framework.local_executor_factory(max_fanout=25))
        tf.compat.v1.reset_default_graph()

        round_start_time = time.time()
        data_prep_start_time = time.time()
        train_data = client_datasets_fn(round_num)
        train_metrics['prepare_datasets_secs'] = time.time(
        ) - data_prep_start_time

        training_start_time = time.time()
        state, tff_train_metrics = iterative_process.next(state, train_data)
        state = ServerState.from_tff_result(state)
        tff_train_metrics = tff_train_metrics._asdict(recursive=True)
        train_metrics.update(tff_train_metrics)
        train_metrics['training_secs'] = time.time() - training_start_time

        logging.info('Round {:2d} elapsed time: {:.2f}s .'.format(
            round_num, (time.time() - round_start_time)))
        train_metrics['total_round_secs'] = time.time() - round_start_time

        if (round_num % FLAGS.rounds_per_checkpoint == 0
                or round_num == total_rounds):
            save_checkpoint_start_time = time.time()
            checkpoint_manager_obj.save_checkpoint(state, round_num)
            train_metrics['save_checkpoint_secs'] = (
                time.time() - save_checkpoint_start_time)

        if round_num % rounds_per_eval == 0 or round_num == total_rounds:
            if metrics_hook is not None:
                eval_metrics = evaluate_fn(state)
                metrics_hook(train_metrics, eval_metrics, round_num)
示例#8
0
def federated_averaging_training_loop(model_fn,
                                      server_optimizer_fn,
                                      client_datasets_fn,
                                      total_rounds=500,
                                      rounds_per_eval=1,
                                      metrics_hook=lambda *args: None):
    """A simple example of training loop for the Federated Averaging algorithm.

  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    client_datasets_fn: A function that takes the round number, and returns a
      list of `tf.data.Datset`, one per client.
    total_rounds: Number of rounds to train.
    rounds_per_eval: How often to call the  `metrics_hook` function.
    metrics_hook: A function taking arguments (server_state, train_metrics,
      round_num) and performs evaluation. Optional.

  Returns:
    Final `ServerState`.
  """
    logging.info('Starting federated_training_loop')
    checkpoint_dir = os.path.join(FLAGS.root_output_dir, FLAGS.exp_name)

    if FLAGS.server_optimizer != 'flars':
        iterative_process = simple_fedavg.build_federated_averaging_process(
            model_fn, server_optimizer_fn=server_optimizer_fn)
        ServerState = simple_fedavg.ServerState  # pylint: disable=invalid-name
    else:
        iterative_process = flars_fedavg.build_federated_averaging_process(
            model_fn, server_optimizer_fn=server_optimizer_fn)
        ServerState = flars_fedavg.ServerState  # pylint: disable=invalid-name

    server_state = ServerState.from_anon_tuple(iterative_process.initialize())
    round_num = None
    train_metrics = {}

    latest_checkpoint_dir = checkpoint_utils.latest_checkpoint(
        checkpoint_dir, CHECKPOINT_PREFIX)
    logging.info('Looking for checkpoints in [%s/%s]', checkpoint_dir,
                 CHECKPOINT_PREFIX)
    while latest_checkpoint_dir is not None:
        # Restart from a previous round.
        logging.info('Loading a previous checkpoint')
        try:
            server_state, metrics_hook.results, round_num = read_checkpoint(
                latest_checkpoint_dir, server_state)
            break
        except OSError as e:
            # Likely corrupted checkpoint, possibly job died while writing. Delete the
            # checkpoint directory and try again.
            logging.error('Exception [%s]', e)
            logging.warning('Deleteing likely corrupted checkpoint at [%s]',
                            latest_checkpoint_dir)
            tf.io.gfile.rmtree(latest_checkpoint_dir)
            latest_checkpoint_dir = checkpoint_utils.latest_checkpoint(
                checkpoint_dir, CHECKPOINT_PREFIX)

    if round_num is not None:
        logging.info('Restarted from checkpoint round %d', round_num)
    else:
        # Write the initial checkpoint
        logging.info('No previous checkpoints, initializing experiment')
        round_num = 0
        metrics_hook(server_state, train_metrics, round_num)
        write_checkpoint(checkpoint_dir, server_state, metrics_hook.results,
                         round_num)

    while round_num < total_rounds:
        round_num += 1
        # Reset the executor to clear the cache, and clear the default graph to
        # garbage collect tf.Functions that will no longer be used.
        tff.framework.set_default_executor(
            tff.framework.create_local_executor(max_fanout=25))
        tf.compat.v1.reset_default_graph()

        round_start_time = time.time()
        data_prep_start_time = time.time()
        federated_train_data = client_datasets_fn(round_num)
        train_metrics['prepare_datasets_secs'] = time.time(
        ) - data_prep_start_time

        training_start_time = time.time()
        anon_tuple_server_state, tff_train_metrics = iterative_process.next(
            server_state, federated_train_data)
        server_state = ServerState.from_anon_tuple(anon_tuple_server_state)
        train_metrics.update(tff_train_metrics._asdict(recursive=True))
        train_metrics['training_secs'] = time.time() - training_start_time

        logging.info('Round {:2d} elapsed time: {:.2f}s .'.format(
            round_num, (time.time() - round_start_time)))
        train_metrics['total_round_secs'] = time.time() - round_start_time

        if round_num % FLAGS.rounds_per_checkpoint == 0:
            write_checkpoint_start_time = time.time()
            write_checkpoint(checkpoint_dir, server_state,
                             metrics_hook.results, round_num)
            train_metrics['write_checkpoint_secs'] = (
                time.time() - write_checkpoint_start_time)

        if round_num % rounds_per_eval == 0:
            metrics_hook(server_state, train_metrics, round_num)

    metrics_hook(server_state, train_metrics, total_rounds)
    write_checkpoint(checkpoint_dir, server_state, metrics_hook.results,
                     round_num)

    return server_state